Add SeedVR2 VAE tiling coverage

This commit is contained in:
John Pollock 2026-05-25 22:12:33 -05:00
parent 9eb6c7fe9e
commit c3bfb743e8
15 changed files with 2004 additions and 0 deletions

View File

@ -0,0 +1,61 @@
"""Regression test for ``comfy_extras.nodes_seedvr.clear_vae_memory`` —
must dispatch its cache clear via ``comfy.model_management.soft_empty_cache``
rather than calling ``torch.cuda.empty_cache()`` directly. The canonical helper
at ``comfy/model_management.py:1780`` short-circuits via ``cpu_mode()`` and
dispatches per-backend (MPS / XPU / NPU / MLU / CUDA), so it is the only
correct call shape on non-CUDA hosts and on managed-device hosts where
``comfy.cli_args.args.cpu`` is True.
"""
from unittest.mock import patch
import torch
# CPU-only CI fix: ``comfy_extras.nodes_seedvr`` transitively imports
# ``comfy.model_management``, whose module-level
# ``cpu_state = CPUState.CPU if args.cpu`` initialiser
# (``comfy/model_management.py:152-153``) reads ``comfy.cli_args.args.cpu``
# at import time. Match the pattern at
# ``tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py:33-44``: flip
# ``args.cpu`` BEFORE importing any ``comfy.ldm.*`` or ``comfy_extras.*``
# symbol. This module forces ``args.cpu = True`` unconditionally (rather
# than only when ``torch.cuda.is_available()`` is False) so ``cpu_mode()``
# returns True at call time regardless of host CUDA availability — the
# path under test is ``soft_empty_cache``'s CPU-mode short-circuit at
# ``comfy/model_management.py:1781``.
from comfy.cli_args import args as _cli_args
_cli_args.cpu = True
import comfy.model_management # noqa: E402
import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402
def test_clear_vae_memory_uses_soft_empty_cache():
"""``clear_vae_memory(stub)`` must invoke
``comfy.model_management.soft_empty_cache`` exactly once and
``torch.cuda.empty_cache`` zero times when ``args.cpu`` is True.
"""
stub = torch.nn.Module()
with patch.object(
comfy.model_management, "soft_empty_cache"
) as soft_empty_spy, patch.object(
torch.cuda, "empty_cache"
) as cuda_empty_spy:
nodes_seedvr.clear_vae_memory(stub)
assert cuda_empty_spy.call_count == 0, (
f"torch.cuda.empty_cache was called {cuda_empty_spy.call_count} "
f"times; expected 0. clear_vae_memory must dispatch via "
f"comfy.model_management.soft_empty_cache, which short-circuits in "
f"CPU mode (cpu_mode() check at comfy/model_management.py:1781). "
f"The unguarded torch.cuda.empty_cache() call at "
f"comfy_extras/nodes_seedvr.py:84 is the regression this test locks."
)
assert soft_empty_spy.call_count == 1, (
f"comfy.model_management.soft_empty_cache was called "
f"{soft_empty_spy.call_count} times; expected exactly 1. "
f"clear_vae_memory must dispatch its cache clear via the canonical "
f"per-backend helper at comfy/model_management.py:1780."
)

View File

@ -0,0 +1,356 @@
from unittest.mock import patch
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
def _lab_color_passthrough(content, style):
return content
def _decode_fingerprint(self, z, return_dict=True):
b, _, t, h, w = z.shape
out = torch.empty(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def _make_wrapper(b=2, t=3, enable_tiling=False):
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
wrapper.tiled_args = {"enable_tiling": enable_tiling}
wrapper.original_image_video = torch.zeros(b, 3, t, 16, 16)
wrapper.img_dims = (16, 16)
return wrapper
def test_seedvr2_decode_accepts_5d_bcthw_latents_and_preserves_batch_time_axes():
wrapper = _make_wrapper(b=2, t=3, enable_tiling=False)
latent = torch.zeros(2, 16, 3, 2, 2)
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_fingerprint), \
patch.object(vae_mod, "lab_color_transfer", _lab_color_passthrough):
out = wrapper.decode(latent)
assert tuple(out.shape) == (2, 3, 3, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
assert out[1, 0, 0, 0, 0].item() == 2.0
class _SeedVR2DecodeStub(vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.tiled_args = {}
self.calls = []
self.original_image_video = torch.zeros(1, 3, 12, 16, 16)
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
def decode(self, z, seedvr2_tiling=None):
self.calls.append({"seedvr2_tiling": seedvr2_tiling, "shape": tuple(z.shape)})
return z
def test_vae_decode_tiled_allows_zero_temporal_controls_and_passes_them_through():
input_types = nodes_mod.VAEDecodeTiled.INPUT_TYPES()["required"]
assert input_types["temporal_size"][1]["min"] == 0
assert input_types["temporal_overlap"][1]["min"] == 0
assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"]
class _DecodeRecorder:
def __init__(self):
self.calls = []
def temporal_compression_decode(self):
return 4
def spacial_compression_decode(self):
return 8
def decode_tiled(self, samples, **kwargs):
self.calls.append({"shape": tuple(samples.shape), **kwargs})
return torch.zeros(1, 8, 8, 3)
recorder = _DecodeRecorder()
node = nodes_mod.VAEDecodeTiled()
node.decode(
recorder,
{"samples": torch.zeros(1, 16, 3, 32, 32)},
tile_size=256,
overlap=64,
temporal_size=0,
temporal_overlap=0,
)
assert recorder.calls == [
{
"shape": (1, 16, 3, 32, 32),
"tile_x": 32,
"tile_y": 32,
"overlap": 8,
"tile_t": 0,
"overlap_t": 0,
}
]
def test_vae_decode_tiled_preserves_positive_overlap_after_temporal_compression():
class _DecodeRecorder:
def __init__(self):
self.calls = []
def temporal_compression_decode(self):
return 8
def spacial_compression_decode(self):
return 8
def decode_tiled(self, samples, **kwargs):
self.calls.append(kwargs)
return torch.zeros(1, 8, 8, 3)
recorder = _DecodeRecorder()
nodes_mod.VAEDecodeTiled().decode(
recorder,
{"samples": torch.zeros(1, 16, 3, 32, 32)},
tile_size=256,
overlap=64,
temporal_size=64,
temporal_overlap=4,
)
assert recorder.calls[0]["tile_t"] == 8
assert recorder.calls[0]["overlap_t"] == 1
def test_seedvr2_decode_tiled_uses_seedvr2_path_not_generic_3d_tiler(monkeypatch):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = _SeedVR2DecodeStub()
vae.vae_dtype = torch.float32
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.disable_offload = True
vae.extra_1d_channel = None
vae.memory_used_decode = lambda shape, dtype: 1
vae.process_output = lambda x: x
vae.patcher = object()
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called")))
latent = torch.zeros(1, 16, 3, 2, 2)
out = vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
assert tuple(out.shape) == (1, 3, 2, 2, 16)
assert vae.first_stage_model.calls == [
{
"shape": (1, 16, 3, 2, 2),
"seedvr2_tiling": {
"enable_tiling": True,
"tile_size": (16, 16),
"tile_overlap": (8, 8),
"temporal_size": 64,
"temporal_overlap": 16,
},
}
]
def test_seedvr2_decode_tiled_explicit_args_override_stale_tiled_args():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = _SeedVR2DecodeStub()
vae.first_stage_model.tiled_args = {
"enable_tiling": False,
"tile_size": (384, 384),
"tile_overlap": (128, 128),
"temporal_size": 16,
"temporal_overlap": 4,
"preserved": "metadata",
}
vae.vae_dtype = torch.float32
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.disable_offload = True
vae.extra_1d_channel = None
vae.memory_used_decode = lambda shape, dtype: 1
vae.process_output = lambda x: x
vae.patcher = object()
latent = torch.zeros(1, 16, 3, 2, 2)
vae.decode_tiled_seedvr2(
latent,
tile_x=32,
tile_y=32,
overlap=8,
tile_t=0,
overlap_t=0,
)
captured = vae.first_stage_model.calls[0]["seedvr2_tiling"]
assert captured["enable_tiling"] is True
assert captured["tile_size"] == (256, 256)
assert captured["tile_overlap"] == (64, 64)
assert captured["temporal_size"] == 0
assert captured["temporal_overlap"] == 0
assert "preserved" not in captured
assert vae.first_stage_model.tiled_args == {
"enable_tiling": False,
"tile_size": (384, 384),
"tile_overlap": (128, 128),
"temporal_size": 16,
"temporal_overlap": 4,
"preserved": "metadata",
}
def test_seedvr2_decode_preserves_requested_spatial_tile_above_512(monkeypatch):
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
captured = {}
def fake_tiled_vae(latent, model, **kwargs):
captured.update(kwargs)
return torch.zeros(1, 3, 1, 16, 16)
monkeypatch.setattr(vae_mod, "tiled_vae", fake_tiled_vae)
wrapper.decode(
torch.zeros(1, 16, 1, 2, 2),
seedvr2_tiling={
"enable_tiling": True,
"tile_size": (1024, 768),
"tile_overlap": (800, 800),
"temporal_size": 0,
"temporal_overlap": 0,
},
)
assert captured["tile_size"] == (1024, 768)
assert captured["tile_overlap"] == (800, 760)
def test_seedvr2_decode_tiled_preserves_ambiguous_channel_first_latents(monkeypatch):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = _SeedVR2DecodeStub()
vae.vae_dtype = torch.float32
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.disable_offload = True
vae.extra_1d_channel = None
vae.latent_channels = 16
vae.memory_used_decode = lambda shape, dtype: 1
vae.process_output = lambda x: x
vae.patcher = object()
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called")))
latent = torch.zeros(1, 16, 8, 8, 16)
vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 8, 8, 16)
def test_seedvr2_decode_tiled_does_not_repair_latent_layout(monkeypatch):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = _SeedVR2DecodeStub()
vae.vae_dtype = torch.float32
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.disable_offload = True
vae.extra_1d_channel = None
vae.latent_channels = 16
vae.memory_used_decode = lambda shape, dtype: 1
vae.process_output = lambda x: x
vae.patcher = object()
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called")))
latent = torch.zeros(1, 9, 8, 8, 16)
vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
assert vae.first_stage_model.calls[0]["shape"] == (1, 9, 8, 8, 16)
def test_seedvr2_decode_tiled_routes_collapsed_latents_to_seedvr2_tiler(monkeypatch):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = _SeedVR2DecodeStub()
vae.vae_dtype = torch.float32
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.disable_offload = True
vae.extra_1d_channel = None
vae.latent_channels = 16
vae.memory_used_decode = lambda shape, dtype: 1
vae.process_output = lambda x: x
vae.patcher = object()
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_ called")))
latent = torch.zeros(1, 48, 2, 2)
vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
assert vae.first_stage_model.calls[0]["shape"] == (1, 48, 2, 2)
assert vae.first_stage_model.calls[0]["seedvr2_tiling"]["temporal_overlap"] == 16
class _TemporalChunkRecorder(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(()))
self.device = "cpu"
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 4
self.chunks = []
def decode_(self, z):
self.chunks.append([int(v) for v in z[0, 0, :, 0, 0].tolist()])
pieces = [z[:, :1, :1]]
if z.shape[2] > 1:
pieces.append(z[:, :1, 1:].repeat_interleave(4, dim=2))
return torch.cat(pieces, dim=2)
def test_seedvr2_tiled_vae_decode_uses_single_slicing_call_per_spatial_tile():
"""After the temporal-stitching fix, run_temporal_chunks delegates to
the wrapper's slicing path with a single decode_ call per spatial tile
(rather than the old hand-rolled outer temporal chunking that reset
causal cache between chunks). Validate the new contract: recorder sees
one call covering the full temporal axis, output shape and value
pattern are equivalent to what the temporal-overlap path produced.
"""
recorder = _TemporalChunkRecorder()
latent = torch.arange(6, dtype=torch.float32).view(1, 1, 6, 1, 1)
out = vae_mod.tiled_vae(
latent,
recorder,
tile_size=(1, 1),
tile_overlap=(0, 0),
temporal_size=16,
temporal_overlap=4,
encode=False,
)
assert recorder.chunks == [[0, 1, 2, 3, 4, 5]]
assert tuple(out.shape) == (1, 1, 21, 1, 1)
assert [int(v) for v in out[0, 0, [0, 1, 5, 9, 13, 17], 0, 0].tolist()] == [0, 1, 2, 3, 4, 5]

View File

@ -0,0 +1,133 @@
from unittest.mock import patch
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
return wrapper
def _fingerprint_decode_(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
out = torch.empty(b, 3, t, h * 8, w * 8)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def _decode_with_patches(wrapper, z):
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
return wrapper.decode(z)
def test_decode_b1_t1_shape_and_ordering_correct():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(1, 16, 2, 2))
assert tuple(out.shape) == (1, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
def test_decode_b1_t5_video_shape_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(1, 16 * 5, 2, 2))
assert tuple(out.shape) == (1, 3, 5, 16, 16)
def test_decode_b2_t1_preserves_batch_time_axes():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
assert tuple(out.shape) == (2, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
assert out[1, 0, 0, 0, 0].item() == 2.0
def test_decode_b4_t1_preserves_batch_time_axes():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(4, 16, 2, 2))
assert tuple(out.shape) == (4, 3, 1, 16, 16)
assert [out[b, 0, 0, 0, 0].item() for b in range(4)] == [1.0, 2.0, 3.0, 4.0]
def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16)
def _tiled_vae_4d_stub(latent, vae_model, **kwargs):
b = int(latent.shape[0])
h = int(latent.shape[3]) * 8
w = int(latent.shape[4]) * 8
out = torch.empty(b, 3, h, w)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def test_decode_tiled_single_frame_4d_output_normalized():
wrapper = _make_wrapper()
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
assert tuple(out.shape) == (1, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
def test_decode_tiled_b2_t1_per_sample_ordering():
wrapper = _make_wrapper()
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
out = wrapper.decode(torch.zeros(2, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
assert tuple(out.shape) == (2, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
assert out[1, 0, 0, 0, 0].item() == 2.0
def test_decode_b2_t1_stacked_equals_individual_per_sample_ordering():
wrapper = _make_wrapper()
out_stacked = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
def _decode_pinned(value):
def _stub(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
return torch.full((b, 3, t, h * 8, w * 8), value)
return _stub
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(1.0)):
out_individual_0 = wrapper.decode(torch.zeros(1, 16, 2, 2))
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(2.0)):
out_individual_1 = wrapper.decode(torch.zeros(1, 16, 2, 2))
assert torch.equal(out_stacked[0, :, 0, :, :], out_individual_0[0, :, 0, :, :])
assert torch.equal(out_stacked[1, :, 0, :, :], out_individual_1[0, :, 0, :, :])

View File

@ -0,0 +1,85 @@
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.calls = []
def parameters(self):
return iter([torch.nn.Parameter(torch.zeros(()))])
def _decode_stub(self, latent):
self.calls.append(tuple(latent.shape))
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)]
def test_seedvr2_wrapper_decode_accepts_collapsed_4d_latents_without_preprocessor_state():
wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, 32, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)]
def test_seedvr2_wrapper_decode_accepts_noncontiguous_collapsed_4d_latents():
wrapper = _Wrapper()
latent = torch.zeros(1, 4, 5, 32).permute(0, 3, 1, 2)
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(latent)
assert not latent.is_contiguous()
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)]
def test_seedvr2_wrapper_decode_rejects_non_dict_tiling_options():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match="seedvr2_tiling.*dict"):
wrapper.decode(torch.zeros(1, 16, 2, 4, 5), seedvr2_tiling=True)
def test_seedvr2_wrapper_decode_rejects_wrong_5d_channel_count():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match="5-D latent input must have 16 channels"):
wrapper.decode(torch.zeros(1, 8, 2, 4, 5))
def test_seedvr2_wrapper_decode_rejects_misaligned_collapsed_4d_latents():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"4-D latent input must use collapsed channel layout"):
wrapper.decode(torch.zeros(1, 17, 4, 5))
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
wrapper.decode(torch.zeros(1, 16, 4))

View File

@ -0,0 +1,35 @@
import pytest
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy_extras import nodes_seedvr # noqa: E402
def _t_padded(t_in: int) -> int:
if t_in == 1:
return 1
if t_in <= 4:
return 5
if (t_in - 1) % 4 == 0:
return t_in
return t_in + (4 - ((t_in - 1) % 4))
@pytest.mark.parametrize("t_in", [1, 2, 3, 4, 5, 6, 7, 8])
def test_t_padded_matches_cut_videos(t_in):
dummy = torch.zeros(1, t_in, 1, 1, 1)
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)
@pytest.mark.parametrize("t_in", [1, 2, 3, 4, 5, 6, 7, 8])
def test_post_processing_trims_decoded_video_to_explicit_reference_frames(t_in):
decoded = torch.zeros(1, _t_padded(t_in), 32, 32, 3)
original = torch.zeros(1, t_in, 32, 32, 3)
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 32, "none").result[0]
assert tuple(output.shape) == (1, t_in, 32, 32, 3)

View File

@ -0,0 +1,165 @@
"""Regression test for ``comfy/sd.py``'s ``VAE.__init__`` loader — must
apply SeedVR2-specific metadata when the SeedVR2 magic key
``decoder.up_blocks.2.upsamplers.0.upscale_conv.weight`` is present in the
state dict.
Without the SeedVR2 elif branch the loader leaves ``latent_channels=4`` /
``latent_dim=2`` defaults, so down-stream consumers mis-shape the latent
buffer and crash with a channel-count mismatch. The expected behaviour
sets ``latent_channels=16``, ``latent_dim=3``, ``disable_offload=True``,
``downscale_index_formula=(4, 8, 8)``, ``upscale_index_formula=(4, 8,
8)``, plus the SeedVR2 ``memory_used_decode`` / ``memory_used_encode``
lambdas, the ``downscale_ratio`` / ``upscale_ratio`` tuples, and the
SeedVR2 ``process_input`` / ``crop_input=False`` overrides.
This module exercises the real ``VAE.__init__`` detection-and-load path
with a stubbed state dict containing only the SeedVR2 magic key, and
patches ``comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper`` with a tiny
``nn.Module`` subclass so the test stays CPU-only and weight-load-free
while still satisfying ``isinstance(...)`` against the real wrapper class
(see ``_StubVideoAutoencoderKLWrapper`` below).
"""
from unittest.mock import patch
import pytest
import torch
# CPU-only CI fix: ``comfy.sd`` transitively imports
# ``comfy.model_management``, whose import-time
# ``cpu_state = CPUState.CPU if args.cpu`` initialiser reads
# ``comfy.cli_args.args.cpu``. Match the pattern at
# ``tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py:33-44``: flip
# ``args.cpu`` BEFORE importing any ``comfy.sd`` / ``comfy.ldm.*`` symbol
# when CUDA is unavailable. Issue-191 AC-3 additionally requires the
# ``_cli_args.cpu = True`` assignment line number to precede every line
# matching ``^import comfy`` or ``^from comfy`` in the committed file, so
# the cli_args module is loaded via ``importlib`` here rather than via
# ``from comfy.cli_args import args``.
import importlib
_cli_args = importlib.import_module("comfy.cli_args").args
if not torch.cuda.is_available():
_cli_args.cpu = True
import torch.nn as nn # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae # noqa: E402
import comfy.sd # noqa: E402
_SEEDVR2_MAGIC_KEY = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight"
class _StubVideoAutoencoderKLWrapper(seedvr_vae.VideoAutoencoderKLWrapper):
"""Subclass that bypasses the real wrapper's heavy weight construction.
The downstream ``comfy.sd.VAE.__init__`` lifecycle after line 519 only
relies on ``nn.Module`` machinery ``.eval()``, ``.to(dtype)``,
``state_dict()`` for ``module_size``, and
``load_state_dict(strict=False)``. A bare ``nn.Module.__init__`` provides
all of that. Subclassing ``VideoAutoencoderKLWrapper`` keeps
``isinstance(stub_instance, VideoAutoencoderKLWrapper)`` ``True`` after
the patch context exits, so the AC-A isinstance assertion holds against
the real wrapper class.
"""
def __init__(self):
nn.Module.__init__(self)
def _build_seedvr2_stub_sd():
"""Minimum state dict that triggers the SeedVR2 elif branch in
``comfy/sd.py``. The detection is a pure ``in sd`` containment check
against the magic key at line 518; no other key is required to reach
that branch (the diffusers-convert early-out at lines 444-446 is
short-circuited by the ``is_seedvr2_vae`` flag set at line 443).
The ``load_state_dict`` call at line 884 uses ``strict=False`` so the
single magic key is accepted as ``unexpected`` against the empty stub
module without raising.
"""
return {_SEEDVR2_MAGIC_KEY: torch.zeros(1)}
@pytest.fixture(scope="module")
def seedvr2_vae():
"""Build a real ``comfy.sd.VAE`` instance through the detection-and-load
path with the SeedVR2 wrapper class stubbed for CPU-only execution.
"""
sd = _build_seedvr2_stub_sd()
with patch.object(
seedvr_vae,
"VideoAutoencoderKLWrapper",
_StubVideoAutoencoderKLWrapper,
):
vae = comfy.sd.VAE(sd=sd)
return vae
def test_seedvr2_loader_first_stage_model_is_video_autoencoder_kl_wrapper(
seedvr2_vae,
):
assert isinstance(
seedvr2_vae.first_stage_model, seedvr_vae.VideoAutoencoderKLWrapper
) is True, (
"Expected first_stage_model to be a VideoAutoencoderKLWrapper "
f"instance; got {type(seedvr2_vae.first_stage_model).__name__}. The "
"SeedVR2 elif branch at comfy/sd.py:518 may not have been taken."
)
def test_seedvr2_loader_sets_latent_channels_16(seedvr2_vae):
assert seedvr2_vae.latent_channels == 16, (
"Expected latent_channels=16 (set at comfy/sd.py:520 inside the "
f"SeedVR2 elif branch); got {seedvr2_vae.latent_channels}. SeedVR2's "
"VideoAutoencoderKL uses 16-channel latents per Wang et al., ICLR "
"2026 (arXiv 2506.05301) §3; the loader default of 4 (comfy/sd.py:457)"
" is wrong for the SeedVR2 path."
)
def test_seedvr2_loader_sets_latent_dim_3(seedvr2_vae):
assert seedvr2_vae.latent_dim == 3, (
"Expected latent_dim=3 (set at comfy/sd.py:521 inside the SeedVR2 "
f"elif branch); got {seedvr2_vae.latent_dim}. SeedVR2 latents are 3D "
"(T, H, W) per the upstream ByteDance-Seed/SeedVR "
"VideoAutoencoderKL contract; the loader default of 2 "
"(comfy/sd.py:458) is wrong for the SeedVR2 path."
)
def test_seedvr2_loader_sets_downscale_index_formula(seedvr2_vae):
assert seedvr2_vae.downscale_index_formula == (4, 8, 8), (
"Expected downscale_index_formula=(4, 8, 8) (set at "
f"comfy/sd.py:527); got {seedvr2_vae.downscale_index_formula}. "
"SeedVR2's spatial-temporal downscale ratio is 4× temporal × 8× "
"spatial × 8× spatial."
)
def test_seedvr2_loader_sets_upscale_index_formula(seedvr2_vae):
assert seedvr2_vae.upscale_index_formula == (4, 8, 8), (
"Expected upscale_index_formula=(4, 8, 8) (set at "
f"comfy/sd.py:529); got {seedvr2_vae.upscale_index_formula}. "
"SeedVR2's spatial-temporal upscale ratio is the inverse of its "
"downscale ratio: 4× temporal × 8× spatial × 8× spatial."
)
def test_seedvr2_loader_sets_disable_offload(seedvr2_vae):
assert seedvr2_vae.disable_offload is True, (
"Expected disable_offload=True (set at comfy/sd.py:522); got "
f"{seedvr2_vae.disable_offload}. SeedVR2 cannot tolerate CPU "
"offload during decode (the wrapper retains memory-state references "
"across slice boundaries — see VideoAutoencoderKL.slicing_decode)."
)
def test_seedvr2_loader_normalizes_comfy_pixels_at_vae_boundary(seedvr2_vae):
pixels = torch.tensor([0.0, 0.5, 1.0])
normalized = seedvr2_vae.process_input(pixels)
assert torch.equal(normalized, torch.tensor([-1.0, 0.0, 1.0]))

View File

@ -0,0 +1,11 @@
import re
from pathlib import Path
def test_seedvr_vae_decode_uses_explicit_tiling_options_not_object_state():
path = Path(__file__).resolve().parents[2] / "comfy" / "ldm" / "seedvr" / "vae.py"
src = path.read_text(encoding="utf-8")
assert not re.search(r"(?:self\.)?tiled_args\b", src), (
"VideoAutoencoderKLWrapper.decode must not read or mutate tiled_args "
f"object state. Source path: {path}"
)

View File

@ -0,0 +1,78 @@
from copy import deepcopy
def _valid_probe_payload():
sha = "0" * 64
return {
"torch_equal": True,
"non_tiled_sha256": sha,
"tiled_sha256": sha,
"dtype": "torch.float16",
"source_frames": 32,
"temporal_tile_size": 16,
"temporal_overlap": 4,
"generic_fallback_used": False,
}
def _assert_real_probe_json_contract(payload):
required = {
"torch_equal",
"non_tiled_sha256",
"tiled_sha256",
"dtype",
"source_frames",
"temporal_tile_size",
"temporal_overlap",
"generic_fallback_used",
}
missing = required.difference(payload)
if missing:
raise AssertionError(f"missing keys: {sorted(missing)}")
if payload["torch_equal"] is not True:
raise AssertionError("torch_equal must be true")
if payload["non_tiled_sha256"] != payload["tiled_sha256"]:
raise AssertionError("tensor sha256 values must match")
if payload["dtype"] != "torch.float16":
raise AssertionError("dtype must be torch.float16")
if payload["source_frames"] != 32:
raise AssertionError("source_frames must be 32")
if payload["temporal_tile_size"] != 16:
raise AssertionError("temporal_tile_size must be 16")
if payload["temporal_overlap"] != 4:
raise AssertionError("temporal_overlap must be 4")
if payload["generic_fallback_used"] is not False:
raise AssertionError("generic_fallback_used must be false")
def test_real_probe_json_contract():
valid = _valid_probe_payload()
_assert_real_probe_json_contract(valid)
for key in valid:
missing = deepcopy(valid)
missing.pop(key)
try:
_assert_real_probe_json_contract(missing)
except AssertionError:
pass
else:
raise AssertionError(f"accepted payload missing {key}")
invalid_values = {
"torch_equal": False,
"tiled_sha256": "1" * 64,
"dtype": "torch.float32",
"source_frames": 31,
"temporal_tile_size": 8,
"temporal_overlap": 0,
"generic_fallback_used": True,
}
for key, value in invalid_values.items():
invalid = deepcopy(valid)
invalid[key] = value
try:
_assert_real_probe_json_contract(invalid)
except AssertionError:
pass
else:
raise AssertionError(f"accepted payload with invalid {key}")

View File

@ -0,0 +1,86 @@
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
class StubVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_latent_min_size = 2
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, t_chunk):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
b, c, d, h, w = z.shape
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
vae = StubVAEModel()
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=False,
)
assert vae.decode_min_sizes == [5]
assert vae.memory_states == [MemoryState.DISABLED]
assert vae.slicing_latent_min_size == 2
def test_runtime_decode_preserves_min_size_when_decode_raises():
from comfy.ldm.seedvr.vae import tiled_vae
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_latent_min_size = 2
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def decode_(self, t_chunk):
raise RuntimeError("simulated decode failure")
vae = RaisingVAEModel()
z = torch.zeros((1, 16, 4, 8, 8), dtype=torch.float32)
raised = False
try:
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=False,
)
except RuntimeError as exc:
if "simulated decode failure" not in str(exc):
raise
raised = True
assert raised
assert vae.slicing_latent_min_size == 2

View File

@ -0,0 +1,89 @@
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
def test_slicing_encode_merges_runt_active_tail():
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
class StubVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.memory_states = []
self.encode_t = []
def encode(self, t_chunk):
h = VideoAutoencoderKL.slicing_encode(self, t_chunk)
return (h, h)
def _encode(self, x, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
self.encode_t.append(x.shape[2])
b, c, t_in, h, w = x.shape
target_d = max(1, (t_in + self.temporal_downsample_factor - 1) // self.temporal_downsample_factor)
target_h = (h + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
target_w = (w + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
return torch.zeros((b, 16, target_d, target_h, target_w), dtype=x.dtype)
vae = StubVAEModel()
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=None,
encode=True,
)
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
assert vae.encode_t == [5, 7]
assert min(vae.encode_t[1:]) >= vae.temporal_downsample_factor
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
from comfy.ldm.seedvr.vae import tiled_vae
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
raise RuntimeError("simulated encode failure")
vae = RaisingVAEModel()
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
raised = False
try:
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
except RuntimeError as exc:
if "simulated encode failure" not in str(exc):
raise
raised = True
assert raised
assert vae.slicing_sample_min_size == 4

View File

@ -0,0 +1,232 @@
from unittest.mock import patch
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
class _SlicingDecodeVAE(nn.Module):
def __init__(self, slicing_latent_min_size):
super().__init__()
self.slicing_latent_min_size = slicing_latent_min_size
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, z):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
x = z[:, :1].repeat(
1,
3,
1,
self.spatial_downsample_factor,
self.spatial_downsample_factor,
)
return x
class _EncodeVAE(nn.Module):
def __init__(self, slicing_sample_min_size):
super().__init__()
self.slicing_sample_min_size = slicing_sample_min_size
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.memory_states = []
self.encoded_t = []
self.encode_min_sizes = []
def encode(self, t_chunk):
self.encode_min_sizes.append(self.slicing_sample_min_size)
h = vae_mod.VideoAutoencoderKL.slicing_encode(self, t_chunk)
return (h, h)
def _encode(self, x, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
self.encoded_t.append(x.shape[2])
b, c, t_in, h, w = x.shape
target_d = max(1, (t_in + self.temporal_downsample_factor - 1) // self.temporal_downsample_factor)
target_h = (h + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
target_w = (w + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
z = torch.zeros((b, 16, target_d, target_h, target_w), dtype=x.dtype)
return z
class _LocalSpatialDecodeVAE(nn.Module):
def __init__(self):
super().__init__()
self.slicing_latent_min_size = 99
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.tile_shapes = []
def decode_(self, z):
self.tile_shapes.append(tuple(z.shape))
b, _, t, h, w = z.shape
width = w * self.spatial_downsample_factor
local_x = torch.arange(width, dtype=z.dtype).view(1, 1, 1, 1, width)
return local_x.expand(
b,
1,
t,
h * self.spatial_downsample_factor,
width,
).clone()
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=12,
temporal_overlap=4,
encode=False,
)
assert vae.decode_min_sizes == [2]
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
assert vae.slicing_latent_min_size == 2
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
seedvr2_tiling = {
"enable_tiling": True,
"tile_size": (64, 64),
"tile_overlap": (0, 0),
"temporal_size": 8,
"temporal_overlap": 7,
}
captured = {}
def _fake_tiled_vae(latent, model, **kwargs):
captured.update(kwargs)
return torch.zeros(1, 3, 1, 16, 16)
with (
patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae),
patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content),
):
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7
def test_encode_tiled_vae_zero_temporal_size_disables_wrapper_slicing():
vae = _EncodeVAE(slicing_sample_min_size=4)
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
assert vae.encode_min_sizes == [12]
assert vae.memory_states == [MemoryState.DISABLED]
assert vae.encoded_t == [12]
assert vae.slicing_sample_min_size == 4
def test_encode_tiled_vae_maps_temporal_args_to_sample_slicing_min_size():
vae = _EncodeVAE(slicing_sample_min_size=4)
x = torch.zeros((1, 3, 14, 64, 64), dtype=torch.float32)
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=8,
temporal_overlap=2,
encode=True,
)
assert vae.encode_min_sizes == [6]
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
assert vae.encoded_t == [7, 7]
assert vae.slicing_sample_min_size == 4
def test_boundary_reference_latent_no_periodic_temporal_tile_discontinuity():
z = torch.arange(1 * 16 * 7 * 8 * 8, dtype=torch.float32).reshape(1, 16, 7, 8, 8)
reference_vae = _SlicingDecodeVAE(slicing_latent_min_size=3)
expected = reference_vae.decode_(z)
tiled_vae_model = _SlicingDecodeVAE(slicing_latent_min_size=3)
actual = tiled_vae(
z,
tiled_vae_model,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=False,
)
assert torch.equal(actual, expected)
assert tiled_vae_model.decode_min_sizes == [7]
assert tiled_vae_model.memory_states == [MemoryState.DISABLED]
assert tiled_vae_model.slicing_latent_min_size == 3
spatial_vae = _LocalSpatialDecodeVAE()
spatial = tiled_vae(
torch.zeros(1, 16, 1, 8, 12),
spatial_vae,
tile_size=(64, 64),
tile_overlap=(0, 32),
encode=False,
)
ramp = 0.5 - 0.5 * torch.cos(torch.linspace(0, 1, steps=32) * torch.pi)
expected = (36.0 * (1.0 - ramp[4])) + (4.0 * ramp[4])
assert spatial_vae.tile_shapes == [
(1, 16, 1, 8, 8),
(1, 16, 1, 8, 8),
]
assert torch.isclose(spatial[0, 0, 0, 0, 36], expected)
def test_decode_tiled_vae_clamps_overlap_sized_tiles_to_preserve_coverage():
spatial_vae = _LocalSpatialDecodeVAE()
spatial = tiled_vae(
torch.zeros(1, 16, 1, 8, 12),
spatial_vae,
tile_size=(64, 64),
tile_overlap=(0, 128),
encode=False,
)
assert len(spatial_vae.tile_shapes) > 1
assert torch.count_nonzero(spatial[0, 0, 0, 0, 64:]) > 0

View File

@ -0,0 +1,165 @@
"""Unit test for the ``VAE.decode`` tiled-fallback dispatcher routing of
SeedVR2 latents in their 4D collapsed form ``(B, 16*T, H, W)``.
Regression: the dispatcher branch at ``comfy/sd.py``'s
``VAE.decode -> if do_tile: ... elif dims == 2`` previously routed
``ndim == 4`` SeedVR2 latents to the generic ``decode_tiled_``, whose
``tiled_scale`` mask broadcast does not understand the
``(16, T)`` channel-time collapse and crashed with
``"The size of tensor a (1024) must match the size of tensor b (256)
at non-singleton dimension 4"``.
Post-fix: when the wrapped model is a
``comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper`` and the input is 4D,
the dispatcher must route to ``decode_tiled_seedvr2`` instead. This
test verifies the dispatcher selection without invoking the actual VAE
math (which would require real model weights and a GPU): the two
candidate methods are patched, the regular decode is forced to OOM via
a stub, and the test asserts that ``decode_tiled_seedvr2`` is called
exactly once (and ``decode_tiled_`` zero times) for a 4D SeedVR2
input.
"""
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
def _make_minimal_seedvr2_vae():
"""Construct a ``comfy.sd.VAE`` instance whose ``first_stage_model``
is a real ``VideoAutoencoderKLWrapper`` (built via ``__new__`` to
skip weight allocation), with the VAE's other attributes stubbed
to the minimum that ``VAE.decode``'s regular-decode setup path
requires before the OOM forced fallback.
"""
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
# Minimum surface that ``VAE.decode`` touches before tiled fallback:
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 16
vae.latent_dim = 3 # SeedVR2 is a 3D-temporal latent format (T, H, W)
vae.downscale_ratio = 8
vae.downscale_index_formula = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_decode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_decode = lambda *a, **k: 1
return vae
def _force_regular_decode_oom(*args, **kwargs):
"""Stub ``first_stage_model.decode`` to raise an OOM-shaped error
so ``VAE.decode``'s ``except`` branch sets ``do_tile = True`` and
falls into the tiled-fallback dispatcher.
"""
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2():
vae = _make_minimal_seedvr2_vae()
samples_4d = torch.zeros(1, 16 * 3, 8, 8) # (B, 16*T, H, W), T=3
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
side_effect=_force_regular_decode_oom), \
patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call), \
patch.object(sd_mod.VAE, "decode_tiled_", generic_call):
vae.decode(samples_4d)
assert seedvr2_call.call_count == 1, (
f"Expected decode_tiled_seedvr2 to be called once for a 4D SeedVR2 "
f"latent under tiled fallback; got {seedvr2_call.call_count} calls."
)
assert generic_call.call_count == 0, (
f"decode_tiled_ must NOT be called for a 4D SeedVR2 latent; got "
f"{generic_call.call_count} calls. Pre-fix dispatcher would route "
f"to this method and crash inside tiled_scale's mask broadcast."
)
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
"""The dispatcher fix must NOT affect non-SeedVR2 4D latents: any
other VAE whose ``first_stage_model`` is not a
``VideoAutoencoderKLWrapper`` continues to route to the generic
``decode_tiled_``.
"""
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock() # NOT a VideoAutoencoderKLWrapper
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 4
vae.latent_dim = 2
vae.downscale_ratio = 8
vae.downscale_index_formula = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_decode = lambda: 8
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_decode = lambda *a, **k: 1
vae.first_stage_model.decode = MagicMock(
side_effect=_force_regular_decode_oom
)
samples_4d = torch.zeros(1, 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call), \
patch.object(sd_mod.VAE, "decode_tiled_", generic_call):
vae.decode(samples_4d)
assert generic_call.call_count == 1, (
f"Expected decode_tiled_ to be called once for a non-SeedVR2 4D "
f"latent; got {generic_call.call_count} calls."
)
assert seedvr2_call.call_count == 0, (
f"decode_tiled_seedvr2 must NOT be called for non-SeedVR2 latents; "
f"got {seedvr2_call.call_count} calls."
)

View File

@ -0,0 +1,119 @@
"""Unit tests for the explicit ``VAE.encode_tiled`` dispatcher routing of
SeedVR2 vs non-SeedVR2 3D inputs.
Mirrors the decode-side dispatcher contract in
``test_vae_decode_tiled_dispatcher_seedvr2_4d.py`` and the encode OOM
fallback contract in ``test_vae_encode_tiled_fallback_dispatcher_seedvr2.py``:
the two candidate methods (``encode_tiled_seedvr2``, ``encode_tiled_3d``)
are patched on the ``VAE`` class, ``encode_tiled`` is invoked directly,
and the test asserts the dispatcher selects the SeedVR2-aware tiler when
``first_stage_model`` is a ``VideoAutoencoderKLWrapper`` while preserving
the generic 3D tiler for non-SeedVR2 inputs.
"""
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
def _populate_common_vae_attrs(vae):
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = [lambda x: x]
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = [lambda x: x]
vae.downscale_index_formula = None
vae.not_video = False
vae.crop_input = False
vae.pad_channel_value = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_encode = lambda: 8
vae.vae_encode_crop_pixels = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_encode = lambda *a, **k: 1
def _make_seedvr2_vae():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
_populate_common_vae_attrs(vae)
return vae
def _make_non_seedvr2_vae():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock()
_populate_common_vae_attrs(vae)
return vae
def test_explicit_encode_tiled_seedvr2_3d_routes_to_seedvr2_tiler():
vae = _make_seedvr2_vae()
pixel_samples = torch.zeros((1, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
with patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode_tiled(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
f"input via explicit encode_tiled; got {seedvr2_call.call_count} calls."
)
assert generic_call.call_count == 0, (
f"encode_tiled_3d must NOT be called for a SeedVR2 input via explicit "
f"encode_tiled; got {generic_call.call_count} calls."
)
def test_explicit_encode_tiled_dispatcher_breakdown():
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
seedvr2_vae = _make_seedvr2_vae()
non_seedvr2_vae = _make_non_seedvr2_vae()
pixel_samples = torch.zeros((1, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
seedvr2_vae.encode_tiled(pixel_samples)
non_seedvr2_vae.encode_tiled(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected encode_tiled_seedvr2 called once across SeedVR2 + "
f"non-SeedVR2 explicit encode_tiled calls; got "
f"{seedvr2_call.call_count}."
)
assert generic_call.call_count == 1, (
f"Expected encode_tiled_3d called once across SeedVR2 + non-SeedVR2 "
f"explicit encode_tiled calls; got {generic_call.call_count}."
)

View File

@ -0,0 +1,184 @@
"""Unit tests for the ``VAE.encode`` OOM-fallback dispatcher routing of
SeedVR2 vs non-SeedVR2 3D inputs.
Mirrors the decode-side dispatcher contract in
``test_vae_decode_tiled_dispatcher_seedvr2_4d.py``: the two candidate
methods (``encode_tiled_seedvr2``, ``encode_tiled_3d``) are patched on
the ``VAE`` class, the regular encode is forced to OOM via a stub, and
the test asserts the dispatcher selects the SeedVR2-aware tiler when
``first_stage_model`` is a ``VideoAutoencoderKLWrapper`` while
preserving the generic 3D tiler for non-SeedVR2 inputs.
"""
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
def _populate_common_vae_attrs(vae):
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.downscale_index_formula = None
vae.not_video = False
vae.crop_input = False
vae.pad_channel_value = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_encode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_encode = lambda *a, **k: 1
def _make_seedvr2_vae():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
_populate_common_vae_attrs(vae)
return vae
def _make_non_seedvr2_vae():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock()
_populate_common_vae_attrs(vae)
return vae
def _force_regular_encode_oom(*args, **kwargs):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom():
vae = _make_seedvr2_vae()
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
side_effect=_force_regular_encode_oom), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
)
assert generic_call.call_count == 0, (
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
f"{generic_call.call_count} calls."
)
def test_seedvr2_oom_fallback_uses_explicit_seedvr2_tile_defaults():
vae = _make_seedvr2_vae()
vae.first_stage_model.tiled_args = {
"tile_size": (128, 128),
"tile_overlap": (32, 32),
"temporal_size": 12,
"temporal_overlap": 4,
}
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
side_effect=_force_regular_encode_oom), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True):
vae.encode(pixel_samples)
assert seedvr2_call.call_count == 1
assert seedvr2_call.call_args.kwargs == {
"tile_x": 256,
"tile_y": 256,
"overlap": 64,
}
def test_oom_fallback_dispatcher_breakdown():
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
seedvr2_vae = _make_seedvr2_vae()
non_seedvr2_vae = _make_non_seedvr2_vae()
non_seedvr2_vae.first_stage_model.encode = MagicMock(
side_effect=_force_regular_encode_oom
)
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
side_effect=_force_regular_encode_oom), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
seedvr2_vae.encode(pixel_samples)
non_seedvr2_vae.encode(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected encode_tiled_seedvr2 called once across SeedVR2 + "
f"non-SeedVR2 OOM fallbacks; got {seedvr2_call.call_count}."
)
assert generic_call.call_count == 1, (
f"Expected encode_tiled_3d called once across SeedVR2 + non-SeedVR2 "
f"OOM fallbacks; got {generic_call.call_count}."
)
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
vae = _make_non_seedvr2_vae()
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode_tiled(pixel_samples)
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)

View File

@ -0,0 +1,205 @@
"""Unit tests for ``VAE.encode_tiled_seedvr2``: existence with the
SeedVR2 tile-shape signature and delegation through
``comfy.ldm.seedvr.vae.tiled_vae(..., encode=True)`` with one call per
spatial tile.
Mirrors the decode-side method-existence + delegation contract for
``VAE.decode_tiled_seedvr2``; CPU-only via mocks and a
``VideoAutoencoderKLWrapper.__new__`` wrapper stub (no weights, no
GPU).
"""
import inspect
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
def _make_minimal_seedvr2_vae():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.vae_output_dtype = lambda: torch.float32
vae.process_input = lambda x: x
return vae
def test_method_exists_with_seedvr2_signature():
assert hasattr(sd_mod.VAE, "encode_tiled_seedvr2"), (
"VAE.encode_tiled_seedvr2 must be defined on the VAE class."
)
sig = inspect.signature(sd_mod.VAE.encode_tiled_seedvr2)
params = list(sig.parameters)
for required in ("self", "pixel_samples", "tile_x", "tile_y",
"overlap", "tile_t", "overlap_t"):
assert required in params, (
f"VAE.encode_tiled_seedvr2 missing required parameter "
f"{required!r}; got parameters {params}."
)
def test_vae_encode_tiled_allows_zero_temporal_controls_and_passes_zero_through():
input_types = nodes_mod.VAEEncodeTiled.INPUT_TYPES()["required"]
assert input_types["temporal_size"][1]["min"] == 0
assert input_types["temporal_overlap"][1]["min"] == 0
assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"]
class _EncodeRecorder:
def __init__(self):
self.calls = []
def encode_tiled(self, pixels, **kwargs):
self.calls.append({"shape": tuple(pixels.shape), **kwargs})
return torch.zeros(1, 16, 1, 8, 8)
recorder = _EncodeRecorder()
node = nodes_mod.VAEEncodeTiled()
output = node.encode(
recorder,
torch.zeros(1, 64, 64, 3),
tile_size=256,
overlap=64,
temporal_size=0,
temporal_overlap=8,
)
assert recorder.calls == [
{
"shape": (1, 64, 64, 3),
"tile_x": 256,
"tile_y": 256,
"overlap": 64,
"tile_t": 0,
"overlap_t": 0,
}
]
assert torch.equal(output[0]["samples"], torch.zeros(1, 16, 1, 8, 8))
def test_method_routes_through_tiled_vae_encode_true():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(pixel_samples)
assert tiled_vae_mock.call_count >= 1, (
f"Expected encode_tiled_seedvr2 to delegate to tiled_vae at "
f"least once; got {tiled_vae_mock.call_count} calls."
)
for call in tiled_vae_mock.call_args_list:
assert call.kwargs.get("encode") is True, (
f"Every tiled_vae delegation from encode_tiled_seedvr2 must "
f"pass encode=True; got kwargs={call.kwargs!r}."
)
def test_method_sets_wrapper_device_before_tiled_vae():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
assert not hasattr(vae.first_stage_model, "device")
def _assert_device_initialized(*args, **kwargs):
vae_model = args[1]
assert vae_model.device == vae.device
return torch.zeros((1, 16, 2, 8, 8))
with patch.object(seedvr_vae_mod, "tiled_vae",
MagicMock(side_effect=_assert_device_initialized)):
vae.encode_tiled_seedvr2(pixel_samples)
def test_method_honors_explicit_tile_parameters_over_stale_wrapper_args():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
vae.first_stage_model.tiled_args = {
"tile_size": (17, 19),
"tile_overlap": (3, 5),
"temporal_size": 7,
"temporal_overlap": 2,
"preserved": "value",
}
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(
pixel_samples,
tile_x=96,
tile_y=80,
overlap=12,
tile_t=11,
overlap_t=4,
)
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (80, 96)
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (12, 12)
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 11
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 4
assert vae.first_stage_model.tiled_args["preserved"] == "value"
def test_method_uses_explicit_defaults_when_call_omits_tile_parameters():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
vae.first_stage_model.tiled_args = {
"tile_size": (128, 160),
"tile_overlap": (16, 24),
"temporal_size": 9,
"temporal_overlap": 1,
}
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(pixel_samples)
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (512, 512)
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (64, 64)
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 9999
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 0
assert vae.first_stage_model.tiled_args == {
"tile_size": (128, 160),
"tile_overlap": (16, 24),
"temporal_size": 9,
"temporal_overlap": 1,
}
def test_method_clamps_overlap_below_tile_size():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(
pixel_samples,
tile_x=64,
tile_y=48,
overlap=96,
)
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (40, 56)