From f632ec67da9f86056ff7fa95cd79a677f840296b Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 25 May 2026 22:13:06 -0500 Subject: [PATCH] Add SeedVR2 integration coverage --- .github/workflows/test-unit.yml | 5 +- .../comfy_test/test_seedvr2_refactor_nodes.py | 227 ++++++++++++++++++ 2 files changed, 229 insertions(+), 3 deletions(-) create mode 100644 tests-unit/comfy_test/test_seedvr2_refactor_nodes.py diff --git a/.github/workflows/test-unit.yml b/.github/workflows/test-unit.yml index d05179cd3..c52defc7d 100644 --- a/.github/workflows/test-unit.yml +++ b/.github/workflows/test-unit.yml @@ -2,9 +2,9 @@ name: Unit Tests on: push: - branches: [ main, master, release/** ] + branches: [ main, master, develop, release/** ] pull_request: - branches: [ main, master, release/** ] + branches: [ main, master, develop, release/** ] jobs: test: @@ -12,7 +12,6 @@ jobs: matrix: os: [ubuntu-latest, windows-2022, macos-latest] runs-on: ${{ matrix.os }} - continue-on-error: true steps: - uses: actions/checkout@v4 - name: Set up Python diff --git a/tests-unit/comfy_test/test_seedvr2_refactor_nodes.py b/tests-unit/comfy_test/test_seedvr2_refactor_nodes.py new file mode 100644 index 000000000..40b5f9204 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_refactor_nodes.py @@ -0,0 +1,227 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy_extras.nodes_seedvr as nodes_seedvr +import nodes + + +def test_seedvr2_postprocessing_restores_flat_decoded_batch_time(): + decoded = torch.arange(6 * 4 * 6 * 1, dtype=torch.float32).reshape(6, 4, 6, 1) + original = torch.ones((2, 3, 4, 6, 1), dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0] + + assert output.shape == (6, 4, 6, 1) + torch.testing.assert_close(output, decoded) + + +def test_seedvr2_postprocessing_crops_to_resized_original_size(): + decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32) + original = torch.full((1, 1, 120, 169, 3), 0.25, dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0] + + assert output.shape == (1, 120, 168, 3) + + +def test_seedvr2_postprocessing_uses_decoded_size_when_resized_original_is_larger(): + decoded = torch.ones((1, 128, 160, 3), dtype=torch.float32) + original = torch.full((1, 1, 480, 640, 3), 0.25, dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0] + + assert output.shape == (1, 128, 160, 3) + + +def test_seedvr2_postprocessing_does_not_trim_real_black_original_edges(): + decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32) + original = torch.zeros((1, 1, 128, 176, 3), dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 128, "none").result[0] + + assert output.shape == (1, 128, 176, 3) + + +def test_seedvr2_postprocessing_crops_height_only_to_resized_original_size(): + decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32) + original = torch.full((1, 1, 120, 176, 3), 0.25, dtype=torch.float32) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0] + + assert output.shape == (1, 120, 176, 3) + + +def test_seedvr2_postprocessing_lab_uses_resized_original_size(monkeypatch): + decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32) + original = torch.full((1, 1, 120, 169, 3), 0.25, dtype=torch.float32) + calls = [] + + def fake_lab_color_transfer(decoded_flat, reference_flat): + calls.append((tuple(decoded_flat.shape), tuple(reference_flat.shape))) + return decoded_flat + + monkeypatch.setattr(nodes_seedvr, "lab_color_transfer", fake_lab_color_transfer) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "lab").result[0] + + assert calls == [((1, 3, 120, 169), (1, 3, 120, 169))] + assert output.shape == (1, 120, 168, 3) + + +def test_seedvr2_tiled_decode_node_ignores_seedvr2_sideband_metadata(): + class FakeVAE: + def __init__(self): + self.decode_call = None + + def temporal_compression_decode(self): + return 4 + + def spacial_compression_decode(self): + return 8 + + def decode_tiled(self, samples, **kwargs): + self.decode_call = kwargs + return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32) + + vae = FakeVAE() + samples = { + "samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32), + "seedvr2_channel_last": True, + } + + nodes.VAEDecodeTiled().decode( + vae, + samples, + tile_size=64, + overlap=0, + temporal_size=64, + temporal_overlap=8, + ) + + assert "seedvr2_channel_last" not in vae.decode_call + + +def test_seedvr2_decode_node_ignores_seedvr2_sideband_metadata(): + class FakeVAE: + def __init__(self): + self.decode_call = None + + def decode(self, samples, **kwargs): + self.decode_call = kwargs + return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32) + + vae = FakeVAE() + samples = { + "samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32), + "seedvr2_channel_last": True, + } + + nodes.VAEDecode().decode(vae, samples) + + assert "seedvr2_channel_last" not in vae.decode_call + + +def test_seedvr2_decode_node_leaves_unmarked_ambiguous_latent_unforced(): + class FakeVAE: + def __init__(self): + self.decode_call = None + + def decode(self, samples, **kwargs): + self.decode_call = kwargs + return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32) + + vae = FakeVAE() + samples = {"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)} + + nodes.VAEDecode().decode(vae, samples) + + assert "seedvr2_channel_last" not in vae.decode_call + + +def test_seedvr2_encode_node_does_not_mark_model_specific_layout_metadata(): + class FakeVAE: + def encode(self, pixels): + return torch.zeros((1, 16, 2, 3, 4), dtype=torch.float32) + + output = nodes.VAEEncode().encode(FakeVAE(), torch.zeros((1, 8, 8, 3)))[0] + + assert set(output) == {"samples"} + + +def test_seedvr2_tiled_encode_node_does_not_mark_model_specific_layout_metadata(): + class FakeVAE: + def encode_tiled(self, pixels, **kwargs): + return torch.zeros((1, 16, 2, 3, 4), dtype=torch.float32) + + output = nodes.VAEEncodeTiled().encode(FakeVAE(), torch.zeros((1, 8, 8, 3)), 64, 0)[0] + + assert set(output) == {"samples"} + + +def test_seedvr2_saved_latent_does_not_persist_model_specific_layout_metadata(monkeypatch): + saved = {} + + def fake_save_image_path(filename_prefix, output_dir): + return output_dir, filename_prefix, 1, "", filename_prefix + + def fake_save_torch_file(output, file, metadata=None): + saved.update(output) + + monkeypatch.setattr(nodes.folder_paths, "get_save_image_path", fake_save_image_path) + monkeypatch.setattr(nodes.comfy.utils, "save_torch_file", fake_save_torch_file) + monkeypatch.setattr(nodes.folder_paths, "get_annotated_filepath", lambda latent: latent) + monkeypatch.setattr(nodes.safetensors.torch, "load_file", lambda latent_path, device="cpu": saved) + + original = torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32) + nodes.SaveLatent().save({"samples": original, "seedvr2_channel_last": True}, "seedvr2_latent") + loaded = nodes.LoadLatent().load("seedvr2_latent")[0] + + assert "seedvr2_channel_last" not in saved + assert "seedvr2_channel_last" not in loaded + torch.testing.assert_close(loaded["samples"], original) + + +def test_seedvr2_tiled_decode_node_preserves_legacy_decode_tiled_signature(): + class FakeVAE: + def __init__(self): + self.decode_call = None + + def temporal_compression_decode(self): + return 4 + + def spacial_compression_decode(self): + return 8 + + def decode_tiled(self, samples, tile_x, tile_y, overlap, tile_t, overlap_t): + self.decode_call = { + "tile_x": tile_x, + "tile_y": tile_y, + "overlap": overlap, + "tile_t": tile_t, + "overlap_t": overlap_t, + } + return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32) + + vae = FakeVAE() + samples = {"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)} + + nodes.VAEDecodeTiled().decode( + vae, + samples, + tile_size=64, + overlap=0, + temporal_size=64, + temporal_overlap=8, + ) + + assert vae.decode_call == { + "tile_x": 8, + "tile_y": 8, + "overlap": 0, + "tile_t": 16, + "overlap_t": 2, + }