diff --git a/tests-unit/comfy_extras_test/test_seedvr2_node_boundaries.py b/tests-unit/comfy_extras_test/test_seedvr2_node_boundaries.py new file mode 100644 index 000000000..ea6793489 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_node_boundaries.py @@ -0,0 +1,58 @@ +import ast +import inspect +import textwrap + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _schema_ids(items): + return [item.id for item in items] + + +def test_resize_schemas_are_preprocess_only(): + simple = nodes_seedvr.SeedVR2Resize.define_schema() + advanced = nodes_seedvr.SeedVR2ResizeAdvanced.define_schema() + + assert _schema_ids(simple.inputs) == ["images", "multiplier"] + assert _schema_ids(simple.outputs) == ["input_pixels", "original_image", "upscaled_shorter_edge"] + assert simple.outputs[0].get_io_type() == "IMAGE" + + assert _schema_ids(advanced.inputs) == ["images", "shorter_edge"] + assert _schema_ids(advanced.outputs) == ["input_pixels", "original_image", "upscaled_shorter_edge"] + assert advanced.outputs[0].get_io_type() == "IMAGE" + + +def test_resize_nodes_do_not_call_encode_decode_or_color_transfer(): + source = "\n".join( + [ + inspect.getsource(nodes_seedvr.SeedVR2Resize.execute), + inspect.getsource(nodes_seedvr.SeedVR2ResizeAdvanced.execute), + ] + ) + tree = ast.parse(textwrap.dedent(source)) + forbidden_names = { + "encode", + "encode_tiled", + "decode", + "decode_tiled", + "tiled_vae", + "lab_color_transfer", + } + + for node in ast.walk(tree): + if isinstance(node, ast.Call): + func = node.func + if isinstance(func, ast.Name): + name = func.id + elif isinstance(func, ast.Attribute): + name = func.attr + else: + continue + assert name not in forbidden_names diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py new file mode 100644 index 000000000..e260499ee --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py @@ -0,0 +1,461 @@ +import inspect +from unittest.mock import patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _schema_ids(items): + return [item.id for item in items] + + +def test_seedvr2_post_processing_schema(): + schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() + + assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"] + assert schema.inputs[2].default is None + assert schema.inputs[2].min == 2 + assert schema.inputs[2].force_input is True + assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"] + assert schema.inputs[3].default == "lab" + assert schema.outputs[0].get_io_type() == "IMAGE" + + +def test_seedvr2_post_processing_color_correction_memory_multipliers_are_named(): + assert nodes_seedvr.LAB_SCALE_MULTIPLIER == 13 + assert nodes_seedvr.WAVELET_SCALE_MULTIPLIER == 10 + assert nodes_seedvr.ADAIN_SCALE_MULTIPLIER == 6 + + +def test_seedvr2_post_processing_lab_autochunks_from_memory_estimate(monkeypatch): + decoded = torch.full((1, 5, 2, 2, 3), 0.25) + original = torch.full((1, 5, 2, 2, 3), 0.75) + calls = [] + + def _lab(content, style): + calls.append(content.shape[0]) + return content + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1700) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0] + + assert calls == [1, 1, 1, 1, 1] + assert tuple(output.shape) == (1, 5, 2, 2, 3) + + +def test_seedvr2_post_processing_lab_runs_each_frame_independently(monkeypatch): + decoded = torch.full((1, 4, 2, 2, 3), 0.25) + original = torch.full((1, 4, 2, 2, 3), 0.75) + calls = [] + + def _lab(content, style): + calls.append(content.shape[0]) + return content + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0] + + assert calls == [1, 1, 1, 1] + assert tuple(output.shape) == (1, 4, 2, 2, 3) + + +def test_seedvr2_post_processing_lab_derives_reference_from_original_and_upscaled_shorter_edge(): + decoded = torch.full((1, 3, 9, 11, 3), 0.25) + original = torch.full((1, 2, 16, 20, 3), 0.75) + calls = [] + + def _lab(content, style): + calls.append((content.clone(), style.clone())) + return torch.zeros_like(content) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0] + + assert tuple(output.shape) == (1, 2, 8, 10, 3) + assert torch.equal(output, torch.full_like(output, 0.5)) + assert len(calls) == 2 + assert calls[0][0].shape == (1, 3, 8, 10) + assert calls[0][1].shape == (1, 3, 8, 10) + assert torch.equal(calls[0][0], torch.full_like(calls[0][0], -0.5)) + assert torch.allclose(calls[0][1], torch.full_like(calls[0][1], 0.5)) + + +def test_seedvr2_post_processing_lab_runs_color_transfer_on_vae_device(): + source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing.execute) + chunk_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks) + helper_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._lab_color_transfer_on_vae_device) + + assert "_color_transfer_chunked" in source + assert "_lab_color_transfer_on_vae_device" in chunk_source + assert "torch.cat" not in chunk_source + assert "torch.empty" in chunk_source + assert ".copy_(" in chunk_source + assert "reference_5d.to(device=decoded_5d.device)" not in source + assert "comfy.model_management.vae_device()" in helper_source + assert ".to(device=color_device)" in helper_source + assert ".to(device=output_device)" in helper_source + + +def test_seedvr2_post_processing_lab_chunking_is_frame_independent(monkeypatch): + decoded = torch.linspace(-0.9, 0.9, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24) + reference = torch.linspace(0.8, -0.8, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24) + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + + one_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks( + decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 1, + ) + multi_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks( + decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 3, + ) + + assert torch.equal(one_frame, multi_frame) + + +def test_seedvr2_post_processing_lab_retry_does_not_mutate_reference(monkeypatch): + decoded = torch.full((2, 3, 4, 4), 0.25) + reference = torch.full((2, 3, 4, 4), 0.75) + original_reference = reference.clone() + calls = [] + cache_clears = [] + + def _lab(content, style): + calls.append((content.clone(), style.clone())) + style.add_(10.0) + if len(calls) == 1: + raise torch.cuda.OutOfMemoryError("CUDA out of memory") + return content + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: cache_clears.append(True)) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( + decoded, reference, torch.device("cpu"), "lab", + ) + + assert len(cache_clears) == 1 + assert torch.equal(reference, original_reference) + assert torch.equal(calls[1][1], original_reference[0:1]) + + +def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): + decoded = torch.full((1, 3, 4, 4), 0.25) + reference = torch.full((1, 3, 4, 4), 0.75) + + def _lab(content, style): + raise torch.cuda.OutOfMemoryError("CUDA out of memory") + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + try: + nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( + decoded, reference, torch.device("cpu"), "lab", + ) + except RuntimeError as exc: + assert "color_correction_method=lab" in str(exc) + assert " method=lab" not in str(exc) + else: + raise AssertionError("expected RuntimeError for one-frame LAB OOM") + + +def test_seedvr2_post_processing_raw_conversion_does_not_probe_full_tensor_range(): + source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._to_seedvr2_raw) + + assert ".amin" not in source + assert ".item" not in source + + +def test_seedvr2_post_processing_none_does_not_resize_reference_pixels(): + decoded = torch.full((1, 2, 10, 12, 3), 0.25) + original = torch.full((1, 2, 16, 20, 3), 0.75) + + with patch.object(nodes_seedvr, "side_resize") as resize: + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0] + + resize.assert_not_called() + assert tuple(output.shape) == (1, 2, 8, 10, 3) + + +def test_seedvr2_post_processing_rejects_invalid_upscaled_shorter_edge(): + decoded = torch.full((1, 2, 10, 12, 3), 0.25) + original = torch.full((1, 2, 16, 20, 3), 0.75) + + for edge in (None, 1, 1.5): + try: + nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, edge, "none") + except ValueError as exc: + assert "upscaled_shorter_edge" in str(exc) + else: + raise AssertionError(f"expected ValueError for upscaled_shorter_edge={edge!r}") + + +def test_seedvr2_post_processing_lab_resizes_full_reference_frame(): + decoded = torch.full((1, 2, 4, 5, 3), 0.25) + original = torch.full((1, 2, 16, 20, 3), 0.75) + resize_calls = [] + lab_calls = [] + + def _resize(images, size, interpolation=None, antialias=None): + resize_calls.append((images.clone(), size, interpolation, antialias)) + if isinstance(size, int): + return torch.full((2, 3, size, round(images.shape[-1] * size / images.shape[-2])), 0.5) + return torch.full((2, 3, size[0], size[1]), 0.5) + + def _lab(content, style): + lab_calls.append((content.clone(), style.clone())) + return torch.zeros_like(content) + + with patch.object(nodes_seedvr.TVF, "resize", _resize): + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0] + + assert tuple(output.shape) == (1, 2, 4, 4, 3) + assert torch.equal(output, torch.full_like(output, 0.5)) + assert resize_calls[0][0].shape == (2, 3, 16, 20) + assert resize_calls[0][1] == 8 + assert resize_calls[1][0].shape == (2, 3, 8, 10) + assert resize_calls[1][1] == (4, 5) + assert len(lab_calls) == 2 + assert lab_calls[0][1].shape == (1, 3, 4, 5) + assert torch.equal(lab_calls[0][1], torch.zeros_like(lab_calls[0][1])) + + +def test_seedvr2_post_processing_none_trims_and_crops_without_color_correction(): + decoded = torch.arange(1 * 3 * 9 * 11 * 3, dtype=torch.float32).reshape(1, 3, 9, 11, 3) + original = torch.zeros(1, 2, 16, 20, 3) + + with patch.object(nodes_seedvr, "lab_color_transfer") as lab: + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0] + + assert lab.call_count == 0 + assert tuple(output.shape) == (1, 2, 8, 10, 3) + assert torch.equal(output, decoded[:, :2, :8, :10, :]) + + +def test_seedvr2_post_processing_restores_flattened_padded_batches_before_trimming(): + decoded = torch.arange(10 * 4 * 6 * 1, dtype=torch.float32).reshape(10, 4, 6, 1) + original = torch.zeros(2, 2, 4, 6, 1) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0] + + expected = torch.cat((decoded[0:2], decoded[5:7]), dim=0) + assert tuple(output.shape) == (4, 4, 6, 1) + assert torch.equal(output, expected) + + +def test_seedvr2_post_processing_none_preserves_decoded_spatial_size_when_reference_is_larger(): + decoded = torch.arange(1 * 3 * 8 * 10 * 3, dtype=torch.float32).reshape(1, 3, 8, 10, 3) + original = torch.zeros(1, 2, 16, 20, 3) + + with patch.object(nodes_seedvr, "lab_color_transfer") as lab: + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 16, "none").result[0] + + assert lab.call_count == 0 + assert tuple(output.shape) == (1, 2, 8, 10, 3) + assert torch.equal(output, decoded[:, :2, :, :, :]) + + +def test_seedvr2_post_processing_crops_to_reference_tensor_when_reference_is_smaller(): + decoded = torch.ones((1, 1, 720, 1280, 3), dtype=torch.float32) + original = torch.ones((1, 1, 360, 640, 3), dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 360, "none").result[0] + + assert tuple(output.shape) == (1, 1, 360, 640, 3) + + +def test_seedvr2_post_processing_uses_decoded_size_when_reference_is_larger(): + decoded = torch.ones((1, 1, 128, 160, 3), dtype=torch.float32) + original = torch.ones((1, 1, 480, 640, 3), dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0] + + assert tuple(output.shape) == (1, 1, 128, 160, 3) + + +def test_seedvr2_post_processing_derives_crop_from_upscaled_shorter_edge(): + decoded = torch.ones((1, 1, 128, 224, 3), dtype=torch.float32) + original = torch.ones((1, 1, 1080, 1920, 3), dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0] + + assert tuple(output.shape) == (1, 1, 120, 212, 3) + + +def test_seedvr2_post_processing_uses_even_crop_from_odd_resized_width(): + decoded = torch.ones((1, 1, 128, 256, 3), dtype=torch.float32) + original = torch.ones((1, 1, 120, 169, 3), dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0] + + assert tuple(output.shape) == (1, 1, 120, 168, 3) + + +def test_seedvr2_post_processing_none_preserves_black_bottom_row_content(): + decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32) + original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32) + original[:, :, -1, :, :] = -1.0 + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0] + + assert tuple(output.shape) == (1, 2, 8, 10, 3) + assert torch.equal(output, decoded) + + +def test_seedvr2_post_processing_none_preserves_black_right_column_content(): + decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32) + original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32) + original[:, :, :, -1, :] = -1.0 + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0] + + assert tuple(output.shape) == (1, 2, 8, 10, 3) + assert torch.equal(output, decoded) + + +def test_seedvr2_post_processing_wavelet_dispatch_routes_through_wavelet_color_transfer(): + decoded = torch.full((1, 3, 9, 11, 3), 0.25) + original = torch.full((1, 2, 16, 20, 3), 0.75) + wavelet_calls = [] + lab_calls = [] + + def _wavelet(content, style): + wavelet_calls.append((content.clone(), style.clone())) + return torch.zeros_like(content) + + def _lab(content, style): + lab_calls.append((content.clone(), style.clone())) + return torch.zeros_like(content) + + with patch.object(nodes_seedvr, "wavelet_color_transfer", _wavelet): + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "wavelet").result[0] + + assert len(wavelet_calls) == 1 + assert len(lab_calls) == 0 + assert tuple(output.shape) == (1, 2, 8, 10, 3) + assert torch.equal(output, torch.full_like(output, 0.5)) + assert wavelet_calls[0][0].shape == (2, 3, 8, 10) + assert wavelet_calls[0][1].shape == (2, 3, 8, 10) + assert torch.equal(wavelet_calls[0][0], torch.full_like(wavelet_calls[0][0], -0.5)) + assert torch.allclose(wavelet_calls[0][1], torch.full_like(wavelet_calls[0][1], 0.5)) + + +def test_seedvr2_post_processing_adain_dispatch_routes_through_adain_color_transfer(): + decoded = torch.full((1, 3, 9, 11, 3), 0.25) + original = torch.full((1, 2, 16, 20, 3), 0.75) + adain_calls = [] + lab_calls = [] + + def _adain(content, style): + adain_calls.append((content.clone(), style.clone())) + return torch.zeros_like(content) + + def _lab(content, style): + lab_calls.append((content.clone(), style.clone())) + return torch.zeros_like(content) + + with patch.object(nodes_seedvr, "adain_color_transfer", _adain): + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "adain").result[0] + + assert len(adain_calls) == 1 + assert len(lab_calls) == 0 + assert tuple(output.shape) == (1, 2, 8, 10, 3) + assert torch.equal(output, torch.full_like(output, 0.5)) + assert adain_calls[0][0].shape == (2, 3, 8, 10) + assert adain_calls[0][1].shape == (2, 3, 8, 10) + + +def test_seedvr2_color_transfer_helper_runs_on_vae_device(): + import inspect as _inspect + helper_source = _inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._color_transfer_on_vae_device) + assert "comfy.model_management.vae_device()" in helper_source + assert ".to(device=color_device)" in helper_source + assert ".to(device=output_device)" in helper_source + assert "transfer_fn" in helper_source + + +def test_seedvr2_wavelet_color_transfer_matches_primary_source_reconstruction(): + from comfy.ldm.seedvr import vae as seedvr_vae + torch.manual_seed(0) + content = torch.rand(1, 3, 12, 16) * 2.0 - 1.0 + style = torch.rand(1, 3, 12, 16) * 2.0 - 1.0 + out = seedvr_vae.wavelet_color_transfer(content, style) + expected = seedvr_vae.wavelet_reconstruction(content.clone(), style.clone()) + assert torch.equal(out, expected) + + +def test_seedvr2_adain_color_transfer_matches_huang_belongie_formula(): + from comfy.ldm.seedvr import vae as seedvr_vae + torch.manual_seed(0) + content = torch.rand(2, 3, 5, 7) * 2.0 - 1.0 + style = torch.rand(2, 3, 5, 7) * 2.0 - 1.0 + out = seedvr_vae.adain_color_transfer(content.clone(), style.clone()) + + b, c = 2, 3 + cf = content.float().reshape(b, c, -1) + sf = style.float().reshape(b, c, -1) + eps = 1e-5 + mu_c = cf.mean(dim=2).reshape(b, c, 1, 1) + sd_c = (cf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + mu_s = sf.mean(dim=2).reshape(b, c, 1, 1) + sd_s = (sf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + expected = ((content.float() - mu_c) / sd_c) * sd_s + mu_s + expected = expected.clamp(-1.0, 1.0) + assert torch.allclose(out, expected, atol=1e-6) + + +def test_seedvr2_adain_single_pixel_uses_population_variance_without_nan(): + from comfy.ldm.seedvr import vae as seedvr_vae + content = torch.tensor([[[[0.25]], [[-0.5]], [[0.75]]]], dtype=torch.float32) + style = torch.tensor([[[[-0.25]], [[0.5]], [[-0.75]]]], dtype=torch.float32) + + out = seedvr_vae.adain_color_transfer(content, style) + + assert torch.isfinite(out).all() + assert torch.equal(out, style) + + +def test_seedvr2_adain_preserves_input_dtype(): + from comfy.ldm.seedvr import vae as seedvr_vae + content = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16) + style = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16) + out = seedvr_vae.adain_color_transfer(content, style) + assert out.dtype == torch.float16 + + +def test_seedvr2_adain_resizes_mismatched_style_to_content_shape(): + from comfy.ldm.seedvr import vae as seedvr_vae + content = torch.rand(1, 3, 8, 10) * 2.0 - 1.0 + style = torch.rand(1, 3, 16, 20) * 2.0 - 1.0 + out = seedvr_vae.adain_color_transfer(content, style) + assert tuple(out.shape) == (1, 3, 8, 10) + + +def test_seedvr2_post_processing_unknown_color_correction_method_raises(): + decoded = torch.zeros(1, 2, 4, 4, 3) + original = torch.zeros(1, 2, 4, 4, 3) + try: + nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus") + except ValueError as exc: + assert "color_correction_method" in str(exc) + else: + raise AssertionError("expected ValueError for unknown color_correction_method") diff --git a/tests-unit/comfy_extras_test/test_seedvr_conditioning_hardening.py b/tests-unit/comfy_extras_test/test_seedvr_conditioning_hardening.py new file mode 100644 index 000000000..063c7216b --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr_conditioning_hardening.py @@ -0,0 +1,601 @@ +"""Regression tests for SeedVR2 conditioning model resolution and RoPE +frequency cast. + +Pin two behaviors: + + 1. ``_resolve_seedvr2_diffusion_model`` returns the inner diffusion-model + for the expected ``model.model.diffusion_model`` shape and fails loud + with a ``RuntimeError`` whose message begins with + ``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` for any other shape, including + the four distinct missing-vs-None subcases of the chain. + 2. ``_apply_rope_freqs_float32_cast`` is idempotent **per-tensor by + dtype check**, NOT per-instance by sentinel attribute. Every call + walks the diffusion-model module tree and invokes ``.to(float32)`` + only on tensors whose dtype is not already ``float32``. A cache-by- + attribute (sentinel) approach is rejected because the sentinel + would survive ComfyUI's dynamic model unload/reload cycle while + ``rope.freqs`` itself is restored to the archived dtype, so the + next call would short-circuit and leave RoPE running in fp16/bf16 + — the exact failure this helper is supposed to prevent. The dtype + check is self-correcting against any weight-restore lifecycle + event. + +Import isolation: ``comfy.model_management`` is stubbed via direct +``sys.modules`` assignment so importing ``comfy_extras.nodes_seedvr`` does +not trigger GPU/server-side initialization. ``patch.dict`` is intentionally +NOT used here because its snapshot/restore semantics evict transitively +imported third-party modules (e.g. ``torchvision``) on exit, which causes +``torch``'s global op-library Meta-key registrations to double-register on +re-import. Module-level cached import + scoped restore of the four mocked +entries avoids that hazard. See ``_import_nodes_seedvr_isolated``. +""" + +import importlib +import sys +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + + +_SENTINEL = object() + + +def _import_nodes_seedvr_isolated(): + """Stub ``comfy.model_management``, import (or reuse a cached import of) + ``comfy_extras.nodes_seedvr``, and return ``(module, restore)``. + + ``restore()`` snapshots and restores three in-process import-state + surfaces: + + 1. ``sys.modules["comfy.model_management"]`` — the stubbed module. + 2. ``sys.modules["comfy_extras.nodes_seedvr"]`` — the imported test + target. If we leave this in ``sys.modules`` after the test, a + later test importing the real ``comfy_extras.nodes_seedvr`` will + get our stubbed-``comfy.model_management`` cached version, which + does not re-resolve against the real ``comfy.model_management``. + 3. ``comfy_extras.nodes_seedvr`` package attribute on the + ``comfy_extras`` package, mirroring the existing + ``comfy.model_management`` attribute restore. + + All three are restored verbatim if previously set; deleted on exit + if previously unset. No global state leaks into later tests. + """ + prior_comfy_mm = sys.modules.get("comfy.model_management", _SENTINEL) + prior_comfy_mm_attr = _SENTINEL + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None: + prior_comfy_mm_attr = getattr(comfy_pkg, "model_management", _SENTINEL) + prior_nodes_seedvr_module = sys.modules.get( + "comfy_extras.nodes_seedvr", _SENTINEL, + ) + prior_nodes_seedvr_attr = _SENTINEL + comfy_extras_pkg = sys.modules.get("comfy_extras") + if comfy_extras_pkg is not None: + prior_nodes_seedvr_attr = getattr( + comfy_extras_pkg, "nodes_seedvr", _SENTINEL, + ) + + # ``comfy_extras.nodes_seedvr`` imports ``comfy.sample`` (added in PR + # #59) which pulls in the full samplers/k_diffusion/model_patcher + # transitive chain. That chain re-imports ``comfy.model_management`` + # and calls feature-detection predicates like ``xformers_enabled()`` + # in module-init code (``comfy/ldm/modules/attention.py:18``); a bare + # ``MagicMock()`` returns truthy for those calls and triggers a real + # ``import xformers`` that fails in the test environment. Pin the + # boolean-returning predicates to ``False`` so the import chain + # follows the no-extension path. + # Configure stub so every ``..._enabled[_*]()`` predicate returns + # False. The transitive import chain through ``comfy.sample`` → ... + # invokes several feature-detection predicates at module-init time + # (``comfy/ldm/modules/attention.py`` ``xformers_enabled()``, + # ``comfy/ldm/modules/diffusionmodules/model.py`` + # ``xformers_enabled_vae()``, etc.). A bare ``MagicMock()`` returns + # truthy auto-attrs, which triggers real ``import xformers`` calls + # that fail in the test environment. + mock_mm = MagicMock() + mock_mm.xformers_enabled.return_value = False + mock_mm.xformers_enabled_vae.return_value = False + mock_mm.pytorch_attention_enabled.return_value = False + mock_mm.pytorch_attention_enabled_vae.return_value = False + mock_mm.sage_attention_enabled.return_value = False + mock_mm.flash_attention_enabled.return_value = False + torch_version_parts = torch.version.__version__.split(".") + mock_mm.torch_version_numeric = ( + int(torch_version_parts[0]), + int(torch_version_parts[1]), + ) + mock_mm.WINDOWS = False + mock_mm.is_intel_xpu.return_value = False + sys.modules["comfy.model_management"] = mock_mm + # The transitive import chain reaches code paths that do + # ``comfy.model_management.`` (attribute access on the comfy + # package, not a fresh import). Setting only ``sys.modules`` is not + # enough — also bind the stub as the package attribute. If the + # ``comfy`` package isn't imported yet at stub-time (cold first run), + # importing it now is safe and idempotent. + if comfy_pkg is None: + import comfy as _comfy_pkg # noqa: F401 + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + if "comfy_extras.nodes_seedvr" in sys.modules: + nodes_seedvr = sys.modules["comfy_extras.nodes_seedvr"] + else: + nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") + + def _restore(): + # 1. comfy.model_management sys.modules entry + if prior_comfy_mm is _SENTINEL: + sys.modules.pop("comfy.model_management", None) + else: + sys.modules["comfy.model_management"] = prior_comfy_mm + # 2. comfy.model_management package attribute on comfy + comfy_pkg_now = sys.modules.get("comfy") + if comfy_pkg_now is not None: + if prior_comfy_mm_attr is _SENTINEL: + if hasattr(comfy_pkg_now, "model_management"): + delattr(comfy_pkg_now, "model_management") + else: + setattr(comfy_pkg_now, "model_management", prior_comfy_mm_attr) + # 3. comfy_extras.nodes_seedvr sys.modules entry + if prior_nodes_seedvr_module is _SENTINEL: + sys.modules.pop("comfy_extras.nodes_seedvr", None) + else: + sys.modules["comfy_extras.nodes_seedvr"] = prior_nodes_seedvr_module + # 4. comfy_extras.nodes_seedvr package attribute on comfy_extras + comfy_extras_pkg_now = sys.modules.get("comfy_extras") + if comfy_extras_pkg_now is not None: + if prior_nodes_seedvr_attr is _SENTINEL: + if hasattr(comfy_extras_pkg_now, "nodes_seedvr"): + delattr(comfy_extras_pkg_now, "nodes_seedvr") + else: + setattr( + comfy_extras_pkg_now, "nodes_seedvr", + prior_nodes_seedvr_attr, + ) + + return nodes_seedvr, _restore + + +class _Rope(nn.Module): + def __init__(self): + super().__init__() + self.freqs = nn.Parameter(torch.zeros(4)) + + +class _Block(nn.Module): + def __init__(self): + super().__init__() + self.rope = _Rope() + + +class _DiffusionModel(nn.Module): + def __init__( + self, + n_blocks=3, + zero_conditioning=False, + conditioning_dtype=torch.float32, + ): + super().__init__() + self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) + if zero_conditioning: + # Simulates a numz-format DiT-only file loaded via UNETLoader: + # ``register_buffer`` zero-init at ``comfy/ldm/seedvr/model.py`` + # leaves the buffers at zero when ``load_state_dict`` cannot + # find ``positive_conditioning`` / ``negative_conditioning`` + # keys in the state_dict. The fail-loud guard at + # ``SeedVR2Conditioning.execute`` distinguishes this from a + # properly-baked file by ``abs().sum() == 0`` on both buffers. + self.register_buffer( + "positive_conditioning", + torch.zeros((2, 4), dtype=conditioning_dtype), + ) + self.register_buffer( + "negative_conditioning", + torch.zeros((3, 4), dtype=conditioning_dtype), + ) + else: + self.register_buffer( + "positive_conditioning", + torch.ones((2, 4), dtype=conditioning_dtype), + ) + self.register_buffer( + "negative_conditioning", + torch.zeros((3, 4), dtype=conditioning_dtype), + ) + + +class _ModelInner: + def __init__(self, diffusion_model): + self.diffusion_model = diffusion_model + + +class _ModelPatcher: + def __init__(self, diffusion_model): + self.model = _ModelInner(diffusion_model) + + +def test_resolve_seedvr2_diffusion_model_returns_inner_when_valid(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + patcher = _ModelPatcher(diffusion_model) + resolved = nodes_seedvr._resolve_seedvr2_diffusion_model(patcher) + assert resolved is diffusion_model + finally: + restore() + + +def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + schema = nodes_seedvr.SeedVR2Conditioning.define_schema() + assert [input_item.id for input_item in schema.inputs] == [ + "model", + "vae_conditioning", + ] + assert schema.inputs[1].display_name == "LATENT" + assert [output.display_name for output in schema.outputs] == [ + "model", + "positive", + "negative", + "latent", + ] + finally: + restore() + + +def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + patcher = _ModelPatcher(diffusion_model) + samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + vae_conditioning = {"samples": samples} + + _, first_positive, first_negative, first_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + _, second_positive, second_negative, second_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + + expected_latent = samples.reshape(1, 6, 2, 2) + channel_last = samples.movedim(1, -1).contiguous() + expected_condition = torch.cat( + [ + channel_last, + torch.ones((*channel_last.shape[:-1], 1)), + ], + dim=-1, + ).movedim(-1, 1).reshape(1, 9, 2, 2) + + assert torch.equal(first_latent["samples"], expected_latent) + assert torch.equal(second_latent["samples"], expected_latent) + assert torch.equal( + first_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + first_negative[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_negative[0][1]["condition"], + expected_condition, + ) + finally: + restore() + + +def test_resolve_seedvr2_diffusion_model_raises_runtime_error_with_specific_prefix(): + """Pin all four failure modes of the resolver chain to the same error + prefix and to message text that distinguishes 'attribute missing' + from 'attribute present but None'. The four modes: + + mode 1: input has no 'model' attribute + mode 2: input.model is None + mode 3: 'model.model' has no 'diffusion_model' attribute + mode 4: 'model.model.diffusion_model' is None + """ + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + # Mode 1: model has no 'model' attribute at all. + class _NoModelAttr: + pass + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr._resolve_seedvr2_diffusion_model(_NoModelAttr()) + msg = str(excinfo.value) + assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX) + assert "no 'model' attribute" in msg + + # Mode 2: model.model exists but is None (must not be conflated + # with "no 'model' attribute"). + class _ModelIsNone: + def __init__(self): + self.model = None + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr._resolve_seedvr2_diffusion_model(_ModelIsNone()) + msg = str(excinfo.value) + assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX) + assert "input.model is None" in msg + + # Mode 3: model.model exists, has no 'diffusion_model' attribute. + class _NoDiffusionAttr: + def __init__(self): + self.model = object() + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr._resolve_seedvr2_diffusion_model(_NoDiffusionAttr()) + msg = str(excinfo.value) + assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX) + assert "no 'diffusion_model' attribute" in msg + + # Mode 4: model.model.diffusion_model exists but is None (must not + # be conflated with "no 'diffusion_model' attribute"). + class _DiffusionIsNoneInner: + def __init__(self): + self.diffusion_model = None + + class _DiffusionIsNone: + def __init__(self): + self.model = _DiffusionIsNoneInner() + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr._resolve_seedvr2_diffusion_model(_DiffusionIsNone()) + msg = str(excinfo.value) + assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX) + assert "'model.model.diffusion_model' is None" in msg + finally: + restore() + + +def test_apply_rope_freqs_float32_cast_idempotent_on_unchanged_dtype(): + """Calling the helper twice on a model whose rope.freqs is already + float32 must NOT mutate the tensor identity or contents — the dtype + check on every nested module short-circuits the .to() call when the + tensor is already in float32. + """ + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + + # Starting dtype is non-float32 so the first call has work to do. + for module in diffusion_model.modules(): + if hasattr(module, "rope") and hasattr(module.rope, "freqs"): + module.rope.freqs.data = module.rope.freqs.data.to(torch.float64) + + nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model) + first_call_data_ids = [] + for module in diffusion_model.modules(): + if hasattr(module, "rope") and hasattr(module.rope, "freqs"): + assert module.rope.freqs.data.dtype == torch.float32 + first_call_data_ids.append(id(module.rope.freqs.data)) + + # Second call on the same already-float32 model: every per-tensor + # dtype check sees float32 and skips the .to() call. Tensor data + # identity must be preserved (no re-allocation). + nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model) + for module, prior_id in zip( + (m for m in diffusion_model.modules() + if hasattr(m, "rope") and hasattr(m.rope, "freqs")), + first_call_data_ids, + strict=True, + ): + assert module.rope.freqs.data.dtype == torch.float32 + assert id(module.rope.freqs.data) == prior_id, ( + "Already-float32 rope.freqs must not be re-allocated on " + "subsequent calls; the per-tensor dtype check must skip the " + ".to(float32) call when the tensor is already in float32." + ) + finally: + restore() + + +def test_apply_rope_freqs_float32_cast_recovers_after_dtype_reset(): + """After a model unload/reload that restores rope.freqs from an + archived non-float32 dtype, the next call must re-cast to float32. + A bool-sentinel cache approach would short-circuit here and leave + RoPE running in fp16/bf16. + """ + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + for module in diffusion_model.modules(): + if hasattr(module, "rope") and hasattr(module.rope, "freqs"): + module.rope.freqs.data = module.rope.freqs.data.to(torch.float64) + + # First call casts to float32. + nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model) + for module in diffusion_model.modules(): + if hasattr(module, "rope") and hasattr(module.rope, "freqs"): + assert module.rope.freqs.data.dtype == torch.float32 + + # Simulate a Comfy dynamic unload/reload that restores rope.freqs + # to the archived (non-float32) dtype. + for module in diffusion_model.modules(): + if hasattr(module, "rope") and hasattr(module.rope, "freqs"): + module.rope.freqs.data = module.rope.freqs.data.to(torch.float64) + + # Second call must detect the dtype regression and re-cast. + nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model) + for module in diffusion_model.modules(): + if hasattr(module, "rope") and hasattr(module.rope, "freqs"): + assert module.rope.freqs.data.dtype == torch.float32, ( + "After a model unload/reload that resets rope.freqs to " + "non-float32, the next _apply_rope_freqs_float32_cast " + "call MUST re-cast to float32. A bool-sentinel cache " + "would have short-circuited here." + ) + finally: + restore() + + +# --------------------------------------------------------------------------- +# Fail-loud guard: zero-valued conditioning buffers +# --------------------------------------------------------------------------- + + +def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): + """A SeedVR2 model whose ``positive_conditioning`` AND + ``negative_conditioning`` buffers are both zero-valued is an + unrecoverable load state — a numz-format DiT-only ``.safetensors`` + file was loaded via ``UNETLoader`` without the SeedVR2 conditioning + keys baked in. ``SeedVR2Conditioning.execute`` must raise + ``RuntimeError`` carrying the standard SeedVR2 invalid-model prefix + instead of letting the diffusion sampler run on null prompt + conditioning (which silently produces wrong output). + """ + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel(zero_conditioning=True) + patcher = _ModelPatcher(diffusion_model) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + + message = str(excinfo.value) + assert message.startswith( + nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX + ), ( + "Fail-loud message must use the standard " + "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " + f"can match it. Got: {message!r}" + ) + assert "positive_conditioning" in message + assert "negative_conditioning" in message + finally: + restore() + + +def test_seedvr2_conditioning_fails_loud_on_fp8_zero_buffers(): + """The zero-buffer sentinel must reduce fp8 conditioning tensors + without hitting PyTorch's unsupported float8 reductions. + """ + fp8_dtype = getattr(torch, "float8_e4m3fn", None) + if fp8_dtype is None: + pytest.skip("torch build does not expose float8_e4m3fn") + + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel( + zero_conditioning=True, + conditioning_dtype=fp8_dtype, + ) + patcher = _ModelPatcher(diffusion_model) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + + message = str(excinfo.value) + assert message.startswith( + nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX + ) + assert "zero-valued" in message + finally: + restore() + + +def test_seedvr2_conditioning_does_not_fire_on_partial_zero_buffers(): + """The guard checks BOTH buffers together: a model with zero + ``negative_conditioning`` but non-zero ``positive_conditioning`` + (the existing baseline mock fixture) must NOT trigger the fail-loud + path. This pins the AND-gating semantic and prevents a future + regression to OR-gating from rejecting valid bundled checkpoints + where one buffer happens to be all-zeros. + """ + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + # Baseline _DiffusionModel has positive=ones, negative=zeros. + diffusion_model = _DiffusionModel(zero_conditioning=False) + patcher = _ModelPatcher(diffusion_model) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + # Should not raise. + passthrough_model, positive, negative, latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + ) + assert positive[0][0].shape == (1, 2, 4) + assert negative[0][0].shape == (1, 3, 4) + assert passthrough_model is patcher + finally: + restore() + + +def test_seedvr2_conditioning_fail_loud_never_exposes_safetensors_path(): + """The fail-loud message must not expose local model paths from + ``cached_patcher_init``. Public runtime errors should describe the + invalid SeedVR2 contract without making filesystem paths part of the + public behavior contract. + """ + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel(zero_conditioning=True) + patcher = _ModelPatcher(diffusion_model) + # Mimic the ``cached_patcher_init`` shape comfy.sd attaches. + patcher.cached_patcher_init = ( + object(), # function reference + ("/some/models/diffusion_models/seedvr2_ema_7b_fp16.safetensors",), + ) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + + message = str(excinfo.value) + assert "/some/models/diffusion_models" not in message + assert "seedvr2_ema_7b_fp16.safetensors" not in message + assert "Source file:" not in message + assert "positive_conditioning" in message + assert "negative_conditioning" in message + finally: + restore() + + +def test_seedvr2_conditioning_fail_loud_falls_back_when_path_unavailable(): + """When ``cached_patcher_init`` is missing or its tuple does not + contain a ``.safetensors`` path, the fail-loud message still + delivers the actionable diagnostic without leaking ``None`` or + raising during message formatting. + """ + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel(zero_conditioning=True) + patcher = _ModelPatcher(diffusion_model) + # No cached_patcher_init set on the patcher. + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + message = str(excinfo.value) + assert "Source file:" not in message # no empty path leak + assert "Re-bake" in message # actionable guidance still present + assert "bf16 keys" not in message + finally: + restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr_node_signature.py b/tests-unit/comfy_extras_test/test_seedvr_node_signature.py new file mode 100644 index 000000000..c16993f4e --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr_node_signature.py @@ -0,0 +1,103 @@ +"""Regression test: SeedVR2 resize schema input ids must match +execute() positional parameter order. Drift between the two would silently +swap arguments at runtime; this test fails loudly on any future drift. + +The schema input attribute is `.id` (verified live via Python introspection +on the upstream class -- there is no `.name`). + +`comfy.model_management` is stubbed via `patch.dict(sys.modules, ...)` for +the import performed inside this test, so importing +`comfy_extras.nodes_seedvr` here does not call +`torch.cuda.is_available()` or trigger other GPU/server-side +initialization through that dependency. Live introspection indicated that +`comfy_extras.nodes_seedvr` pulls in `comfy.model_management` +transitively here (not `nodes`, not `server`). + +The test snapshots three pieces of import state before patching and +restores all three in `finally` via a sentinel: + +1. `sys.modules["comfy_extras.nodes_seedvr"]` +2. `comfy.model_management` package attribute on the `comfy` package +3. `comfy_extras.nodes_seedvr` attribute on the `comfy_extras` package + +If any of the three was set before the test, it is restored verbatim; +if it was unset, it is deleted on exit. This prevents the test from +clobbering a real `comfy.model_management` (or +`comfy_extras.nodes_seedvr`) module that another test may have +legitimately imported earlier in the same pytest process, while still +preventing the test's mock from leaking into later tests that import +the real `comfy_extras.nodes_seedvr`.""" + +import importlib +import inspect +import sys +from unittest.mock import MagicMock, patch + +from comfy.cli_args import args as cli_args + + +def test_seedvr_node_signature_matches_schema(): + mock_model_management = MagicMock() + mock_model_management.xformers_enabled.return_value = False + mock_model_management.xformers_enabled_vae.return_value = False + mock_model_management.sage_attention_enabled.return_value = False + mock_model_management.flash_attention_enabled.return_value = False + sentinel = object() + prior_cpu = cli_args.cpu + cli_args.cpu = True + + comfy_module_pre = sys.modules.get("comfy") + comfy_extras_module_pre = sys.modules.get("comfy_extras") + prior_comfy_mm_attr = ( + getattr(comfy_module_pre, "model_management", sentinel) + if comfy_module_pre is not None + else sentinel + ) + prior_comfy_extras_seedvr_attr = ( + getattr(comfy_extras_module_pre, "nodes_seedvr", sentinel) + if comfy_extras_module_pre is not None + else sentinel + ) + prior_comfy_extras_seedvr_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) + + with patch.dict(sys.modules, {"comfy.model_management": mock_model_management}): + if comfy_module_pre is not None: + setattr(comfy_module_pre, "model_management", mock_model_management) + sys.modules.pop("comfy_extras.nodes_seedvr", None) + try: + nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") + for node_cls in ( + nodes_seedvr.SeedVR2Resize, + nodes_seedvr.SeedVR2ResizeAdvanced, + ): + schema_ids = [i.id for i in node_cls.define_schema().inputs] + exec_params = [ + p + for p in inspect.signature(node_cls.execute).parameters.keys() + if p != "cls" + ] + assert schema_ids == exec_params, ( + f"{node_cls.__name__} schema input ids do not match " + f"execute() parameter order: schema_ids={schema_ids}, " + f"exec_params={exec_params}" + ) + finally: + if prior_comfy_extras_seedvr_module is sentinel: + sys.modules.pop("comfy_extras.nodes_seedvr", None) + else: + sys.modules["comfy_extras.nodes_seedvr"] = prior_comfy_extras_seedvr_module + cli_args.cpu = prior_cpu + comfy_extras_module = sys.modules.get("comfy_extras") + if comfy_extras_module is not None: + if prior_comfy_extras_seedvr_attr is sentinel: + if hasattr(comfy_extras_module, "nodes_seedvr"): + delattr(comfy_extras_module, "nodes_seedvr") + else: + setattr(comfy_extras_module, "nodes_seedvr", prior_comfy_extras_seedvr_attr) + comfy_module = sys.modules.get("comfy") + if comfy_module is not None: + if prior_comfy_mm_attr is sentinel: + if hasattr(comfy_module, "model_management"): + delattr(comfy_module, "model_management") + else: + setattr(comfy_module, "model_management", prior_comfy_mm_attr) diff --git a/tests-unit/comfy_test/test_seedvr2_hidden_state_static_audit.py b/tests-unit/comfy_test/test_seedvr2_hidden_state_static_audit.py new file mode 100644 index 000000000..a85eda627 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_hidden_state_static_audit.py @@ -0,0 +1,40 @@ +import ast +from pathlib import Path + +import pytest + + +ROOT = Path(__file__).resolve().parents[2] +FILES = [ + ROOT / "comfy/ldm/seedvr/vae.py", + ROOT / "comfy/sd.py", + ROOT / "comfy_extras/nodes_seedvr.py", +] +FORBIDDEN_ATTRS = {"original_image_video", "img_dims", "tiled_args"} +FORBIDDEN_KEYS = { + "sampler_metadata", + "latent_sidecar_metadata", + "saved_latent_metadata", + "workflow_hidden_state", +} +FORBIDDEN_GETSET_KEYS = {"original_image_video", "img_dims", "tiled_args"} + + +def test_seedvr2_decode_paths_do_not_use_hidden_vae_object_state(): + for path in FILES: + tree = ast.parse(path.read_text(encoding="utf-8")) + for node in ast.walk(tree): + if isinstance(node, ast.Attribute) and node.attr in FORBIDDEN_ATTRS: + pytest.fail(f"{path}: forbidden VAE object state attr {node.attr}") + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + if node.func.id in {"getattr", "setattr", "delattr"} and len(node.args) >= 2: + key = node.args[1] + if isinstance(key, ast.Constant) and key.value in FORBIDDEN_GETSET_KEYS: + pytest.fail(f"{path}: forbidden VAE object state access {key.value}") + if isinstance(node, ast.Constant) and isinstance(node.value, str): + if node.value in FORBIDDEN_ATTRS or node.value in FORBIDDEN_KEYS: + pytest.fail(f"{path}: forbidden hidden-state string {node.value}") + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests-unit/comfy_test/test_seedvr2_non_goal_static_audit.py b/tests-unit/comfy_test/test_seedvr2_non_goal_static_audit.py new file mode 100644 index 000000000..01892be77 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_non_goal_static_audit.py @@ -0,0 +1,43 @@ +import os +import subprocess +from pathlib import Path + +import pytest + + +ROOT = Path(__file__).resolve().parents[2] +FORBIDDEN_FILES = { + "comfy/ldm/seedvr/model.py", + "comfy/ldm/modules/attention.py", + "comfy/sample.py", + "comfy/samplers.py", +} + +pytestmark = pytest.mark.skipif( + os.environ.get("SEEDVR2_NON_GOAL_STATIC_AUDIT") != "1", + reason="SEEDVR2_NON_GOAL_STATIC_AUDIT=1 is required for git-index audit execution.", +) + + +def _git_changed_paths(*args): + result = subprocess.run( + ["git", "-C", str(ROOT), "diff", "--name-only", *args], + text=True, + capture_output=True, + check=False, + ) + if result.returncode != 0: + pytest.skip(f"git diff unavailable: {result.stderr.strip()}") + return set(result.stdout.splitlines()) + + +def test_seedvr2_non_goal_files_are_not_dirty(): + changed = _git_changed_paths() + changed.update(_git_changed_paths("--cached")) + changed_forbidden = sorted(FORBIDDEN_FILES.intersection(changed)) + if changed_forbidden: + pytest.fail(f"forbidden non-goal files changed: {changed_forbidden}") + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests-unit/comfy_test/test_seedvr2_resize_and_pad_pre_encode_state.py b/tests-unit/comfy_test/test_seedvr2_resize_and_pad_pre_encode_state.py new file mode 100644 index 000000000..21a16b227 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_resize_and_pad_pre_encode_state.py @@ -0,0 +1,110 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402 + + +def test_resize_simple_multiplier_resolves_upscaled_shorter_edge(): + images = torch.zeros(1, 3, 16, 20, 3) + + output = nodes_seedvr.SeedVR2Resize.execute(images, 4.0) + + input_pixels, original_image, upscaled_shorter_edge = output.result + assert tuple(input_pixels.shape) == (1, 5, 64, 80, 3) + assert input_pixels.min().item() == 0.0 + assert input_pixels.max().item() == 0.0 + assert original_image is images + assert upscaled_shorter_edge == 64 + + +def test_resize_simple_silent_spatial_padding_keeps_unpadded_edge_output(): + images = torch.zeros(1, 1, 16, 16, 3) + + output = nodes_seedvr.SeedVR2Resize.execute(images, 7.5) + + input_pixels, original_image, upscaled_shorter_edge = output.result + assert tuple(input_pixels.shape) == (1, 1, 128, 128, 3) + assert original_image is images + assert upscaled_shorter_edge == 120 + + +def test_resize_simple_rejects_non_positive_multiplier(): + images = torch.zeros(1, 1, 16, 16, 3) + + try: + nodes_seedvr.SeedVR2Resize.execute(images, 0.0) + except ValueError as e: + assert "multiplier must be > 0" in str(e) + else: + raise AssertionError("non-positive multiplier was not rejected") + + +def test_resize_simple_rejects_multiplier_resolving_to_too_small_edge(): + images = torch.zeros(1, 1, 16, 16, 3) + + try: + nodes_seedvr.SeedVR2Resize.execute(images, 0.01) + except ValueError as e: + assert "multiplier resolved upscaled_shorter_edge" in str(e) + assert "at least 2 pixels" in str(e) + else: + raise AssertionError("too-small resolved edge was not rejected") + + +def test_resize_advanced_takes_exact_shorter_edge(): + images = torch.zeros(1, 1, 16, 16, 3) + + output = nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 120) + + input_pixels, original_image, upscaled_shorter_edge = output.result + assert tuple(input_pixels.shape) == (1, 1, 128, 128, 3) + assert original_image is images + assert upscaled_shorter_edge == 120 + + +def test_resize_advanced_treats_4d_image_as_one_video_frame_sequence(): + images = torch.zeros(2, 16, 16, 3) + + output = nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 120) + + input_pixels, original_image, upscaled_shorter_edge = output.result + assert tuple(input_pixels.shape) == (1, 5, 128, 128, 3) + assert original_image is images + assert upscaled_shorter_edge == 120 + + +def test_resize_advanced_rejects_one_pixel_shorter_edge(): + images = torch.zeros(1, 1, 16, 16, 3) + + try: + nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 1) + except ValueError as e: + assert "upscaled_shorter_edge must be at least 2 pixels" in str(e) + else: + raise AssertionError("one-pixel shorter_edge was not rejected") + + +def test_resize_node_schemas_and_execute_signatures_are_preprocess_only(): + simple = nodes_seedvr.SeedVR2Resize.define_schema() + advanced = nodes_seedvr.SeedVR2ResizeAdvanced.define_schema() + + assert [item.id for item in simple.inputs] == ["images", "multiplier"] + assert simple.inputs[1].default == 4.0 + assert [item.id for item in simple.outputs] == [ + "input_pixels", + "original_image", + "upscaled_shorter_edge", + ] + + assert [item.id for item in advanced.inputs] == ["images", "shorter_edge"] + assert advanced.inputs[1].min == 2 + assert advanced.inputs[1].step is None + assert [item.id for item in advanced.outputs] == [ + "input_pixels", + "original_image", + "upscaled_shorter_edge", + ] diff --git a/tests-unit/comfy_test/test_seedvr2_saved_latent_decode_boundary.py b/tests-unit/comfy_test/test_seedvr2_saved_latent_decode_boundary.py new file mode 100644 index 000000000..24eec8301 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_saved_latent_decode_boundary.py @@ -0,0 +1,38 @@ +import io + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 +import nodes as nodes_mod # noqa: E402 + + +class _DecodeOnlyVAE: + def __init__(self): + self.decode_calls = 0 + + def decode(self, latent): + self.decode_calls += 1 + b, tc, h, w = latent.shape + t = tc // 16 + return torch.full((b, t, h * 8, w * 8, 3), 0.25) + + +def test_saved_loaded_seedvr2_latent_decode_boundary_does_not_rerun_preprocessing(): + latent = {"samples": torch.zeros(1, 32, 4, 5)} + buffer = io.BytesIO() + torch.save(latent["samples"], buffer) + buffer.seek(0) + loaded = {"samples": torch.load(buffer, weights_only=True)} + + vae = _DecodeOnlyVAE() + decoded = nodes_mod.VAEDecode().decode(vae, loaded)[0] + original = torch.full((1, 2, 32, 40, 3), 0.75) + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 32, "none").result[0] + + assert vae.decode_calls == 1 + assert tuple(output.shape) == (2, 32, 40, 3) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_graph_boundaries.py b/tests-unit/comfy_test/test_seedvr2_vae_graph_boundaries.py new file mode 100644 index 000000000..a6e48801a --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_vae_graph_boundaries.py @@ -0,0 +1,210 @@ +from unittest.mock import MagicMock + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +import nodes as nodes_mod # noqa: E402 + + +class _Patcher: + def get_free_memory(self, device): + return 1024 * 1024 * 1024 + + +class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self, encoded): + nn.Module.__init__(self) + self.encoded = encoded + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.seen = [] + + def encode(self, x): + self.seen.append(tuple(x.shape)) + return self.encoded.to(device=x.device, dtype=x.dtype) + + +class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.calls = [] + + def decode(self, z, seedvr2_tiling=None): + self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) + if z.ndim == 4: + b, tc, h, w = z.shape + t = tc // 16 + else: + b, _, t, h, w = z.shape + return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) + + +def _make_vae(wrapper): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = wrapper + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) + vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + vae.output_channels = 3 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.crop_input = False + vae.not_video = False + vae.patcher = _Patcher() + vae.process_input = lambda image: image + vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) + vae.vae_output_dtype = lambda: torch.float32 + vae.memory_used_encode = lambda shape, dtype: 1 + vae.memory_used_decode = lambda shape, dtype: 1 + vae.throw_exception_if_invalid = lambda: None + vae.vae_encode_crop_pixels = lambda pixels: pixels + vae.spacial_compression_decode = lambda: 8 + vae.temporal_compression_decode = lambda: 4 + return vae + + +def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + + encoded = torch.full((1, 16, 2, 4, 5), 2.0) + vae = _make_vae(_EncodeWrapper(encoded)) + pixels = torch.zeros(1, 5, 32, 40, 3) + + node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] + node_latent = node_output["samples"] + assert set(node_output) == {"samples"} + assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) + assert node_latent.dtype == torch.float32 + assert node_latent.stride()[-1] == 1 + assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) + + tiled = torch.full((1, 16, 2, 4, 5), 3.0) + monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) + tiled_output = nodes_mod.VAEEncodeTiled().encode( + vae, + pixels, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + )[0] + tiled_latent = tiled_output["samples"] + assert set(tiled_output) == {"samples"} + assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) + assert tiled_latent.dtype == torch.float32 + assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) + + +def test_seedvr2_decode_and_decode_tiled_do_not_require_preprocessor_state(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + latent = {"samples": torch.zeros(1, 32, 4, 5)} + decoded = nodes_mod.VAEDecode().decode(vae, latent)[0] + assert tuple(decoded.shape) == (2, 32, 40, 3) + + tiled = nodes_mod.VAEDecodeTiled().decode( + vae, + {"samples": torch.zeros(1, 16, 2, 4, 5)}, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + )[0] + assert tuple(tiled.shape) == (2, 32, 40, 3) + + +def test_seedvr2_vaedecode_does_not_repair_latent_layout(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + latent = {"samples": torch.zeros(1, 2, 4, 5, 16)} + nodes_mod.VAEDecode().decode(vae, latent) + + assert vae.first_stage_model.calls == [{"shape": (1, 2, 4, 5, 16), "seedvr2_tiling": None}] + + +def test_seedvr2_vaedecode_keeps_public_channel_first_width_16_latents(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + nodes_mod.VAEDecode().decode( + vae, + {"samples": torch.zeros(1, 16, 4, 5, 16)}, + ) + + assert vae.first_stage_model.calls == [{"shape": (1, 16, 4, 5, 16), "seedvr2_tiling": None}] + + +def test_seedvr2_direct_decode_preserves_channel_first_width_16(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + vae.decode(torch.zeros(1, 16, 2, 4, 16)) + + assert vae.first_stage_model.calls == [{"shape": (1, 16, 2, 4, 16), "seedvr2_tiling": None}] + + +def test_seedvr2_decode_tiled_preserves_direct_channel_first_width_16(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + vae.decode_tiled_seedvr2(torch.zeros(1, 16, 2, 4, 16)) + + assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 2, 4, 16) + + +def test_seedvr2_vaedecode_tiled_keeps_public_channel_first_width_16_latents(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + nodes_mod.VAEDecodeTiled().decode( + vae, + {"samples": torch.zeros(1, 16, 4, 5, 16)}, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + ) + + assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 4, 5, 16) + + +def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + nodes_mod.VAEDecodeTiled().decode( + vae, + {"samples": torch.zeros(1, 16, 2, 4, 5)}, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + ) + + assert vae.first_stage_model.calls == [ + { + "shape": (1, 16, 2, 4, 5), + "seedvr2_tiling": { + "enable_tiling": True, + "tile_size": (512, 512), + "tile_overlap": (64, 64), + "temporal_size": 16, + "temporal_overlap": 4, + }, + } + ] diff --git a/tests-unit/comfy_test/test_seedvr2_windows_static_verify.py b/tests-unit/comfy_test/test_seedvr2_windows_static_verify.py new file mode 100644 index 000000000..1053980f2 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_windows_static_verify.py @@ -0,0 +1,40 @@ +from pathlib import Path + +import pytest + + +ROOT = Path(__file__).resolve().parents[2] + + +def _read(relative): + return (ROOT / relative).read_text(encoding="utf-8") + + +def test_seedvr2_windows_static_contract_tokens(): + nodes = _read("comfy_extras/nodes_seedvr.py") + sd = _read("comfy/sd.py") + vae = _read("comfy/ldm/seedvr/vae.py") + + required = [ + "SeedVR2Resize", + "SeedVR2ResizeAdvanced", + "SeedVR2PostProcessing", + 'io.Image.Input("decoded")', + 'io.Image.Input("original_image")', + 'io.Int.Input("upscaled_shorter_edge", min=2, force_input=True)', + 'io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab")', + "def _format_seedvr2_encoded_samples", + "def decode(self, z, seedvr2_tiling=None)", + ] + for needle in required: + if needle not in nodes + sd + vae: + pytest.fail(f"missing required static token: {needle}") + + forbidden = ["original_image_video", "img_dims", "tiled_args"] + for needle in forbidden: + if needle in nodes + sd + vae: + pytest.fail(f"forbidden hidden-state token remains: {needle}") + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py new file mode 100644 index 000000000..5d7e44c7d --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py @@ -0,0 +1,1070 @@ +"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``. + +Covers: + +- Single-chunk degeneracy (``frames_per_chunk >= T_pixel``) takes the + short-circuit path and calls ``comfy.sample.sample`` exactly once with + the full unsliced latent. +- Multi-chunk path slices ``samples_4d`` along the latent T axis, + invokes the inner sampler once per chunk, and concatenates results + back into the same total ``(B, 16*T_total, H, W)`` shape with no NaN + or Inf values. +- ``frames_per_chunk`` that violates the 4n+1 pixel-frame constraint + is rejected with a typed ``ValueError`` before any model invocation. +- Determinism: given a fixed seed, slicing into N chunks runs each + chunk against the same global noise tensor (sliced per chunk), so + the same seed always produces the same final latent regardless of + chunk count, modulo the inherent T-axis chunk-boundary independence + of the model. +- Latent-space Hann overlap blend: ``temporal_overlap=0`` produces + output byte-identical to the no-overlap path; small-overlap path + uses a linear ramp; Hann blend reconstructs source under a + passthrough inner sampler. + +The tests mock ``comfy.sample.sample``, ``comfy.sample.prepare_noise``, +and ``comfy.sample.fix_empty_latent_channels`` so the slicing / +concatenation / cond-handling logic can be exercised in isolation +without GPU, model weights, or ComfyUI's full sampling stack. +""" + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.sample # noqa: E402 +import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 +from comfy_extras.nodes_seedvr import ( # noqa: E402 + SeedVR2ProgressiveSampler, + _blend_overlap_region, + _concat_chunks_along_t, + _concat_chunks_with_overlap_blend, + _hann_blend_weights_1d, + _slice_collapsed_4d_along_t, + _slice_seedvr2_cond_along_t, +) + +_LAT_C = 16 +_COND_C = 17 + + +def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): + """Build minimal SeedVR2-shaped sampling inputs. + + The latent and condition tensors carry deterministic, reversible + values (an arange laid out in a 5D ``(B, C, T, H, W)`` view that is + then collapsed) so per-chunk slices can be cross-checked against + the original 5D source without ambiguity. + """ + samples_5d = torch.arange( + B * _LAT_C * T * H * W, dtype=torch.float32 + ).reshape(B, _LAT_C, T, H, W) + samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() + + cond_5d = torch.arange( + B * _COND_C * T * H * W, dtype=torch.float32 + ).reshape(B, _COND_C, T, H, W) + 10000.0 + cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() + + text_pos = torch.zeros(1, 4, 32) + text_neg = torch.zeros(1, 4, 32) + positive = [[text_pos, {"condition": cond.clone()}]] + negative = [[text_neg, {"condition": cond.clone()}]] + latent_image = {"samples": samples} + return latent_image, positive, negative, samples_5d, cond_5d + + +def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): + return latent_image + + +def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): + """Return a tensor whose values encode ``(seed, position)`` so the + chunked slicing path can be verified end-to-end against a global + noise tensor. + """ + base = torch.arange( + latent_image.numel(), dtype=torch.float32 + ).reshape(latent_image.shape) + return base + float(seed) * 1e6 + + +def _passthrough_sample_returning_latent( + model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None, +): + """Mock for ``comfy.sample.sample``: returns the per-call + ``latent_image`` unchanged so we can verify the post-concat result + equals the original input under per-chunk slice + concat. + """ + return latent_image.clone() + + +# --------------------------------------------------------------------------- +# Helper-level tests (slicing / concat / cond plumbing) +# --------------------------------------------------------------------------- + + +def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): + schema = SeedVR2ProgressiveSampler.define_schema() + inputs = {item.id: item for item in schema.inputs} + + assert inputs["chunking_mode"].options == ["manual", "auto"] + assert inputs["chunking_mode"].default == "manual" + + +def test_slice_collapsed_4d_along_t_shape_correct(): + t = torch.zeros(1, _LAT_C * 5, 8, 8) + out = _slice_collapsed_4d_along_t(t, 1, 4, _LAT_C) + assert tuple(out.shape) == (1, _LAT_C * 3, 8, 8) + + +def test_slice_collapsed_preserves_per_frame_values(): + """Slicing ``[t_start:t_end]`` must preserve the ``(t_start + i)``-th + latent frame's channel layout at the i'th position of the slice. + """ + B, T, H, W = 1, 6, 4, 4 + t5 = torch.arange( + B * _LAT_C * T * H * W, dtype=torch.float32 + ).reshape(B, _LAT_C, T, H, W) + t4 = t5.reshape(B, _LAT_C * T, H, W).contiguous() + out_4d = _slice_collapsed_4d_along_t(t4, 2, 5, _LAT_C) + out_5d = out_4d.reshape(B, _LAT_C, 3, H, W) + for i, src_t in enumerate([2, 3, 4]): + assert torch.equal(out_5d[:, :, i], t5[:, :, src_t]) + + +def test_slice_collapsed_4d_along_t_accepts_non_contiguous_input(): + """Collapsed latents may arrive from slicing/cropping views; temporal + slicing must not require contiguous input storage. + """ + B, T, H, W = 1, 5, 4, 4 + wide = torch.arange( + B * _LAT_C * T * H * W * 2, dtype=torch.float32, + ).reshape(B, _LAT_C * T, H, W * 2) + src = wide[:, :, :, ::2] + assert not src.is_contiguous() + + out = _slice_collapsed_4d_along_t(src, 1, 4, _LAT_C) + expected = src.reshape(B, _LAT_C, T, H, W)[:, :, 1:4].contiguous() + expected = expected.reshape(B, _LAT_C * 3, H, W) + + assert torch.equal(out, expected) + + +def test_concat_chunks_along_t_roundtrip_recovers_source(): + """Slicing a tensor and concatenating the slices must reproduce the + source byte-identically (within tensor equality). + """ + B, T, H, W = 1, 7, 4, 4 + t = torch.arange( + B * _LAT_C * T * H * W, dtype=torch.float32 + ).reshape(B, _LAT_C, T, H, W).reshape(B, _LAT_C * T, H, W).contiguous() + a = _slice_collapsed_4d_along_t(t, 0, 3, _LAT_C) + b = _slice_collapsed_4d_along_t(t, 3, 5, _LAT_C) + c = _slice_collapsed_4d_along_t(t, 5, 7, _LAT_C) + cat = _concat_chunks_along_t([a, b, c], _LAT_C) + assert torch.equal(cat, t) + + +def test_concat_chunks_along_t_accepts_non_contiguous_chunks(): + """Concatenation must accept non-contiguous chunk tensors returned by + sampling or upstream tensor views. + """ + B, H, W = 1, 4, 4 + wide_a = torch.arange( + B * _LAT_C * 2 * H * W * 2, dtype=torch.float32, + ).reshape(B, _LAT_C * 2, H, W * 2) + wide_b = torch.arange( + B * _LAT_C * 3 * H * W * 2, dtype=torch.float32, + ).reshape(B, _LAT_C * 3, H, W * 2) + 10000.0 + chunk_a = wide_a[:, :, :, ::2] + chunk_b = wide_b[:, :, :, ::2] + assert not chunk_a.is_contiguous() + assert not chunk_b.is_contiguous() + + out = _concat_chunks_along_t([chunk_a, chunk_b], _LAT_C) + expected = torch.cat( + [ + chunk_a.reshape(B, _LAT_C, 2, H, W), + chunk_b.reshape(B, _LAT_C, 3, H, W), + ], + dim=2, + ).reshape(B, _LAT_C * 5, H, W) + + assert tuple(out.shape) == (B, _LAT_C * 5, H, W) + assert torch.equal(out, expected) + + +def test_slice_seedvr2_cond_along_t_passes_other_keys_unchanged(): + """The cond-list slicer must mutate only ``options['condition']``; + every other key must pass through unchanged, and the source + options dict must not be mutated. + """ + B, T, H, W = 1, 5, 8, 8 + cond = torch.zeros(B, _COND_C * T, H, W) + text = torch.zeros(1, 4, 32) + sentinel = object() + src_options = {"condition": cond, "extra_key": sentinel} + cond_list = [[text, src_options]] + out = _slice_seedvr2_cond_along_t(cond_list, 1, 4) + assert out[0][1]["extra_key"] is sentinel + assert out[0][1]["condition"].shape == (B, _COND_C * 3, H, W) + # Source options dict not mutated. + assert src_options["condition"].shape == (B, _COND_C * T, H, W) + + +def test_slice_seedvr2_cond_passes_through_entries_without_condition_key(): + """Entries lacking a ``condition`` key are forwarded verbatim — the + sampler must not crash on conditioning produced by non-SeedVR2 + upstream nodes. + """ + text = torch.zeros(1, 4, 32) + cond_list = [[text, {"unrelated": 1}]] + out = _slice_seedvr2_cond_along_t(cond_list, 0, 1) + assert out[0] is cond_list[0] + assert out[0][1] == {"unrelated": 1} + + +# --------------------------------------------------------------------------- +# Single-chunk degeneracy +# --------------------------------------------------------------------------- + + +def test_t1_single_chunk_degeneracy_calls_sampler_once_with_full_latent(): + """When ``frames_per_chunk >= T_pixel``, the short-circuit + standard path runs and calls ``comfy.sample.sample`` exactly once + with the full unsliced ``(B, 16*T_total, H, W)`` latent. + """ + latent, pos, neg, _, _ = _make_inputs(T=5) # T_pixel = 4*4+1 = 17 + full_shape = tuple(latent["samples"].shape) + calls = [] + + def _record(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + calls.append(tuple(latent_image.shape)) + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_record), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + assert len(calls) == 1 + assert calls[0] == full_shape + out_latent = out.result[0] + assert tuple(out_latent["samples"].shape) == full_shape + + +# --------------------------------------------------------------------------- +# Multi-chunk path +# --------------------------------------------------------------------------- + + +def test_t2_two_chunk_path_shape_preserved_and_no_nan_inf(): + """A T_pixel that exceeds frames_per_chunk + triggers chunking; the inner sampler is invoked once per chunk; + the concatenated output preserves the original + ``(B, 16*T_total, H, W)`` shape and contains no NaN/Inf values. + """ + # T_latent=11 -> T_pixel=4*10+1=41; chunk_pixel=21 -> chunk_latent=6. + # Expected chunks: [0:6], [6:11] (two chunks; second is a runt of 5). + latent, pos, neg, _, _ = _make_inputs(T=11) + full_shape = tuple(latent["samples"].shape) + chunk_shapes = [] + + def _record(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + chunk_shapes.append(tuple(latent_image.shape)) + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_record), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + # Two chunks: latent T = 6 then 5. + assert len(chunk_shapes) == 2 + assert chunk_shapes[0] == (1, _LAT_C * 6, 8, 8) + assert chunk_shapes[1] == (1, _LAT_C * 5, 8, 8) + + # Final shape preserved. + out_latent = out.result[0] + assert tuple(out_latent["samples"].shape) == full_shape + + # Boundedness. + samples_out = out_latent["samples"] + assert not torch.isnan(samples_out).any() + assert not torch.isinf(samples_out).any() + + +def test_t2_concat_equals_source_under_passthrough_sampler(): + """When the inner sampler is a passthrough (returns its + ``latent_image`` argument verbatim), the multi-chunk run must + reconstruct the original input latent byte-identically — that is, + the slice / sample / concat composition is the identity on the + latent. + """ + latent, pos, neg, _, _ = _make_inputs(T=11) + src = latent["samples"].clone() + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + out_latent = out.result[0] + assert torch.equal(out_latent["samples"], src) + + +def test_t2_per_chunk_cond_slice_matches_chunk_latent_t(): + """Each per-chunk ``comfy.sample.sample`` invocation must receive + a positive / negative cond list whose ``condition`` tensor has been + sliced to match the chunk's latent length. + """ + latent, pos, neg, _, _ = _make_inputs(T=11) + cond_shapes = [] + + def _record_conds(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + pos_cond_t = positive[0][1]["condition"] + neg_cond_t = negative[0][1]["condition"] + cond_shapes.append((tuple(pos_cond_t.shape), tuple(neg_cond_t.shape))) + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_record_conds), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + assert cond_shapes[0] == ((1, _COND_C * 6, 8, 8), (1, _COND_C * 6, 8, 8)) + assert cond_shapes[1] == ((1, _COND_C * 5, 8, 8), (1, _COND_C * 5, 8, 8)) + + +def test_t2_standard_noise_mask_passed_through_for_sampler_expansion(): + """Standard ``SetLatentNoiseMask`` masks are ``(B, 1, H, W)`` and + must be forwarded unchanged so KSampler can expand them to each + chunk's latent shape. + """ + latent, pos, neg, _, _ = _make_inputs(T=11) + latent["noise_mask"] = torch.ones(1, 1, 8, 8) + mask_shapes = [] + + def _record_mask(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + mask_shapes.append(tuple(noise_mask.shape)) + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_record_mask), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + assert mask_shapes == [(1, 1, 8, 8), (1, 1, 8, 8)] + + +def test_t2_collapsed_noise_mask_sliced_per_chunk(): + """A pre-expanded collapsed ``(B, 16*T, H, W)`` noise mask must be + sliced along latent T to match each chunk before sampling. + """ + latent, pos, neg, _, _ = _make_inputs(T=11) + latent["noise_mask"] = torch.ones_like(latent["samples"]) + mask_shapes = [] + + def _record_mask(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + mask_shapes.append(tuple(noise_mask.shape)) + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_record_mask), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + assert mask_shapes == [(1, _LAT_C * 6, 8, 8), (1, _LAT_C * 5, 8, 8)] + + +# --------------------------------------------------------------------------- +# Auto chunking OOM fallback +# --------------------------------------------------------------------------- + + +def test_auto_chunking_success_without_retry(): + """Auto mode must leave a successful current chunk geometry alone.""" + latent, pos, neg, _, _ = _make_inputs(T=11) + calls = [] + + def _record(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + calls.append(tuple(latent_image.shape)) + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_record), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache") as soft_empty: + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + chunking_mode="auto", + ) + + assert calls == [(1, _LAT_C * 6, 8, 8), (1, _LAT_C * 5, 8, 8)] + assert torch.equal(out.result[0]["samples"], latent["samples"]) + soft_empty.assert_not_called() + + +def test_auto_chunking_retries_current_oom_with_next_stricter_chunk(): + """An OOM in the current geometry must retry with a smaller chunk.""" + latent, pos, neg, _, _ = _make_inputs(T=11) + calls = [] + + def _oom_on_full(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + calls.append(tuple(latent_image.shape)) + if latent_image.shape[1] == _LAT_C * 11: + raise torch.cuda.OutOfMemoryError("full oom") + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_oom_on_full), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache") as soft_empty: + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=45, temporal_overlap=0, + chunking_mode="auto", + ) + + assert calls == [ + (1, _LAT_C * 11, 8, 8), + (1, _LAT_C * 6, 8, 8), + (1, _LAT_C * 5, 8, 8), + ] + assert torch.equal(out.result[0]["samples"], latent["samples"]) + assert soft_empty.call_count == 1 + + +def test_auto_chunking_walks_two_three_four_chunk_ladder(): + """Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM.""" + latent, pos, neg, _, _ = _make_inputs(T=17) + calls = [] + + def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name, + scheduler, positive, negative, + latent_image, denoise=1.0, + noise_mask=None, seed=None): + calls.append(tuple(latent_image.shape)) + if latent_image.shape[1] > _LAT_C * 5: + raise torch.cuda.OutOfMemoryError("chunk too large") + return latent_image.clone() + + with patch.object(comfy.sample, "sample", + side_effect=_oom_until_four_chunks), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache") as soft_empty: + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=65, temporal_overlap=0, + chunking_mode="auto", + ) + + assert calls[:4] == [ + (1, _LAT_C * 17, 8, 8), + (1, _LAT_C * 9, 8, 8), + (1, _LAT_C * 6, 8, 8), + (1, _LAT_C * 5, 8, 8), + ] + assert torch.equal(out.result[0]["samples"], latent["samples"]) + assert soft_empty.call_count == 3 + + +def test_auto_chunking_exhausted_floor_rethrows_loudly(): + """If one-latent-frame chunks still OOM, auto mode must fail loud.""" + latent, pos, neg, _, _ = _make_inputs(T=3) + + def _always_oom(*args, **kwargs): + raise torch.cuda.OutOfMemoryError("stable oom") + + with patch.object(comfy.sample, "sample", side_effect=_always_oom), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache") as soft_empty: + with pytest.raises(RuntimeError) as excinfo: + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=9, temporal_overlap=0, + chunking_mode="auto", + ) + + assert "exhausted auto chunking attempts" in str(excinfo.value) + assert "[9, 5, 1]" in str(excinfo.value) + assert soft_empty.call_count == 2 + + +def test_auto_chunking_non_oom_does_not_retry(): + """Only real OOM failures are eligible for auto chunk retry.""" + latent, pos, neg, _, _ = _make_inputs(T=11) + + def _raise_non_oom(*args, **kwargs): + raise ValueError("not oom") + + with patch.object(comfy.sample, "sample", side_effect=_raise_non_oom), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache") as soft_empty: + with pytest.raises(ValueError, match="not oom"): + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=45, temporal_overlap=0, + chunking_mode="auto", + ) + + soft_empty.assert_not_called() + + +def test_auto_chunking_matches_manual_at_resolved_chunk_size(): + """After resolving to a chunk size, auto output must match manual.""" + latent_auto, pos_auto, neg_auto, _, _ = _make_inputs(T=11) + latent_manual, pos_manual, neg_manual, _, _ = _make_inputs(T=11) + + def _oom_full_only(model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0, + noise_mask=None, seed=None): + if latent_image.shape[1] == _LAT_C * 11: + raise torch.cuda.OutOfMemoryError("full oom") + return latent_image.clone() + + with patch.object(comfy.sample, "sample", side_effect=_oom_full_only), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache"): + out_auto = SeedVR2ProgressiveSampler.execute( + model=None, seed=123, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos_auto, negative=neg_auto, latent_image=latent_auto, + denoise=1.0, frames_per_chunk=45, temporal_overlap=0, + chunking_mode="auto", + ) + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out_manual = SeedVR2ProgressiveSampler.execute( + model=None, seed=123, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos_manual, negative=neg_manual, + latent_image=latent_manual, denoise=1.0, + frames_per_chunk=21, temporal_overlap=0, + ) + + assert torch.equal(out_auto.result[0]["samples"], + out_manual.result[0]["samples"]) + + +# --------------------------------------------------------------------------- +# 4n+1 violation rejection +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("bad_chunk", [0, -1, 2, 3, 4, 6, 7, 8, 10, 12]) +def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): + """``frames_per_chunk`` violating 4n+1 (for n >= 0) must raise + ``ValueError`` with a message naming the offending value, before any + model invocation. ``frames_per_chunk < 1`` is also rejected. + """ + latent, pos, neg, _, _ = _make_inputs(T=5) + + sampler_called = {"n": 0} + + def _should_not_be_called(*args, **kwargs): + sampler_called["n"] += 1 + return torch.zeros(1) + + with patch.object(comfy.sample, "sample", + side_effect=_should_not_be_called), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + with pytest.raises(ValueError) as excinfo: + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, + ) + assert str(bad_chunk) in str(excinfo.value) + assert sampler_called["n"] == 0 + + +@pytest.mark.parametrize("good_chunk", [1, 5, 9, 13, 17, 21, 25]) +def test_t3_valid_frames_per_chunk_does_not_raise(good_chunk): + """The 4n+1 sequence (1, 5, 9, 13, ...) must be accepted.""" + latent, pos, neg, _, _ = _make_inputs(T=5) + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=good_chunk, temporal_overlap=0, + ) + + +# --------------------------------------------------------------------------- +# Determinism +# --------------------------------------------------------------------------- + + +def test_t4_determinism_same_seed_same_output(): + """Two runs with identical (seed, inputs, + frames_per_chunk) must produce byte-identical output, given the + inner sampler is deterministic (here: passthrough). + """ + latent_a, pos_a, neg_a, _, _ = _make_inputs(T=11) + latent_b, pos_b, neg_b, _, _ = _make_inputs(T=11) + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out_a = SeedVR2ProgressiveSampler.execute( + model=None, seed=42, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos_a, negative=neg_a, latent_image=latent_a, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + out_b = SeedVR2ProgressiveSampler.execute( + model=None, seed=42, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos_b, negative=neg_b, latent_image=latent_b, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + assert torch.equal(out_a.result[0]["samples"], + out_b.result[0]["samples"]) + + +def test_t4_chunk_count_invariance_under_passthrough(): + """When the inner sampler is the identity, the final latent must be + identical regardless of how the work is partitioned: a single-chunk + run and a multi-chunk run on the same input must produce the same + output. This pins the slice / concat composition as a true identity + on the latent under a deterministic inner sampler. + """ + latent_single, pos_s, neg_s, _, _ = _make_inputs(T=11) + latent_multi, pos_m, neg_m, _, _ = _make_inputs(T=11) + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out_single = SeedVR2ProgressiveSampler.execute( + model=None, seed=7, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos_s, negative=neg_s, latent_image=latent_single, + denoise=1.0, frames_per_chunk=45, temporal_overlap=0, # >= T_pixel=41 + ) + out_multi = SeedVR2ProgressiveSampler.execute( + model=None, seed=7, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos_m, negative=neg_m, latent_image=latent_multi, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, # forces 2 chunks + ) + + assert torch.equal(out_single.result[0]["samples"], + out_multi.result[0]["samples"]) + + +# --------------------------------------------------------------------------- +# Hann overlap blend helper tests (Hann window + blend region + concat-with-blend) +# --------------------------------------------------------------------------- + + +def test_hann_weights_overlap_3_matches_numz_formula(): + """At ``overlap >= 3`` the Hann formula + ``0.5 + 0.5 * cos(pi * u)`` (with the [1/3, 2/3] dead-band) + must produce values identical to numz's + ``blend_overlapping_frames``: endpoints at ``1.0`` and ``0.0`` for + the previous-chunk weight, midpoint at ``0.5``. + """ + w = _hann_blend_weights_1d(3, torch.device("cpu"), torch.float32) + assert tuple(w.shape) == (3,) + assert torch.allclose(w[0], torch.tensor(1.0)) + assert torch.allclose(w[-1], torch.tensor(0.0)) + assert torch.allclose(w[1], torch.tensor(0.5), atol=1e-6) + + +def test_hann_weights_overlap_lt_3_uses_linear_ramp(): + """At ``overlap < 3`` the Hann dead-band collapses, so the helper + falls back to a linear ramp from 1.0 to 0.0. + """ + w1 = _hann_blend_weights_1d(1, torch.device("cpu"), torch.float32) + assert torch.equal(w1, torch.tensor([1.0])) + w2 = _hann_blend_weights_1d(2, torch.device("cpu"), torch.float32) + assert torch.equal(w2, torch.tensor([1.0, 0.0])) + + +def test_hann_weights_monotone_non_increasing(): + """The previous-chunk weight is a crossfade ramp; it must be + non-increasing along the overlap axis (any reversal would produce + audible/visible boundary artifacts). + """ + for n in [3, 4, 5, 7, 8, 11, 16]: + w = _hann_blend_weights_1d(n, torch.device("cpu"), torch.float32) + diffs = w[1:] - w[:-1] + assert torch.all(diffs <= 1e-6), ( + f"Hann weights non-monotone at overlap={n}: {w.tolist()}" + ) + + +def test_blend_region_endpoints_reproduce_pure_chunks(): + """At the first overlap position the result must equal the + previous chunk's tail; at the last position it must equal the + current chunk's head. Verifies the weights actually anchor at 0 + and 1 ends on the underlying tensor. + """ + B, C, T_overlap, H, W = 1, 16, 5, 4, 4 + prev = torch.full((B, C, T_overlap, H, W), 7.0) + cur = torch.full((B, C, T_overlap, H, W), -3.0) + blended = _blend_overlap_region(prev, cur) + assert torch.allclose(blended[:, :, 0], prev[:, :, 0]) + assert torch.allclose(blended[:, :, -1], cur[:, :, -1]) + + +def test_blend_region_equal_inputs_returns_input(): + """If both chunks agree perfectly in the overlap region, the + crossfade output must equal the common value at every position. + Linear combination of equal inputs is always the input. + """ + B, C, T_overlap, H, W = 1, 16, 5, 4, 4 + same = torch.randn(B, C, T_overlap, H, W) + blended = _blend_overlap_region(same.clone(), same.clone()) + assert torch.allclose(blended, same, atol=1e-6) + + +def test_concat_with_overlap_zero_matches_plain_concat(): + """``overlap_latent == 0`` must take the fast path and produce the + same tensor as ``_concat_chunks_along_t`` of the same chunks. + Required so that ``temporal_overlap=0`` is byte-identical to the + no-overlap chunked path. + """ + B, T1, T2, H, W = 1, 3, 4, 4, 4 + a4 = torch.randn(B, _LAT_C * T1, H, W) + b4 = torch.randn(B, _LAT_C * T2, H, W) + plain = _concat_chunks_along_t([a4, b4], _LAT_C) + blended = _concat_chunks_with_overlap_blend( + [(0, T1, a4), (T1, T1 + T2, b4)], _LAT_C, overlap_latent=0, + ) + assert torch.equal(blended, plain) + + +def test_concat_with_overlap_two_chunks_blends_only_overlap_region(): + """For two chunks that overlap by ``overlap_latent`` latent frames, + the non-overlap portions must be copied verbatim from each chunk; + only the overlap region carries the blended values. + """ + B, H, W = 1, 4, 4 + chunk_T = 4 + overlap = 2 + cs0, ce0 = 0, chunk_T # 0..3 + cs1, ce1 = chunk_T - overlap, chunk_T - overlap + chunk_T # 2..5 + a4 = torch.full((B, _LAT_C * chunk_T, H, W), 1.0) + b4 = torch.full((B, _LAT_C * chunk_T, H, W), 2.0) + out = _concat_chunks_with_overlap_blend( + [(cs0, ce0, a4), (cs1, ce1, b4)], _LAT_C, + overlap_latent=overlap, + ) + assert tuple(out.shape) == (B, _LAT_C * (chunk_T + chunk_T - overlap), H, W) + out_5d = out.view(B, _LAT_C, chunk_T + chunk_T - overlap, H, W) + # Pre-overlap: chunk 0 verbatim (index 0..chunk_T - overlap - 1) + for i in range(chunk_T - overlap): + assert torch.allclose(out_5d[:, :, i], torch.tensor(1.0)) + # Post-overlap: chunk 1 verbatim (last chunk_T - overlap frames) + for i in range(chunk_T + chunk_T - overlap - (chunk_T - overlap), + chunk_T + chunk_T - overlap): + assert torch.allclose(out_5d[:, :, i], torch.tensor(2.0)) + + +def test_concat_with_overlap_runt_chunk_uses_min_available_overlap(): + """When the final chunk is a runt shorter than the configured + overlap, the blend must be performed on the actually-available + overlap width rather than overrun the runt chunk. + """ + B, H, W = 1, 4, 4 + overlap = 3 + a4 = torch.full((B, _LAT_C * 4, H, W), 1.0) # T 0..3 + b4 = torch.full((B, _LAT_C * 1, H, W), 2.0) # T 1..1 (runt of 1) + # b4 starts at 1, ends at 2: overlaps [1:4] -> available width 1. + out = _concat_chunks_with_overlap_blend( + [(0, 4, a4), (1, 2, b4)], _LAT_C, overlap_latent=overlap, + ) + # Total covered: indices 0..3 -> length 4. + assert tuple(out.shape) == (B, _LAT_C * 4, H, W) + + +# --------------------------------------------------------------------------- +# overlap=0 is byte-identical to the no-overlap chunked path +# --------------------------------------------------------------------------- + + +def test_t5_overlap_zero_byte_identical_to_slice1_path(): + """``temporal_overlap=0`` must produce output byte-identical + to the no-overlap chunked path under a deterministic inner sampler. + Verifies the overlap=0 fast path is wired correctly through + ``_concat_chunks_with_overlap_blend``. + """ + latent, pos, neg, _, _ = _make_inputs(T=11) + src = latent["samples"].clone() + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=0, + ) + + out_latent = out.result[0] + assert torch.equal(out_latent["samples"], src) + + +# --------------------------------------------------------------------------- +# Small overlap (linear ramp path) +# --------------------------------------------------------------------------- + + +def test_t6_small_overlap_linear_ramp_no_nan_inf(): + """``temporal_overlap=2`` exercises + the linear-ramp fallback (overlap < 3). The output must preserve + the source's overall T_total shape and contain no NaN/Inf. + """ + latent, pos, neg, _, _ = _make_inputs(T=11) + full_shape = tuple(latent["samples"].shape) + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=2, + ) + + samples_out = out.result[0]["samples"] + assert tuple(samples_out.shape) == full_shape + assert not torch.isnan(samples_out).any() + assert not torch.isinf(samples_out).any() + + +# --------------------------------------------------------------------------- +# Hann blend (overlap >= 3): bounded, no boundary discontinuity +# --------------------------------------------------------------------------- + + +def test_t7_hann_blend_bounded_under_passthrough_inner_sampler(): + """Boundedness for the Hann path. With a passthrough inner + sampler the per-chunk outputs equal the per-chunk input slices, + so the post-blend output equals the source latent at every frame + (the overlap regions blend two slices of the same source). This + is the strongest available unit-level statement of "no boundary + discontinuity introduced by the blend". + """ + latent, pos, neg, _, _ = _make_inputs(T=11) + src = latent["samples"].clone() + + with patch.object(comfy.sample, "sample", + side_effect=_passthrough_sample_returning_latent), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=21, temporal_overlap=3, + ) + + samples_out = out.result[0]["samples"] + assert torch.allclose(samples_out, src, atol=1e-5), ( + "Passthrough inner sampler + Hann blend must reconstruct source: " + "blending two equal slices of the same source must equal the " + "source at every position." + ) + assert not torch.isnan(samples_out).any() + assert not torch.isinf(samples_out).any() + + +@pytest.mark.parametrize( + ("frames_per_chunk", "expected_sample_calls"), + [ + (1, 5), # chunk_latent=1; overlap=999 resolves to 0. + (5, 4), # chunk_latent=2; overlap=999 resolves to 1. + ], +) +def test_t7_oversized_overlap_uses_maximum_valid_overlap( + frames_per_chunk, expected_sample_calls, +): + """Users do not know the latent chunk length. Oversized positive + ``temporal_overlap`` values must resolve to the maximum valid + overlap instead of hard-failing. + """ + latent, pos, neg, _, _ = _make_inputs(T=5) + src = latent["samples"].clone() + + sampler_called = {"n": 0} + + def _sample(*args, **kwargs): + sampler_called["n"] += 1 + return _passthrough_sample_returning_latent(*args, **kwargs) + + with patch.object(comfy.sample, "sample", + side_effect=_sample), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=frames_per_chunk, + temporal_overlap=999, + ) + assert torch.equal(out.result[0]["samples"], src) + assert sampler_called["n"] == expected_sample_calls + + +def test_t7_negative_overlap_rejected(): + """Negative ``temporal_overlap`` still fails before sampling.""" + latent, pos, neg, _, _ = _make_inputs(T=5) + + sampler_called = {"n": 0} + + def _should_not_be_called(*args, **kwargs): + sampler_called["n"] += 1 + return torch.zeros(1) + + with patch.object(comfy.sample, "sample", + side_effect=_should_not_be_called), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + with pytest.raises(ValueError) as excinfo: + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent_image=latent, + denoise=1.0, frames_per_chunk=5, temporal_overlap=-1, + ) + assert "temporal_overlap" in str(excinfo.value) + assert sampler_called["n"] == 0