diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py new file mode 100644 index 000000000..2a6e3d430 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -0,0 +1,213 @@ +"""Consolidated SeedVR2 conditioning and refactor regression tests. + +Merges the prior test_seedvr2_refactor_nodes.py and +test_seedvr_conditioning_hardening.py modules. Refactor tests use the +top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests +use _import_nodes_seedvr_isolated() for sys.modules isolation when +mocking comfy.model_management. +""" + +import importlib +import sys +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +_SENTINEL = object() +_TARGETS = ( + ("comfy.model_management", "comfy"), + ("comfy_extras.nodes_seedvr", "comfy_extras"), +) + + +def _import_nodes_seedvr_isolated(): + """Import comfy_extras.nodes_seedvr with comfy.model_management mocked.""" + priors = [] + for mod_name, parent_name in _TARGETS: + prior_mod = sys.modules.get(mod_name, _SENTINEL) + parent = sys.modules.get(parent_name) + attr = mod_name.split(".")[-1] + prior_attr = ( + getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL + ) + priors.append((mod_name, parent_name, attr, prior_mod, prior_attr)) + + mock_mm = MagicMock() + for fn in ( + "xformers_enabled", "xformers_enabled_vae", + "pytorch_attention_enabled", "pytorch_attention_enabled_vae", + "sage_attention_enabled", "flash_attention_enabled", + "is_intel_xpu", + ): + getattr(mock_mm, fn).return_value = False + tv = torch.version.__version__.split(".") + mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1])) + mock_mm.WINDOWS = False + sys.modules["comfy.model_management"] = mock_mm + if sys.modules.get("comfy") is None: + import comfy as _comfy_pkg # noqa: F401 + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or ( + importlib.import_module("comfy_extras.nodes_seedvr") + ) + + def _restore(): + for mod_name, parent_name, attr, prior_mod, prior_attr in priors: + if prior_mod is _SENTINEL: + sys.modules.pop(mod_name, None) + else: + sys.modules[mod_name] = prior_mod + parent = sys.modules.get(parent_name) + if parent is None: + continue + if prior_attr is _SENTINEL: + if hasattr(parent, attr): + delattr(parent, attr) + else: + setattr(parent, attr, prior_attr) + + return nodes_seedvr, _restore + + +class _Rope(nn.Module): + """Minimal RoPE stub exposing a `freqs` parameter.""" + def __init__(self): + super().__init__() + self.freqs = nn.Parameter(torch.zeros(4)) + + +class _Block(nn.Module): + """Minimal transformer block stub holding a `_Rope`.""" + def __init__(self): + super().__init__() + self.rope = _Rope() + + +class _DiffusionModel(nn.Module): + """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" + def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): + super().__init__() + self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) + pos = torch.zeros if zero_conditioning else torch.ones + self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) + self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) + + +class _ModelInner: + """Inner model wrapper exposing `.diffusion_model`.""" + def __init__(self, diffusion_model): + self.diffusion_model = diffusion_model + + +class _ModelPatcher: + """ModelPatcher stub exposing `.model._ModelInner`.""" + def __init__(self, diffusion_model): + self.model = _ModelInner(diffusion_model) + + +def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + schema = nodes_seedvr.SeedVR2Conditioning.define_schema() + assert [input_item.id for input_item in schema.inputs] == [ + "model", + "vae_conditioning", + ] + assert schema.inputs[1].display_name == "latent" + assert [output.display_name for output in schema.outputs] == [ + "model", + "positive", + "negative", + "latent", + ] + finally: + restore() + + +def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + patcher = _ModelPatcher(diffusion_model) + samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + vae_conditioning = {"samples": samples} + + _, first_positive, first_negative, first_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + _, second_positive, second_negative, second_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + + expected_latent = samples.reshape(1, 6, 2, 2) + channel_last = samples.movedim(1, -1).contiguous() + expected_condition = torch.cat( + [ + channel_last, + torch.ones((*channel_last.shape[:-1], 1)), + ], + dim=-1, + ).movedim(-1, 1).reshape(1, 9, 2, 2) + + assert torch.equal(first_latent["samples"], expected_latent) + assert torch.equal(second_latent["samples"], expected_latent) + assert torch.equal( + first_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + first_negative[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_negative[0][1]["condition"], + expected_condition, + ) + finally: + restore() + + +def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel(zero_conditioning=True) + patcher = _ModelPatcher(diffusion_model) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + + message = str(excinfo.value) + assert message.startswith( + nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX + ), ( + "Fail-loud message must use the standard " + "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " + f"can match it. Got: {message!r}" + ) + assert "positive_conditioning" in message + assert "negative_conditioning" in message + finally: + restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py new file mode 100644 index 000000000..f7d9a4f65 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py @@ -0,0 +1,55 @@ +import importlib +import inspect +import sys +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +def test_seedvr_node_signature_matches_schema(): + mock_mm = MagicMock() + mock_mm.xformers_enabled.return_value = False + mock_mm.xformers_enabled_vae.return_value = False + mock_mm.sage_attention_enabled.return_value = False + mock_mm.flash_attention_enabled.return_value = False + + sentinel = object() + prior_cpu = cli_args.cpu + cli_args.cpu = True + prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) + comfy_pkg = sys.modules.get("comfy") + prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel + + with patch.dict(sys.modules, {"comfy.model_management": mock_mm}): + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + sys.modules.pop("comfy_extras.nodes_seedvr", None) + try: + nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") + for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): + schema_ids = [i.id for i in node_cls.define_schema().inputs] + exec_params = [ + p for p in inspect.signature(node_cls.execute).parameters.keys() + if p != "cls" + ] + assert schema_ids == exec_params, ( + f"{node_cls.__name__} schema/execute drift: " + f"schema_ids={schema_ids}, exec_params={exec_params}" + ) + finally: + cli_args.cpu = prior_cpu + if prior_module is sentinel: + sys.modules.pop("comfy_extras.nodes_seedvr", None) + else: + sys.modules["comfy_extras.nodes_seedvr"] = prior_module + if comfy_pkg is not None: + if prior_mm_attr is sentinel: + if hasattr(comfy_pkg, "model_management"): + delattr(comfy_pkg, "model_management") + else: + setattr(comfy_pkg, "model_management", prior_mm_attr) diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py new file mode 100644 index 000000000..a27a8f8df --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py @@ -0,0 +1,57 @@ +from unittest.mock import patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _schema_ids(items): + return [item.id for item in items] + + +def test_seedvr2_post_processing_schema(): + schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() + + assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"] + assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"] + assert schema.inputs[2].default == "lab" + assert schema.outputs[0].get_io_type() == "IMAGE" + + +def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): + decoded = torch.full((1, 3, 4, 4), 0.25) + reference = torch.full((1, 3, 4, 4), 0.75) + + def _lab(content, style): + raise torch.cuda.OutOfMemoryError("CUDA out of memory") + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + try: + nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( + decoded, reference, torch.device("cpu"), "lab", + ) + except RuntimeError as exc: + assert "color_correction_method=lab" in str(exc) + assert " method=lab" not in str(exc) + else: + raise AssertionError("expected RuntimeError for one-frame LAB OOM") + + +def test_seedvr2_post_processing_unknown_color_correction_method_raises(): + decoded = torch.zeros(1, 2, 4, 4, 3) + original = torch.zeros(1, 2, 4, 4, 3) + try: + nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") + except ValueError as exc: + assert "color_correction_method" in str(exc) + else: + raise AssertionError("expected ValueError for unknown color_correction_method")