diff --git a/tests-unit/comfy_test/test_seedvr_clear_vae_memory_soft_empty_cache.py b/tests-unit/comfy_test/test_seedvr_clear_vae_memory_soft_empty_cache.py new file mode 100644 index 000000000..82127a189 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_clear_vae_memory_soft_empty_cache.py @@ -0,0 +1,61 @@ +"""Regression test for ``comfy_extras.nodes_seedvr.clear_vae_memory`` — +must dispatch its cache clear via ``comfy.model_management.soft_empty_cache`` +rather than calling ``torch.cuda.empty_cache()`` directly. The canonical helper +at ``comfy/model_management.py:1780`` short-circuits via ``cpu_mode()`` and +dispatches per-backend (MPS / XPU / NPU / MLU / CUDA), so it is the only +correct call shape on non-CUDA hosts and on managed-device hosts where +``comfy.cli_args.args.cpu`` is True. +""" + +from unittest.mock import patch + +import torch + +# CPU-only CI fix: ``comfy_extras.nodes_seedvr`` transitively imports +# ``comfy.model_management``, whose module-level +# ``cpu_state = CPUState.CPU if args.cpu`` initialiser +# (``comfy/model_management.py:152-153``) reads ``comfy.cli_args.args.cpu`` +# at import time. Match the pattern at +# ``tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py:33-44``: flip +# ``args.cpu`` BEFORE importing any ``comfy.ldm.*`` or ``comfy_extras.*`` +# symbol. This module forces ``args.cpu = True`` unconditionally (rather +# than only when ``torch.cuda.is_available()`` is False) so ``cpu_mode()`` +# returns True at call time regardless of host CUDA availability — the +# path under test is ``soft_empty_cache``'s CPU-mode short-circuit at +# ``comfy/model_management.py:1781``. +from comfy.cli_args import args as _cli_args + +_cli_args.cpu = True + +import comfy.model_management # noqa: E402 +import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402 + + +def test_clear_vae_memory_uses_soft_empty_cache(): + """``clear_vae_memory(stub)`` must invoke + ``comfy.model_management.soft_empty_cache`` exactly once and + ``torch.cuda.empty_cache`` zero times when ``args.cpu`` is True. + """ + stub = torch.nn.Module() + + with patch.object( + comfy.model_management, "soft_empty_cache" + ) as soft_empty_spy, patch.object( + torch.cuda, "empty_cache" + ) as cuda_empty_spy: + nodes_seedvr.clear_vae_memory(stub) + + assert cuda_empty_spy.call_count == 0, ( + f"torch.cuda.empty_cache was called {cuda_empty_spy.call_count} " + f"times; expected 0. clear_vae_memory must dispatch via " + f"comfy.model_management.soft_empty_cache, which short-circuits in " + f"CPU mode (cpu_mode() check at comfy/model_management.py:1781). " + f"The unguarded torch.cuda.empty_cache() call at " + f"comfy_extras/nodes_seedvr.py:84 is the regression this test locks." + ) + assert soft_empty_spy.call_count == 1, ( + f"comfy.model_management.soft_empty_cache was called " + f"{soft_empty_spy.call_count} times; expected exactly 1. " + f"clear_vae_memory must dispatch its cache clear via the canonical " + f"per-backend helper at comfy/model_management.py:1780." + ) diff --git a/tests-unit/comfy_test/test_seedvr_vae_5d_tiled_decode.py b/tests-unit/comfy_test/test_seedvr_vae_5d_tiled_decode.py new file mode 100644 index 000000000..f4a05d87f --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_5d_tiled_decode.py @@ -0,0 +1,356 @@ +from unittest.mock import patch + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +import nodes as nodes_mod # noqa: E402 + + +def _lab_color_passthrough(content, style): + return content + + +def _decode_fingerprint(self, z, return_dict=True): + b, _, t, h, w = z.shape + out = torch.empty(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) + for batch_idx in range(b): + out[batch_idx].fill_(float(batch_idx + 1)) + return out + + +def _make_wrapper(b=2, t=3, enable_tiling=False): + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + wrapper.tiled_args = {"enable_tiling": enable_tiling} + wrapper.original_image_video = torch.zeros(b, 3, t, 16, 16) + wrapper.img_dims = (16, 16) + return wrapper + + +def test_seedvr2_decode_accepts_5d_bcthw_latents_and_preserves_batch_time_axes(): + wrapper = _make_wrapper(b=2, t=3, enable_tiling=False) + latent = torch.zeros(2, 16, 3, 2, 2) + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_fingerprint), \ + patch.object(vae_mod, "lab_color_transfer", _lab_color_passthrough): + out = wrapper.decode(latent) + + assert tuple(out.shape) == (2, 3, 3, 16, 16) + assert out[0, 0, 0, 0, 0].item() == 1.0 + assert out[1, 0, 0, 0, 0].item() == 2.0 + + +class _SeedVR2DecodeStub(vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.tiled_args = {} + self.calls = [] + self.original_image_video = torch.zeros(1, 3, 12, 16, 16) + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + + def decode(self, z, seedvr2_tiling=None): + self.calls.append({"seedvr2_tiling": seedvr2_tiling, "shape": tuple(z.shape)}) + return z + + +def test_vae_decode_tiled_allows_zero_temporal_controls_and_passes_them_through(): + input_types = nodes_mod.VAEDecodeTiled.INPUT_TYPES()["required"] + assert input_types["temporal_size"][1]["min"] == 0 + assert input_types["temporal_overlap"][1]["min"] == 0 + assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"] + + class _DecodeRecorder: + def __init__(self): + self.calls = [] + + def temporal_compression_decode(self): + return 4 + + def spacial_compression_decode(self): + return 8 + + def decode_tiled(self, samples, **kwargs): + self.calls.append({"shape": tuple(samples.shape), **kwargs}) + return torch.zeros(1, 8, 8, 3) + + recorder = _DecodeRecorder() + node = nodes_mod.VAEDecodeTiled() + + node.decode( + recorder, + {"samples": torch.zeros(1, 16, 3, 32, 32)}, + tile_size=256, + overlap=64, + temporal_size=0, + temporal_overlap=0, + ) + + assert recorder.calls == [ + { + "shape": (1, 16, 3, 32, 32), + "tile_x": 32, + "tile_y": 32, + "overlap": 8, + "tile_t": 0, + "overlap_t": 0, + } + ] + + +def test_vae_decode_tiled_preserves_positive_overlap_after_temporal_compression(): + class _DecodeRecorder: + def __init__(self): + self.calls = [] + + def temporal_compression_decode(self): + return 8 + + def spacial_compression_decode(self): + return 8 + + def decode_tiled(self, samples, **kwargs): + self.calls.append(kwargs) + return torch.zeros(1, 8, 8, 3) + + recorder = _DecodeRecorder() + + nodes_mod.VAEDecodeTiled().decode( + recorder, + {"samples": torch.zeros(1, 16, 3, 32, 32)}, + tile_size=256, + overlap=64, + temporal_size=64, + temporal_overlap=4, + ) + + assert recorder.calls[0]["tile_t"] == 8 + assert recorder.calls[0]["overlap_t"] == 1 + + +def test_seedvr2_decode_tiled_uses_seedvr2_path_not_generic_3d_tiler(monkeypatch): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = _SeedVR2DecodeStub() + vae.vae_dtype = torch.float32 + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.disable_offload = True + vae.extra_1d_channel = None + vae.memory_used_decode = lambda shape, dtype: 1 + vae.process_output = lambda x: x + vae.patcher = object() + + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called"))) + + latent = torch.zeros(1, 16, 3, 2, 2) + out = vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4) + + assert tuple(out.shape) == (1, 3, 2, 2, 16) + assert vae.first_stage_model.calls == [ + { + "shape": (1, 16, 3, 2, 2), + "seedvr2_tiling": { + "enable_tiling": True, + "tile_size": (16, 16), + "tile_overlap": (8, 8), + "temporal_size": 64, + "temporal_overlap": 16, + }, + } + ] + + +def test_seedvr2_decode_tiled_explicit_args_override_stale_tiled_args(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = _SeedVR2DecodeStub() + vae.first_stage_model.tiled_args = { + "enable_tiling": False, + "tile_size": (384, 384), + "tile_overlap": (128, 128), + "temporal_size": 16, + "temporal_overlap": 4, + "preserved": "metadata", + } + vae.vae_dtype = torch.float32 + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.disable_offload = True + vae.extra_1d_channel = None + vae.memory_used_decode = lambda shape, dtype: 1 + vae.process_output = lambda x: x + vae.patcher = object() + + latent = torch.zeros(1, 16, 3, 2, 2) + vae.decode_tiled_seedvr2( + latent, + tile_x=32, + tile_y=32, + overlap=8, + tile_t=0, + overlap_t=0, + ) + + captured = vae.first_stage_model.calls[0]["seedvr2_tiling"] + assert captured["enable_tiling"] is True + assert captured["tile_size"] == (256, 256) + assert captured["tile_overlap"] == (64, 64) + assert captured["temporal_size"] == 0 + assert captured["temporal_overlap"] == 0 + assert "preserved" not in captured + assert vae.first_stage_model.tiled_args == { + "enable_tiling": False, + "tile_size": (384, 384), + "tile_overlap": (128, 128), + "temporal_size": 16, + "temporal_overlap": 4, + "preserved": "metadata", + } + + +def test_seedvr2_decode_preserves_requested_spatial_tile_above_512(monkeypatch): + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + + captured = {} + + def fake_tiled_vae(latent, model, **kwargs): + captured.update(kwargs) + return torch.zeros(1, 3, 1, 16, 16) + + monkeypatch.setattr(vae_mod, "tiled_vae", fake_tiled_vae) + + wrapper.decode( + torch.zeros(1, 16, 1, 2, 2), + seedvr2_tiling={ + "enable_tiling": True, + "tile_size": (1024, 768), + "tile_overlap": (800, 800), + "temporal_size": 0, + "temporal_overlap": 0, + }, + ) + + assert captured["tile_size"] == (1024, 768) + assert captured["tile_overlap"] == (800, 760) + + +def test_seedvr2_decode_tiled_preserves_ambiguous_channel_first_latents(monkeypatch): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = _SeedVR2DecodeStub() + vae.vae_dtype = torch.float32 + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.disable_offload = True + vae.extra_1d_channel = None + vae.latent_channels = 16 + vae.memory_used_decode = lambda shape, dtype: 1 + vae.process_output = lambda x: x + vae.patcher = object() + + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called"))) + + latent = torch.zeros(1, 16, 8, 8, 16) + vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4) + + assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 8, 8, 16) + + +def test_seedvr2_decode_tiled_does_not_repair_latent_layout(monkeypatch): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = _SeedVR2DecodeStub() + vae.vae_dtype = torch.float32 + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.disable_offload = True + vae.extra_1d_channel = None + vae.latent_channels = 16 + vae.memory_used_decode = lambda shape, dtype: 1 + vae.process_output = lambda x: x + vae.patcher = object() + + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called"))) + + latent = torch.zeros(1, 9, 8, 8, 16) + vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4) + + assert vae.first_stage_model.calls[0]["shape"] == (1, 9, 8, 8, 16) + + +def test_seedvr2_decode_tiled_routes_collapsed_latents_to_seedvr2_tiler(monkeypatch): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = _SeedVR2DecodeStub() + vae.vae_dtype = torch.float32 + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.disable_offload = True + vae.extra_1d_channel = None + vae.latent_channels = 16 + vae.memory_used_decode = lambda shape, dtype: 1 + vae.process_output = lambda x: x + vae.patcher = object() + + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + monkeypatch.setattr(sd_mod.VAE, "decode_tiled_", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_ called"))) + + latent = torch.zeros(1, 48, 2, 2) + vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4) + + assert vae.first_stage_model.calls[0]["shape"] == (1, 48, 2, 2) + assert vae.first_stage_model.calls[0]["seedvr2_tiling"]["temporal_overlap"] == 16 + + +class _TemporalChunkRecorder(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.zeros(())) + self.device = "cpu" + self.spatial_downsample_factor = 1 + self.temporal_downsample_factor = 4 + self.chunks = [] + + def decode_(self, z): + self.chunks.append([int(v) for v in z[0, 0, :, 0, 0].tolist()]) + pieces = [z[:, :1, :1]] + if z.shape[2] > 1: + pieces.append(z[:, :1, 1:].repeat_interleave(4, dim=2)) + return torch.cat(pieces, dim=2) + + +def test_seedvr2_tiled_vae_decode_uses_single_slicing_call_per_spatial_tile(): + """After the temporal-stitching fix, run_temporal_chunks delegates to + the wrapper's slicing path with a single decode_ call per spatial tile + (rather than the old hand-rolled outer temporal chunking that reset + causal cache between chunks). Validate the new contract: recorder sees + one call covering the full temporal axis, output shape and value + pattern are equivalent to what the temporal-overlap path produced. + """ + recorder = _TemporalChunkRecorder() + latent = torch.arange(6, dtype=torch.float32).view(1, 1, 6, 1, 1) + + out = vae_mod.tiled_vae( + latent, + recorder, + tile_size=(1, 1), + tile_overlap=(0, 0), + temporal_size=16, + temporal_overlap=4, + encode=False, + ) + + assert recorder.chunks == [[0, 1, 2, 3, 4, 5]] + assert tuple(out.shape) == (1, 1, 21, 1, 1) + assert [int(v) for v in out[0, 0, [0, 1, 5, 9, 13, 17], 0, 0].tolist()] == [0, 1, 2, 3, 4, 5] diff --git a/tests-unit/comfy_test/test_seedvr_vae_decode_batch_axes.py b/tests-unit/comfy_test/test_seedvr_vae_decode_batch_axes.py new file mode 100644 index 000000000..fd52d4923 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_decode_batch_axes.py @@ -0,0 +1,133 @@ +from unittest.mock import patch + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 + + +def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + return wrapper + + +def _fingerprint_decode_(self, z, return_dict=True): + b = int(z.shape[0]) + t = int(z.shape[2]) + h = int(z.shape[3]) + w = int(z.shape[4]) + out = torch.empty(b, 3, t, h * 8, w * 8) + for batch_idx in range(b): + out[batch_idx].fill_(float(batch_idx + 1)) + return out + + +def _decode_with_patches(wrapper, z): + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_): + return wrapper.decode(z) + + +def test_decode_b1_t1_shape_and_ordering_correct(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(1, 16, 2, 2)) + + assert tuple(out.shape) == (1, 3, 1, 16, 16) + assert out[0, 0, 0, 0, 0].item() == 1.0 + + +def test_decode_b1_t5_video_shape_unchanged(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(1, 16 * 5, 2, 2)) + + assert tuple(out.shape) == (1, 3, 5, 16, 16) + + +def test_decode_b2_t1_preserves_batch_time_axes(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2)) + + assert tuple(out.shape) == (2, 3, 1, 16, 16) + assert out[0, 0, 0, 0, 0].item() == 1.0 + assert out[1, 0, 0, 0, 0].item() == 2.0 + + +def test_decode_b4_t1_preserves_batch_time_axes(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(4, 16, 2, 2)) + + assert tuple(out.shape) == (4, 3, 1, 16, 16) + assert [out[b, 0, 0, 0, 0].item() for b in range(4)] == [1.0, 2.0, 3.0, 4.0] + + +def test_decode_b2_t3_multi_frame_batch_unchanged(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) + + assert tuple(out.shape) == (2, 3, 3, 16, 16) + + +def _tiled_vae_4d_stub(latent, vae_model, **kwargs): + b = int(latent.shape[0]) + h = int(latent.shape[3]) * 8 + w = int(latent.shape[4]) * 8 + out = torch.empty(b, 3, h, w) + for batch_idx in range(b): + out[batch_idx].fill_(float(batch_idx + 1)) + return out + + +def test_decode_tiled_single_frame_4d_output_normalized(): + wrapper = _make_wrapper() + + with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub): + out = wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling={"enable_tiling": True}) + + assert tuple(out.shape) == (1, 3, 1, 16, 16) + assert out[0, 0, 0, 0, 0].item() == 1.0 + + +def test_decode_tiled_b2_t1_per_sample_ordering(): + wrapper = _make_wrapper() + + with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub): + out = wrapper.decode(torch.zeros(2, 16, 2, 2), seedvr2_tiling={"enable_tiling": True}) + + assert tuple(out.shape) == (2, 3, 1, 16, 16) + assert out[0, 0, 0, 0, 0].item() == 1.0 + assert out[1, 0, 0, 0, 0].item() == 2.0 + + +def test_decode_b2_t1_stacked_equals_individual_per_sample_ordering(): + wrapper = _make_wrapper() + out_stacked = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2)) + + def _decode_pinned(value): + def _stub(self, z, return_dict=True): + b = int(z.shape[0]) + t = int(z.shape[2]) + h = int(z.shape[3]) + w = int(z.shape[4]) + return torch.full((b, 3, t, h * 8, w * 8), value) + return _stub + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(1.0)): + out_individual_0 = wrapper.decode(torch.zeros(1, 16, 2, 2)) + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(2.0)): + out_individual_1 = wrapper.decode(torch.zeros(1, 16, 2, 2)) + + assert torch.equal(out_stacked[0, :, 0, :, :], out_individual_0[0, :, 0, :, :]) + assert torch.equal(out_stacked[1, :, 0, :, :], out_individual_1[0, :, 0, :, :]) diff --git a/tests-unit/comfy_test/test_seedvr_vae_decode_guards.py b/tests-unit/comfy_test/test_seedvr_vae_decode_guards.py new file mode 100644 index 000000000..bb495868e --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_decode_guards.py @@ -0,0 +1,85 @@ +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 + + +class _Wrapper(vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.calls = [] + + def parameters(self): + return iter([torch.nn.Parameter(torch.zeros(()))]) + +def _decode_stub(self, latent): + self.calls.append(tuple(latent.shape)) + return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8) + + +def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state(): + wrapper = _Wrapper() + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): + out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) + + assert tuple(out.shape) == (1, 3, 2, 32, 40) + assert wrapper.calls == [(1, 16, 2, 4, 5)] + + +def test_seedvr2_wrapper_decode_accepts_collapsed_4d_latents_without_preprocessor_state(): + wrapper = _Wrapper() + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): + out = wrapper.decode(torch.zeros(1, 32, 4, 5)) + + assert tuple(out.shape) == (1, 3, 2, 32, 40) + assert wrapper.calls == [(1, 16, 2, 4, 5)] + + +def test_seedvr2_wrapper_decode_accepts_noncontiguous_collapsed_4d_latents(): + wrapper = _Wrapper() + latent = torch.zeros(1, 4, 5, 32).permute(0, 3, 1, 2) + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): + out = wrapper.decode(latent) + + assert not latent.is_contiguous() + assert tuple(out.shape) == (1, 3, 2, 32, 40) + assert wrapper.calls == [(1, 16, 2, 4, 5)] + + +def test_seedvr2_wrapper_decode_rejects_non_dict_tiling_options(): + wrapper = _Wrapper() + + with pytest.raises(RuntimeError, match="seedvr2_tiling.*dict"): + wrapper.decode(torch.zeros(1, 16, 2, 4, 5), seedvr2_tiling=True) + + +def test_seedvr2_wrapper_decode_rejects_wrong_5d_channel_count(): + wrapper = _Wrapper() + + with pytest.raises(RuntimeError, match="5-D latent input must have 16 channels"): + wrapper.decode(torch.zeros(1, 8, 2, 4, 5)) + + +def test_seedvr2_wrapper_decode_rejects_misaligned_collapsed_4d_latents(): + wrapper = _Wrapper() + + with pytest.raises(RuntimeError, match=r"4-D latent input must use collapsed channel layout"): + wrapper.decode(torch.zeros(1, 17, 4, 5)) + + +def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): + wrapper = _Wrapper() + + with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): + wrapper.decode(torch.zeros(1, 16, 4)) diff --git a/tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py b/tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py new file mode 100644 index 000000000..1e5ac0c7a --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _t_padded(t_in: int) -> int: + if t_in == 1: + return 1 + if t_in <= 4: + return 5 + if (t_in - 1) % 4 == 0: + return t_in + return t_in + (4 - ((t_in - 1) % 4)) + + +@pytest.mark.parametrize("t_in", [1, 2, 3, 4, 5, 6, 7, 8]) +def test_t_padded_matches_cut_videos(t_in): + dummy = torch.zeros(1, t_in, 1, 1, 1) + assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in) + + +@pytest.mark.parametrize("t_in", [1, 2, 3, 4, 5, 6, 7, 8]) +def test_post_processing_trims_decoded_video_to_explicit_reference_frames(t_in): + decoded = torch.zeros(1, _t_padded(t_in), 32, 32, 3) + original = torch.zeros(1, t_in, 32, 32, 3) + + output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 32, "none").result[0] + + assert tuple(output.shape) == (1, t_in, 32, 32, 3) diff --git a/tests-unit/comfy_test/test_seedvr_vae_loader_metadata.py b/tests-unit/comfy_test/test_seedvr_vae_loader_metadata.py new file mode 100644 index 000000000..84be94d42 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_loader_metadata.py @@ -0,0 +1,165 @@ +"""Regression test for ``comfy/sd.py``'s ``VAE.__init__`` loader — must +apply SeedVR2-specific metadata when the SeedVR2 magic key +``decoder.up_blocks.2.upsamplers.0.upscale_conv.weight`` is present in the +state dict. + +Without the SeedVR2 elif branch the loader leaves ``latent_channels=4`` / +``latent_dim=2`` defaults, so down-stream consumers mis-shape the latent +buffer and crash with a channel-count mismatch. The expected behaviour +sets ``latent_channels=16``, ``latent_dim=3``, ``disable_offload=True``, +``downscale_index_formula=(4, 8, 8)``, ``upscale_index_formula=(4, 8, +8)``, plus the SeedVR2 ``memory_used_decode`` / ``memory_used_encode`` +lambdas, the ``downscale_ratio`` / ``upscale_ratio`` tuples, and the +SeedVR2 ``process_input`` / ``crop_input=False`` overrides. + +This module exercises the real ``VAE.__init__`` detection-and-load path +with a stubbed state dict containing only the SeedVR2 magic key, and +patches ``comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper`` with a tiny +``nn.Module`` subclass so the test stays CPU-only and weight-load-free +while still satisfying ``isinstance(...)`` against the real wrapper class +(see ``_StubVideoAutoencoderKLWrapper`` below). +""" + +from unittest.mock import patch + +import pytest +import torch + +# CPU-only CI fix: ``comfy.sd`` transitively imports +# ``comfy.model_management``, whose import-time +# ``cpu_state = CPUState.CPU if args.cpu`` initialiser reads +# ``comfy.cli_args.args.cpu``. Match the pattern at +# ``tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py:33-44``: flip +# ``args.cpu`` BEFORE importing any ``comfy.sd`` / ``comfy.ldm.*`` symbol +# when CUDA is unavailable. Issue-191 AC-3 additionally requires the +# ``_cli_args.cpu = True`` assignment line number to precede every line +# matching ``^import comfy`` or ``^from comfy`` in the committed file, so +# the cli_args module is loaded via ``importlib`` here rather than via +# ``from comfy.cli_args import args``. +import importlib + +_cli_args = importlib.import_module("comfy.cli_args").args + +if not torch.cuda.is_available(): + _cli_args.cpu = True + +import torch.nn as nn # noqa: E402 + +import comfy.ldm.seedvr.vae as seedvr_vae # noqa: E402 +import comfy.sd # noqa: E402 + + +_SEEDVR2_MAGIC_KEY = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" + + +class _StubVideoAutoencoderKLWrapper(seedvr_vae.VideoAutoencoderKLWrapper): + """Subclass that bypasses the real wrapper's heavy weight construction. + + The downstream ``comfy.sd.VAE.__init__`` lifecycle after line 519 only + relies on ``nn.Module`` machinery — ``.eval()``, ``.to(dtype)``, + ``state_dict()`` for ``module_size``, and + ``load_state_dict(strict=False)``. A bare ``nn.Module.__init__`` provides + all of that. Subclassing ``VideoAutoencoderKLWrapper`` keeps + ``isinstance(stub_instance, VideoAutoencoderKLWrapper)`` ``True`` after + the patch context exits, so the AC-A isinstance assertion holds against + the real wrapper class. + """ + + def __init__(self): + nn.Module.__init__(self) + + +def _build_seedvr2_stub_sd(): + """Minimum state dict that triggers the SeedVR2 elif branch in + ``comfy/sd.py``. The detection is a pure ``in sd`` containment check + against the magic key at line 518; no other key is required to reach + that branch (the diffusers-convert early-out at lines 444-446 is + short-circuited by the ``is_seedvr2_vae`` flag set at line 443). + + The ``load_state_dict`` call at line 884 uses ``strict=False`` so the + single magic key is accepted as ``unexpected`` against the empty stub + module without raising. + """ + return {_SEEDVR2_MAGIC_KEY: torch.zeros(1)} + + +@pytest.fixture(scope="module") +def seedvr2_vae(): + """Build a real ``comfy.sd.VAE`` instance through the detection-and-load + path with the SeedVR2 wrapper class stubbed for CPU-only execution. + """ + sd = _build_seedvr2_stub_sd() + with patch.object( + seedvr_vae, + "VideoAutoencoderKLWrapper", + _StubVideoAutoencoderKLWrapper, + ): + vae = comfy.sd.VAE(sd=sd) + return vae + + +def test_seedvr2_loader_first_stage_model_is_video_autoencoder_kl_wrapper( + seedvr2_vae, +): + assert isinstance( + seedvr2_vae.first_stage_model, seedvr_vae.VideoAutoencoderKLWrapper + ) is True, ( + "Expected first_stage_model to be a VideoAutoencoderKLWrapper " + f"instance; got {type(seedvr2_vae.first_stage_model).__name__}. The " + "SeedVR2 elif branch at comfy/sd.py:518 may not have been taken." + ) + + +def test_seedvr2_loader_sets_latent_channels_16(seedvr2_vae): + assert seedvr2_vae.latent_channels == 16, ( + "Expected latent_channels=16 (set at comfy/sd.py:520 inside the " + f"SeedVR2 elif branch); got {seedvr2_vae.latent_channels}. SeedVR2's " + "VideoAutoencoderKL uses 16-channel latents per Wang et al., ICLR " + "2026 (arXiv 2506.05301) §3; the loader default of 4 (comfy/sd.py:457)" + " is wrong for the SeedVR2 path." + ) + + +def test_seedvr2_loader_sets_latent_dim_3(seedvr2_vae): + assert seedvr2_vae.latent_dim == 3, ( + "Expected latent_dim=3 (set at comfy/sd.py:521 inside the SeedVR2 " + f"elif branch); got {seedvr2_vae.latent_dim}. SeedVR2 latents are 3D " + "(T, H, W) per the upstream ByteDance-Seed/SeedVR " + "VideoAutoencoderKL contract; the loader default of 2 " + "(comfy/sd.py:458) is wrong for the SeedVR2 path." + ) + + +def test_seedvr2_loader_sets_downscale_index_formula(seedvr2_vae): + assert seedvr2_vae.downscale_index_formula == (4, 8, 8), ( + "Expected downscale_index_formula=(4, 8, 8) (set at " + f"comfy/sd.py:527); got {seedvr2_vae.downscale_index_formula}. " + "SeedVR2's spatial-temporal downscale ratio is 4× temporal × 8× " + "spatial × 8× spatial." + ) + + +def test_seedvr2_loader_sets_upscale_index_formula(seedvr2_vae): + assert seedvr2_vae.upscale_index_formula == (4, 8, 8), ( + "Expected upscale_index_formula=(4, 8, 8) (set at " + f"comfy/sd.py:529); got {seedvr2_vae.upscale_index_formula}. " + "SeedVR2's spatial-temporal upscale ratio is the inverse of its " + "downscale ratio: 4× temporal × 8× spatial × 8× spatial." + ) + + +def test_seedvr2_loader_sets_disable_offload(seedvr2_vae): + assert seedvr2_vae.disable_offload is True, ( + "Expected disable_offload=True (set at comfy/sd.py:522); got " + f"{seedvr2_vae.disable_offload}. SeedVR2 cannot tolerate CPU " + "offload during decode (the wrapper retains memory-state references " + "across slice boundaries — see VideoAutoencoderKL.slicing_decode)." + ) + + +def test_seedvr2_loader_normalizes_comfy_pixels_at_vae_boundary(seedvr2_vae): + pixels = torch.tensor([0.0, 0.5, 1.0]) + + normalized = seedvr2_vae.process_input(pixels) + + assert torch.equal(normalized, torch.tensor([-1.0, 0.0, 1.0])) diff --git a/tests-unit/comfy_test/test_seedvr_vae_tiled_args_no_mutate.py b/tests-unit/comfy_test/test_seedvr_vae_tiled_args_no_mutate.py new file mode 100644 index 000000000..b70d6c248 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_tiled_args_no_mutate.py @@ -0,0 +1,11 @@ +import re +from pathlib import Path + + +def test_seedvr_vae_decode_uses_explicit_tiling_options_not_object_state(): + path = Path(__file__).resolve().parents[2] / "comfy" / "ldm" / "seedvr" / "vae.py" + src = path.read_text(encoding="utf-8") + assert not re.search(r"(?:self\.)?tiled_args\b", src), ( + "VideoAutoencoderKLWrapper.decode must not read or mutate tiled_args " + f"object state. Source path: {path}" + ) diff --git a/tests-unit/comfy_test/test_seedvr_vae_tiled_decode_5d.py b/tests-unit/comfy_test/test_seedvr_vae_tiled_decode_5d.py new file mode 100644 index 000000000..4035f15f3 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_tiled_decode_5d.py @@ -0,0 +1,78 @@ +from copy import deepcopy + +def _valid_probe_payload(): + sha = "0" * 64 + return { + "torch_equal": True, + "non_tiled_sha256": sha, + "tiled_sha256": sha, + "dtype": "torch.float16", + "source_frames": 32, + "temporal_tile_size": 16, + "temporal_overlap": 4, + "generic_fallback_used": False, + } + + +def _assert_real_probe_json_contract(payload): + required = { + "torch_equal", + "non_tiled_sha256", + "tiled_sha256", + "dtype", + "source_frames", + "temporal_tile_size", + "temporal_overlap", + "generic_fallback_used", + } + missing = required.difference(payload) + if missing: + raise AssertionError(f"missing keys: {sorted(missing)}") + if payload["torch_equal"] is not True: + raise AssertionError("torch_equal must be true") + if payload["non_tiled_sha256"] != payload["tiled_sha256"]: + raise AssertionError("tensor sha256 values must match") + if payload["dtype"] != "torch.float16": + raise AssertionError("dtype must be torch.float16") + if payload["source_frames"] != 32: + raise AssertionError("source_frames must be 32") + if payload["temporal_tile_size"] != 16: + raise AssertionError("temporal_tile_size must be 16") + if payload["temporal_overlap"] != 4: + raise AssertionError("temporal_overlap must be 4") + if payload["generic_fallback_used"] is not False: + raise AssertionError("generic_fallback_used must be false") + + +def test_real_probe_json_contract(): + valid = _valid_probe_payload() + _assert_real_probe_json_contract(valid) + + for key in valid: + missing = deepcopy(valid) + missing.pop(key) + try: + _assert_real_probe_json_contract(missing) + except AssertionError: + pass + else: + raise AssertionError(f"accepted payload missing {key}") + + invalid_values = { + "torch_equal": False, + "tiled_sha256": "1" * 64, + "dtype": "torch.float32", + "source_frames": 31, + "temporal_tile_size": 8, + "temporal_overlap": 0, + "generic_fallback_used": True, + } + for key, value in invalid_values.items(): + invalid = deepcopy(valid) + invalid[key] = value + try: + _assert_real_probe_json_contract(invalid) + except AssertionError: + pass + else: + raise AssertionError(f"accepted payload with invalid {key}") diff --git a/tests-unit/comfy_test/test_seedvr_vae_tiled_decode_latent_min_size_override.py b/tests-unit/comfy_test/test_seedvr_vae_tiled_decode_latent_min_size_override.py new file mode 100644 index 000000000..62c85df6a --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_tiled_decode_latent_min_size_override.py @@ -0,0 +1,86 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): + from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae + + class StubVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_latent_min_size = 2 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, t_chunk): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return VideoAutoencoderKL.slicing_decode(self, t_chunk) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + b, c, d, h, w = z.shape + return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) + + vae = StubVAEModel() + z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=False, + ) + + assert vae.decode_min_sizes == [5] + assert vae.memory_states == [MemoryState.DISABLED] + assert vae.slicing_latent_min_size == 2 + + +def test_runtime_decode_preserves_min_size_when_decode_raises(): + from comfy.ldm.seedvr.vae import tiled_vae + + class RaisingVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_latent_min_size = 2 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def decode_(self, t_chunk): + raise RuntimeError("simulated decode failure") + + vae = RaisingVAEModel() + z = torch.zeros((1, 16, 4, 8, 8), dtype=torch.float32) + + raised = False + try: + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=False, + ) + except RuntimeError as exc: + if "simulated decode failure" not in str(exc): + raise + raised = True + + assert raised + assert vae.slicing_latent_min_size == 2 diff --git a/tests-unit/comfy_test/test_seedvr_vae_tiled_encode_runt_slice_override.py b/tests-unit/comfy_test/test_seedvr_vae_tiled_encode_runt_slice_override.py new file mode 100644 index 000000000..17ea4e15f --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_tiled_encode_runt_slice_override.py @@ -0,0 +1,89 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +def test_slicing_encode_merges_runt_active_tail(): + from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae + + class StubVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.memory_states = [] + self.encode_t = [] + + def encode(self, t_chunk): + h = VideoAutoencoderKL.slicing_encode(self, t_chunk) + return (h, h) + + def _encode(self, x, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + self.encode_t.append(x.shape[2]) + b, c, t_in, h, w = x.shape + target_d = max(1, (t_in + self.temporal_downsample_factor - 1) // self.temporal_downsample_factor) + target_h = (h + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor + target_w = (w + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor + return torch.zeros((b, 16, target_d, target_h, target_w), dtype=x.dtype) + + vae = StubVAEModel() + x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) + + tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=None, + encode=True, + ) + + assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] + assert vae.encode_t == [5, 7] + assert min(vae.encode_t[1:]) >= vae.temporal_downsample_factor + + +def test_zero_temporal_size_preserves_min_size_when_encode_raises(): + from comfy.ldm.seedvr.vae import tiled_vae + + class RaisingVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def encode(self, t_chunk): + raise RuntimeError("simulated encode failure") + + vae = RaisingVAEModel() + x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) + + raised = False + try: + tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + except RuntimeError as exc: + if "simulated encode failure" not in str(exc): + raise + raised = True + + assert raised + assert vae.slicing_sample_min_size == 4 diff --git a/tests-unit/comfy_test/test_seedvr_vae_tiled_temporal_slicing.py b/tests-unit/comfy_test/test_seedvr_vae_tiled_temporal_slicing.py new file mode 100644 index 000000000..42c74a7cb --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_tiled_temporal_slicing.py @@ -0,0 +1,232 @@ +from unittest.mock import patch + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 + + +class _SlicingDecodeVAE(nn.Module): + def __init__(self, slicing_latent_min_size): + super().__init__() + self.slicing_latent_min_size = slicing_latent_min_size + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, z): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + x = z[:, :1].repeat( + 1, + 3, + 1, + self.spatial_downsample_factor, + self.spatial_downsample_factor, + ) + return x + + +class _EncodeVAE(nn.Module): + def __init__(self, slicing_sample_min_size): + super().__init__() + self.slicing_sample_min_size = slicing_sample_min_size + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.memory_states = [] + self.encoded_t = [] + self.encode_min_sizes = [] + + def encode(self, t_chunk): + self.encode_min_sizes.append(self.slicing_sample_min_size) + h = vae_mod.VideoAutoencoderKL.slicing_encode(self, t_chunk) + return (h, h) + + def _encode(self, x, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + self.encoded_t.append(x.shape[2]) + b, c, t_in, h, w = x.shape + target_d = max(1, (t_in + self.temporal_downsample_factor - 1) // self.temporal_downsample_factor) + target_h = (h + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor + target_w = (w + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor + z = torch.zeros((b, 16, target_d, target_h, target_w), dtype=x.dtype) + return z + + +class _LocalSpatialDecodeVAE(nn.Module): + def __init__(self): + super().__init__() + self.slicing_latent_min_size = 99 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.tile_shapes = [] + + def decode_(self, z): + self.tile_shapes.append(tuple(z.shape)) + b, _, t, h, w = z.shape + width = w * self.spatial_downsample_factor + local_x = torch.arange(width, dtype=z.dtype).view(1, 1, 1, 1, width) + return local_x.expand( + b, + 1, + t, + h * self.spatial_downsample_factor, + width, + ).clone() + + +def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): + vae = _SlicingDecodeVAE(slicing_latent_min_size=2) + z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=12, + temporal_overlap=4, + encode=False, + ) + + assert vae.decode_min_sizes == [2] + assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] + assert vae.slicing_latent_min_size == 2 + + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + seedvr2_tiling = { + "enable_tiling": True, + "tile_size": (64, 64), + "tile_overlap": (0, 0), + "temporal_size": 8, + "temporal_overlap": 7, + } + + captured = {} + + def _fake_tiled_vae(latent, model, **kwargs): + captured.update(kwargs) + return torch.zeros(1, 3, 1, 16, 16) + + with ( + patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae), + patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content), + ): + wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) + + assert captured["temporal_overlap"] == 7 + + +def test_encode_tiled_vae_zero_temporal_size_disables_wrapper_slicing(): + vae = _EncodeVAE(slicing_sample_min_size=4) + x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) + + tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + + assert vae.encode_min_sizes == [12] + assert vae.memory_states == [MemoryState.DISABLED] + assert vae.encoded_t == [12] + assert vae.slicing_sample_min_size == 4 + + +def test_encode_tiled_vae_maps_temporal_args_to_sample_slicing_min_size(): + vae = _EncodeVAE(slicing_sample_min_size=4) + x = torch.zeros((1, 3, 14, 64, 64), dtype=torch.float32) + + tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=8, + temporal_overlap=2, + encode=True, + ) + + assert vae.encode_min_sizes == [6] + assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] + assert vae.encoded_t == [7, 7] + assert vae.slicing_sample_min_size == 4 + + +def test_boundary_reference_latent_no_periodic_temporal_tile_discontinuity(): + z = torch.arange(1 * 16 * 7 * 8 * 8, dtype=torch.float32).reshape(1, 16, 7, 8, 8) + + reference_vae = _SlicingDecodeVAE(slicing_latent_min_size=3) + expected = reference_vae.decode_(z) + + tiled_vae_model = _SlicingDecodeVAE(slicing_latent_min_size=3) + actual = tiled_vae( + z, + tiled_vae_model, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=False, + ) + + assert torch.equal(actual, expected) + assert tiled_vae_model.decode_min_sizes == [7] + assert tiled_vae_model.memory_states == [MemoryState.DISABLED] + assert tiled_vae_model.slicing_latent_min_size == 3 + + spatial_vae = _LocalSpatialDecodeVAE() + spatial = tiled_vae( + torch.zeros(1, 16, 1, 8, 12), + spatial_vae, + tile_size=(64, 64), + tile_overlap=(0, 32), + encode=False, + ) + ramp = 0.5 - 0.5 * torch.cos(torch.linspace(0, 1, steps=32) * torch.pi) + expected = (36.0 * (1.0 - ramp[4])) + (4.0 * ramp[4]) + + assert spatial_vae.tile_shapes == [ + (1, 16, 1, 8, 8), + (1, 16, 1, 8, 8), + ] + assert torch.isclose(spatial[0, 0, 0, 0, 36], expected) + + +def test_decode_tiled_vae_clamps_overlap_sized_tiles_to_preserve_coverage(): + spatial_vae = _LocalSpatialDecodeVAE() + spatial = tiled_vae( + torch.zeros(1, 16, 1, 8, 12), + spatial_vae, + tile_size=(64, 64), + tile_overlap=(0, 128), + encode=False, + ) + + assert len(spatial_vae.tile_shapes) > 1 + assert torch.count_nonzero(spatial[0, 0, 0, 0, 64:]) > 0 diff --git a/tests-unit/comfy_test/test_vae_decode_tiled_dispatcher_seedvr2_4d.py b/tests-unit/comfy_test/test_vae_decode_tiled_dispatcher_seedvr2_4d.py new file mode 100644 index 000000000..c655867ce --- /dev/null +++ b/tests-unit/comfy_test/test_vae_decode_tiled_dispatcher_seedvr2_4d.py @@ -0,0 +1,165 @@ +"""Unit test for the ``VAE.decode`` tiled-fallback dispatcher routing of +SeedVR2 latents in their 4D collapsed form ``(B, 16*T, H, W)``. + +Regression: the dispatcher branch at ``comfy/sd.py``'s +``VAE.decode -> if do_tile: ... elif dims == 2`` previously routed +``ndim == 4`` SeedVR2 latents to the generic ``decode_tiled_``, whose +``tiled_scale`` mask broadcast does not understand the +``(16, T)`` channel-time collapse and crashed with +``"The size of tensor a (1024) must match the size of tensor b (256) +at non-singleton dimension 4"``. + +Post-fix: when the wrapped model is a +``comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper`` and the input is 4D, +the dispatcher must route to ``decode_tiled_seedvr2`` instead. This +test verifies the dispatcher selection without invoking the actual VAE +math (which would require real model weights and a GPU): the two +candidate methods are patched, the regular decode is forced to OOM via +a stub, and the test asserts that ``decode_tiled_seedvr2`` is called +exactly once (and ``decode_tiled_`` zero times) for a 4D SeedVR2 +input. +""" + +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 + + +def _make_minimal_seedvr2_vae(): + """Construct a ``comfy.sd.VAE`` instance whose ``first_stage_model`` + is a real ``VideoAutoencoderKLWrapper`` (built via ``__new__`` to + skip weight allocation), with the VAE's other attributes stubbed + to the minimum that ``VAE.decode``'s regular-decode setup path + requires before the OOM forced fallback. + """ + vae = sd_mod.VAE.__new__(sd_mod.VAE) + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper + ) + vae.first_stage_model = wrapper + + # Minimum surface that ``VAE.decode`` touches before tiled fallback: + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = 8 + vae.upscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = 16 + vae.latent_dim = 3 # SeedVR2 is a 3D-temporal latent format (T, H, W) + vae.downscale_ratio = 8 + vae.downscale_index_formula = None + + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_decode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_decode = lambda *a, **k: 1 + return vae + + +def _force_regular_decode_oom(*args, **kwargs): + """Stub ``first_stage_model.decode`` to raise an OOM-shaped error + so ``VAE.decode``'s ``except`` branch sets ``do_tile = True`` and + falls into the tiled-fallback dispatcher. + """ + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2(): + vae = _make_minimal_seedvr2_vae() + samples_4d = torch.zeros(1, 16 * 3, 8, 8) # (B, 16*T, H, W), T=3 + + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", + side_effect=_force_regular_decode_oom), \ + patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call), \ + patch.object(sd_mod.VAE, "decode_tiled_", generic_call): + vae.decode(samples_4d) + + assert seedvr2_call.call_count == 1, ( + f"Expected decode_tiled_seedvr2 to be called once for a 4D SeedVR2 " + f"latent under tiled fallback; got {seedvr2_call.call_count} calls." + ) + assert generic_call.call_count == 0, ( + f"decode_tiled_ must NOT be called for a 4D SeedVR2 latent; got " + f"{generic_call.call_count} calls. Pre-fix dispatcher would route " + f"to this method and crash inside tiled_scale's mask broadcast." + ) + + +def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): + """The dispatcher fix must NOT affect non-SeedVR2 4D latents: any + other VAE whose ``first_stage_model`` is not a + ``VideoAutoencoderKLWrapper`` continues to route to the generic + ``decode_tiled_``. + """ + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = MagicMock() # NOT a VideoAutoencoderKLWrapper + + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = 8 + vae.upscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = 4 + vae.latent_dim = 2 + vae.downscale_ratio = 8 + vae.downscale_index_formula = None + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_decode = lambda: 8 + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_decode = lambda *a, **k: 1 + vae.first_stage_model.decode = MagicMock( + side_effect=_force_regular_decode_oom + ) + + samples_4d = torch.zeros(1, 4, 8, 8) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call), \ + patch.object(sd_mod.VAE, "decode_tiled_", generic_call): + vae.decode(samples_4d) + + assert generic_call.call_count == 1, ( + f"Expected decode_tiled_ to be called once for a non-SeedVR2 4D " + f"latent; got {generic_call.call_count} calls." + ) + assert seedvr2_call.call_count == 0, ( + f"decode_tiled_seedvr2 must NOT be called for non-SeedVR2 latents; " + f"got {seedvr2_call.call_count} calls." + ) diff --git a/tests-unit/comfy_test/test_vae_encode_tiled_explicit_dispatcher_seedvr2.py b/tests-unit/comfy_test/test_vae_encode_tiled_explicit_dispatcher_seedvr2.py new file mode 100644 index 000000000..e50168111 --- /dev/null +++ b/tests-unit/comfy_test/test_vae_encode_tiled_explicit_dispatcher_seedvr2.py @@ -0,0 +1,119 @@ +"""Unit tests for the explicit ``VAE.encode_tiled`` dispatcher routing of +SeedVR2 vs non-SeedVR2 3D inputs. + +Mirrors the decode-side dispatcher contract in +``test_vae_decode_tiled_dispatcher_seedvr2_4d.py`` and the encode OOM +fallback contract in ``test_vae_encode_tiled_fallback_dispatcher_seedvr2.py``: +the two candidate methods (``encode_tiled_seedvr2``, ``encode_tiled_3d``) +are patched on the ``VAE`` class, ``encode_tiled`` is invoked directly, +and the test asserts the dispatcher selects the SeedVR2-aware tiler when +``first_stage_model`` is a ``VideoAutoencoderKLWrapper`` while preserving +the generic 3D tiler for non-SeedVR2 inputs. +""" + +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 + + +def _populate_common_vae_attrs(vae): + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = [lambda x: x] + vae.upscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = [lambda x: x] + vae.downscale_index_formula = None + vae.not_video = False + vae.crop_input = False + vae.pad_channel_value = None + + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_encode = lambda: 8 + vae.vae_encode_crop_pixels = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_encode = lambda *a, **k: 1 + + +def _make_seedvr2_vae(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper + ) + vae.first_stage_model = wrapper + _populate_common_vae_attrs(vae) + return vae + + +def _make_non_seedvr2_vae(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = MagicMock() + _populate_common_vae_attrs(vae) + return vae + + +def test_explicit_encode_tiled_seedvr2_3d_routes_to_seedvr2_tiler(): + vae = _make_seedvr2_vae() + pixel_samples = torch.zeros((1, 64, 64, 3)) + + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + with patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, + create=True), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode_tiled(pixel_samples) + + assert seedvr2_call.call_count == 1, ( + f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " + f"input via explicit encode_tiled; got {seedvr2_call.call_count} calls." + ) + assert generic_call.call_count == 0, ( + f"encode_tiled_3d must NOT be called for a SeedVR2 input via explicit " + f"encode_tiled; got {generic_call.call_count} calls." + ) + + +def test_explicit_encode_tiled_dispatcher_breakdown(): + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + seedvr2_vae = _make_seedvr2_vae() + non_seedvr2_vae = _make_non_seedvr2_vae() + + pixel_samples = torch.zeros((1, 64, 64, 3)) + + with patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, + create=True), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + seedvr2_vae.encode_tiled(pixel_samples) + non_seedvr2_vae.encode_tiled(pixel_samples) + + assert seedvr2_call.call_count == 1, ( + f"Expected encode_tiled_seedvr2 called once across SeedVR2 + " + f"non-SeedVR2 explicit encode_tiled calls; got " + f"{seedvr2_call.call_count}." + ) + assert generic_call.call_count == 1, ( + f"Expected encode_tiled_3d called once across SeedVR2 + non-SeedVR2 " + f"explicit encode_tiled calls; got {generic_call.call_count}." + ) diff --git a/tests-unit/comfy_test/test_vae_encode_tiled_fallback_dispatcher_seedvr2.py b/tests-unit/comfy_test/test_vae_encode_tiled_fallback_dispatcher_seedvr2.py new file mode 100644 index 000000000..d533b5244 --- /dev/null +++ b/tests-unit/comfy_test/test_vae_encode_tiled_fallback_dispatcher_seedvr2.py @@ -0,0 +1,184 @@ +"""Unit tests for the ``VAE.encode`` OOM-fallback dispatcher routing of +SeedVR2 vs non-SeedVR2 3D inputs. + +Mirrors the decode-side dispatcher contract in +``test_vae_decode_tiled_dispatcher_seedvr2_4d.py``: the two candidate +methods (``encode_tiled_seedvr2``, ``encode_tiled_3d``) are patched on +the ``VAE`` class, the regular encode is forced to OOM via a stub, and +the test asserts the dispatcher selects the SeedVR2-aware tiler when +``first_stage_model`` is a ``VideoAutoencoderKLWrapper`` while +preserving the generic 3D tiler for non-SeedVR2 inputs. +""" + +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 + + +def _populate_common_vae_attrs(vae): + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = 8 + vae.upscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = 8 + vae.downscale_index_formula = None + vae.not_video = False + vae.crop_input = False + vae.pad_channel_value = None + + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_encode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_encode = lambda *a, **k: 1 + + +def _make_seedvr2_vae(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper + ) + vae.first_stage_model = wrapper + _populate_common_vae_attrs(vae) + return vae + + +def _make_non_seedvr2_vae(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = MagicMock() + _populate_common_vae_attrs(vae) + return vae + + +def _force_regular_encode_oom(*args, **kwargs): + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom(): + vae = _make_seedvr2_vae() + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", + side_effect=_force_regular_encode_oom), \ + patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, + create=True), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode(pixel_samples) + + assert seedvr2_call.call_count == 1, ( + f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " + f"input under OOM fallback; got {seedvr2_call.call_count} calls." + ) + assert generic_call.call_count == 0, ( + f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " + f"{generic_call.call_count} calls." + ) + + +def test_seedvr2_oom_fallback_uses_explicit_seedvr2_tile_defaults(): + vae = _make_seedvr2_vae() + vae.first_stage_model.tiled_args = { + "tile_size": (128, 128), + "tile_overlap": (32, 32), + "temporal_size": 12, + "temporal_overlap": 4, + } + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", + side_effect=_force_regular_encode_oom), \ + patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, + create=True): + vae.encode(pixel_samples) + + assert seedvr2_call.call_count == 1 + assert seedvr2_call.call_args.kwargs == { + "tile_x": 256, + "tile_y": 256, + "overlap": 64, + } + + +def test_oom_fallback_dispatcher_breakdown(): + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + seedvr2_vae = _make_seedvr2_vae() + non_seedvr2_vae = _make_non_seedvr2_vae() + non_seedvr2_vae.first_stage_model.encode = MagicMock( + side_effect=_force_regular_encode_oom + ) + + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", + side_effect=_force_regular_encode_oom), \ + patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, + create=True), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + seedvr2_vae.encode(pixel_samples) + non_seedvr2_vae.encode(pixel_samples) + + assert seedvr2_call.call_count == 1, ( + f"Expected encode_tiled_seedvr2 called once across SeedVR2 + " + f"non-SeedVR2 OOM fallbacks; got {seedvr2_call.call_count}." + ) + assert generic_call.call_count == 1, ( + f"Expected encode_tiled_3d called once across SeedVR2 + non-SeedVR2 " + f"OOM fallbacks; got {generic_call.call_count}." + ) + + +def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): + vae = _make_non_seedvr2_vae() + vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) + vae.upscale_ratio = (lambda a: a * 4, 8, 8) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + with patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode_tiled(pixel_samples) + + assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64) diff --git a/tests-unit/comfy_test/test_vae_encode_tiled_seedvr2_method.py b/tests-unit/comfy_test/test_vae_encode_tiled_seedvr2_method.py new file mode 100644 index 000000000..0013cd6ed --- /dev/null +++ b/tests-unit/comfy_test/test_vae_encode_tiled_seedvr2_method.py @@ -0,0 +1,205 @@ +"""Unit tests for ``VAE.encode_tiled_seedvr2``: existence with the +SeedVR2 tile-shape signature and delegation through +``comfy.ldm.seedvr.vae.tiled_vae(..., encode=True)`` with one call per +spatial tile. + +Mirrors the decode-side method-existence + delegation contract for +``VAE.decode_tiled_seedvr2``; CPU-only via mocks and a +``VideoAutoencoderKLWrapper.__new__`` wrapper stub (no weights, no +GPU). +""" + +import inspect +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +import nodes as nodes_mod # noqa: E402 + + +def _make_minimal_seedvr2_vae(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper + ) + vae.first_stage_model = wrapper + + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = 8 + + vae.vae_output_dtype = lambda: torch.float32 + vae.process_input = lambda x: x + return vae + + +def test_method_exists_with_seedvr2_signature(): + assert hasattr(sd_mod.VAE, "encode_tiled_seedvr2"), ( + "VAE.encode_tiled_seedvr2 must be defined on the VAE class." + ) + sig = inspect.signature(sd_mod.VAE.encode_tiled_seedvr2) + params = list(sig.parameters) + for required in ("self", "pixel_samples", "tile_x", "tile_y", + "overlap", "tile_t", "overlap_t"): + assert required in params, ( + f"VAE.encode_tiled_seedvr2 missing required parameter " + f"{required!r}; got parameters {params}." + ) + + +def test_vae_encode_tiled_allows_zero_temporal_controls_and_passes_zero_through(): + input_types = nodes_mod.VAEEncodeTiled.INPUT_TYPES()["required"] + assert input_types["temporal_size"][1]["min"] == 0 + assert input_types["temporal_overlap"][1]["min"] == 0 + assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"] + + class _EncodeRecorder: + def __init__(self): + self.calls = [] + + def encode_tiled(self, pixels, **kwargs): + self.calls.append({"shape": tuple(pixels.shape), **kwargs}) + return torch.zeros(1, 16, 1, 8, 8) + + recorder = _EncodeRecorder() + node = nodes_mod.VAEEncodeTiled() + + output = node.encode( + recorder, + torch.zeros(1, 64, 64, 3), + tile_size=256, + overlap=64, + temporal_size=0, + temporal_overlap=8, + ) + + assert recorder.calls == [ + { + "shape": (1, 64, 64, 3), + "tile_x": 256, + "tile_y": 256, + "overlap": 64, + "tile_t": 0, + "overlap_t": 0, + } + ] + assert torch.equal(output[0]["samples"], torch.zeros(1, 16, 1, 8, 8)) + + +def test_method_routes_through_tiled_vae_encode_true(): + vae = _make_minimal_seedvr2_vae() + pixel_samples = torch.zeros((1, 3, 8, 64, 64)) + + tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8))) + + with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock): + vae.encode_tiled_seedvr2(pixel_samples) + + assert tiled_vae_mock.call_count >= 1, ( + f"Expected encode_tiled_seedvr2 to delegate to tiled_vae at " + f"least once; got {tiled_vae_mock.call_count} calls." + ) + for call in tiled_vae_mock.call_args_list: + assert call.kwargs.get("encode") is True, ( + f"Every tiled_vae delegation from encode_tiled_seedvr2 must " + f"pass encode=True; got kwargs={call.kwargs!r}." + ) + + +def test_method_sets_wrapper_device_before_tiled_vae(): + vae = _make_minimal_seedvr2_vae() + pixel_samples = torch.zeros((1, 3, 8, 64, 64)) + assert not hasattr(vae.first_stage_model, "device") + + def _assert_device_initialized(*args, **kwargs): + vae_model = args[1] + assert vae_model.device == vae.device + return torch.zeros((1, 16, 2, 8, 8)) + + with patch.object(seedvr_vae_mod, "tiled_vae", + MagicMock(side_effect=_assert_device_initialized)): + vae.encode_tiled_seedvr2(pixel_samples) + + +def test_method_honors_explicit_tile_parameters_over_stale_wrapper_args(): + vae = _make_minimal_seedvr2_vae() + pixel_samples = torch.zeros((1, 3, 8, 64, 64)) + vae.first_stage_model.tiled_args = { + "tile_size": (17, 19), + "tile_overlap": (3, 5), + "temporal_size": 7, + "temporal_overlap": 2, + "preserved": "value", + } + + tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8))) + + with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock): + vae.encode_tiled_seedvr2( + pixel_samples, + tile_x=96, + tile_y=80, + overlap=12, + tile_t=11, + overlap_t=4, + ) + + assert tiled_vae_mock.call_args.kwargs["tile_size"] == (80, 96) + assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (12, 12) + assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 11 + assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 4 + assert vae.first_stage_model.tiled_args["preserved"] == "value" + + +def test_method_uses_explicit_defaults_when_call_omits_tile_parameters(): + vae = _make_minimal_seedvr2_vae() + pixel_samples = torch.zeros((1, 3, 8, 64, 64)) + vae.first_stage_model.tiled_args = { + "tile_size": (128, 160), + "tile_overlap": (16, 24), + "temporal_size": 9, + "temporal_overlap": 1, + } + + tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8))) + + with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock): + vae.encode_tiled_seedvr2(pixel_samples) + + assert tiled_vae_mock.call_args.kwargs["tile_size"] == (512, 512) + assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (64, 64) + assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 9999 + assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 0 + assert vae.first_stage_model.tiled_args == { + "tile_size": (128, 160), + "tile_overlap": (16, 24), + "temporal_size": 9, + "temporal_overlap": 1, + } + + +def test_method_clamps_overlap_below_tile_size(): + vae = _make_minimal_seedvr2_vae() + pixel_samples = torch.zeros((1, 3, 8, 64, 64)) + + tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8))) + + with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock): + vae.encode_tiled_seedvr2( + pixel_samples, + tile_x=64, + tile_y=48, + overlap=96, + ) + + assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (40, 56)