mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Add SeedVR2 VAE tiling coverage
This commit is contained in:
parent
9eb6c7fe9e
commit
c3bfb743e8
@ -0,0 +1,61 @@
|
||||
"""Regression test for ``comfy_extras.nodes_seedvr.clear_vae_memory`` —
|
||||
must dispatch its cache clear via ``comfy.model_management.soft_empty_cache``
|
||||
rather than calling ``torch.cuda.empty_cache()`` directly. The canonical helper
|
||||
at ``comfy/model_management.py:1780`` short-circuits via ``cpu_mode()`` and
|
||||
dispatches per-backend (MPS / XPU / NPU / MLU / CUDA), so it is the only
|
||||
correct call shape on non-CUDA hosts and on managed-device hosts where
|
||||
``comfy.cli_args.args.cpu`` is True.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
# CPU-only CI fix: ``comfy_extras.nodes_seedvr`` transitively imports
|
||||
# ``comfy.model_management``, whose module-level
|
||||
# ``cpu_state = CPUState.CPU if args.cpu`` initialiser
|
||||
# (``comfy/model_management.py:152-153``) reads ``comfy.cli_args.args.cpu``
|
||||
# at import time. Match the pattern at
|
||||
# ``tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py:33-44``: flip
|
||||
# ``args.cpu`` BEFORE importing any ``comfy.ldm.*`` or ``comfy_extras.*``
|
||||
# symbol. This module forces ``args.cpu = True`` unconditionally (rather
|
||||
# than only when ``torch.cuda.is_available()`` is False) so ``cpu_mode()``
|
||||
# returns True at call time regardless of host CUDA availability — the
|
||||
# path under test is ``soft_empty_cache``'s CPU-mode short-circuit at
|
||||
# ``comfy/model_management.py:1781``.
|
||||
from comfy.cli_args import args as _cli_args
|
||||
|
||||
_cli_args.cpu = True
|
||||
|
||||
import comfy.model_management # noqa: E402
|
||||
import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def test_clear_vae_memory_uses_soft_empty_cache():
|
||||
"""``clear_vae_memory(stub)`` must invoke
|
||||
``comfy.model_management.soft_empty_cache`` exactly once and
|
||||
``torch.cuda.empty_cache`` zero times when ``args.cpu`` is True.
|
||||
"""
|
||||
stub = torch.nn.Module()
|
||||
|
||||
with patch.object(
|
||||
comfy.model_management, "soft_empty_cache"
|
||||
) as soft_empty_spy, patch.object(
|
||||
torch.cuda, "empty_cache"
|
||||
) as cuda_empty_spy:
|
||||
nodes_seedvr.clear_vae_memory(stub)
|
||||
|
||||
assert cuda_empty_spy.call_count == 0, (
|
||||
f"torch.cuda.empty_cache was called {cuda_empty_spy.call_count} "
|
||||
f"times; expected 0. clear_vae_memory must dispatch via "
|
||||
f"comfy.model_management.soft_empty_cache, which short-circuits in "
|
||||
f"CPU mode (cpu_mode() check at comfy/model_management.py:1781). "
|
||||
f"The unguarded torch.cuda.empty_cache() call at "
|
||||
f"comfy_extras/nodes_seedvr.py:84 is the regression this test locks."
|
||||
)
|
||||
assert soft_empty_spy.call_count == 1, (
|
||||
f"comfy.model_management.soft_empty_cache was called "
|
||||
f"{soft_empty_spy.call_count} times; expected exactly 1. "
|
||||
f"clear_vae_memory must dispatch its cache clear via the canonical "
|
||||
f"per-backend helper at comfy/model_management.py:1780."
|
||||
)
|
||||
356
tests-unit/comfy_test/test_seedvr_vae_5d_tiled_decode.py
Normal file
356
tests-unit/comfy_test/test_seedvr_vae_5d_tiled_decode.py
Normal file
@ -0,0 +1,356 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
import nodes as nodes_mod # noqa: E402
|
||||
|
||||
|
||||
def _lab_color_passthrough(content, style):
|
||||
return content
|
||||
|
||||
|
||||
def _decode_fingerprint(self, z, return_dict=True):
|
||||
b, _, t, h, w = z.shape
|
||||
out = torch.empty(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||
for batch_idx in range(b):
|
||||
out[batch_idx].fill_(float(batch_idx + 1))
|
||||
return out
|
||||
|
||||
|
||||
def _make_wrapper(b=2, t=3, enable_tiling=False):
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
wrapper.tiled_args = {"enable_tiling": enable_tiling}
|
||||
wrapper.original_image_video = torch.zeros(b, 3, t, 16, 16)
|
||||
wrapper.img_dims = (16, 16)
|
||||
return wrapper
|
||||
|
||||
|
||||
def test_seedvr2_decode_accepts_5d_bcthw_latents_and_preserves_batch_time_axes():
|
||||
wrapper = _make_wrapper(b=2, t=3, enable_tiling=False)
|
||||
latent = torch.zeros(2, 16, 3, 2, 2)
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_fingerprint), \
|
||||
patch.object(vae_mod, "lab_color_transfer", _lab_color_passthrough):
|
||||
out = wrapper.decode(latent)
|
||||
|
||||
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
||||
assert out[0, 0, 0, 0, 0].item() == 1.0
|
||||
assert out[1, 0, 0, 0, 0].item() == 2.0
|
||||
|
||||
|
||||
class _SeedVR2DecodeStub(vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self.tiled_args = {}
|
||||
self.calls = []
|
||||
self.original_image_video = torch.zeros(1, 3, 12, 16, 16)
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
|
||||
def decode(self, z, seedvr2_tiling=None):
|
||||
self.calls.append({"seedvr2_tiling": seedvr2_tiling, "shape": tuple(z.shape)})
|
||||
return z
|
||||
|
||||
|
||||
def test_vae_decode_tiled_allows_zero_temporal_controls_and_passes_them_through():
|
||||
input_types = nodes_mod.VAEDecodeTiled.INPUT_TYPES()["required"]
|
||||
assert input_types["temporal_size"][1]["min"] == 0
|
||||
assert input_types["temporal_overlap"][1]["min"] == 0
|
||||
assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"]
|
||||
|
||||
class _DecodeRecorder:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def temporal_compression_decode(self):
|
||||
return 4
|
||||
|
||||
def spacial_compression_decode(self):
|
||||
return 8
|
||||
|
||||
def decode_tiled(self, samples, **kwargs):
|
||||
self.calls.append({"shape": tuple(samples.shape), **kwargs})
|
||||
return torch.zeros(1, 8, 8, 3)
|
||||
|
||||
recorder = _DecodeRecorder()
|
||||
node = nodes_mod.VAEDecodeTiled()
|
||||
|
||||
node.decode(
|
||||
recorder,
|
||||
{"samples": torch.zeros(1, 16, 3, 32, 32)},
|
||||
tile_size=256,
|
||||
overlap=64,
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
)
|
||||
|
||||
assert recorder.calls == [
|
||||
{
|
||||
"shape": (1, 16, 3, 32, 32),
|
||||
"tile_x": 32,
|
||||
"tile_y": 32,
|
||||
"overlap": 8,
|
||||
"tile_t": 0,
|
||||
"overlap_t": 0,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_vae_decode_tiled_preserves_positive_overlap_after_temporal_compression():
|
||||
class _DecodeRecorder:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def temporal_compression_decode(self):
|
||||
return 8
|
||||
|
||||
def spacial_compression_decode(self):
|
||||
return 8
|
||||
|
||||
def decode_tiled(self, samples, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return torch.zeros(1, 8, 8, 3)
|
||||
|
||||
recorder = _DecodeRecorder()
|
||||
|
||||
nodes_mod.VAEDecodeTiled().decode(
|
||||
recorder,
|
||||
{"samples": torch.zeros(1, 16, 3, 32, 32)},
|
||||
tile_size=256,
|
||||
overlap=64,
|
||||
temporal_size=64,
|
||||
temporal_overlap=4,
|
||||
)
|
||||
|
||||
assert recorder.calls[0]["tile_t"] == 8
|
||||
assert recorder.calls[0]["overlap_t"] == 1
|
||||
|
||||
|
||||
def test_seedvr2_decode_tiled_uses_seedvr2_path_not_generic_3d_tiler(monkeypatch):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = _SeedVR2DecodeStub()
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.memory_used_decode = lambda shape, dtype: 1
|
||||
vae.process_output = lambda x: x
|
||||
vae.patcher = object()
|
||||
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called")))
|
||||
|
||||
latent = torch.zeros(1, 16, 3, 2, 2)
|
||||
out = vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 2, 2, 16)
|
||||
assert vae.first_stage_model.calls == [
|
||||
{
|
||||
"shape": (1, 16, 3, 2, 2),
|
||||
"seedvr2_tiling": {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (16, 16),
|
||||
"tile_overlap": (8, 8),
|
||||
"temporal_size": 64,
|
||||
"temporal_overlap": 16,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_seedvr2_decode_tiled_explicit_args_override_stale_tiled_args():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = _SeedVR2DecodeStub()
|
||||
vae.first_stage_model.tiled_args = {
|
||||
"enable_tiling": False,
|
||||
"tile_size": (384, 384),
|
||||
"tile_overlap": (128, 128),
|
||||
"temporal_size": 16,
|
||||
"temporal_overlap": 4,
|
||||
"preserved": "metadata",
|
||||
}
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.memory_used_decode = lambda shape, dtype: 1
|
||||
vae.process_output = lambda x: x
|
||||
vae.patcher = object()
|
||||
|
||||
latent = torch.zeros(1, 16, 3, 2, 2)
|
||||
vae.decode_tiled_seedvr2(
|
||||
latent,
|
||||
tile_x=32,
|
||||
tile_y=32,
|
||||
overlap=8,
|
||||
tile_t=0,
|
||||
overlap_t=0,
|
||||
)
|
||||
|
||||
captured = vae.first_stage_model.calls[0]["seedvr2_tiling"]
|
||||
assert captured["enable_tiling"] is True
|
||||
assert captured["tile_size"] == (256, 256)
|
||||
assert captured["tile_overlap"] == (64, 64)
|
||||
assert captured["temporal_size"] == 0
|
||||
assert captured["temporal_overlap"] == 0
|
||||
assert "preserved" not in captured
|
||||
assert vae.first_stage_model.tiled_args == {
|
||||
"enable_tiling": False,
|
||||
"tile_size": (384, 384),
|
||||
"tile_overlap": (128, 128),
|
||||
"temporal_size": 16,
|
||||
"temporal_overlap": 4,
|
||||
"preserved": "metadata",
|
||||
}
|
||||
|
||||
|
||||
def test_seedvr2_decode_preserves_requested_spatial_tile_above_512(monkeypatch):
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
|
||||
captured = {}
|
||||
|
||||
def fake_tiled_vae(latent, model, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return torch.zeros(1, 3, 1, 16, 16)
|
||||
|
||||
monkeypatch.setattr(vae_mod, "tiled_vae", fake_tiled_vae)
|
||||
|
||||
wrapper.decode(
|
||||
torch.zeros(1, 16, 1, 2, 2),
|
||||
seedvr2_tiling={
|
||||
"enable_tiling": True,
|
||||
"tile_size": (1024, 768),
|
||||
"tile_overlap": (800, 800),
|
||||
"temporal_size": 0,
|
||||
"temporal_overlap": 0,
|
||||
},
|
||||
)
|
||||
|
||||
assert captured["tile_size"] == (1024, 768)
|
||||
assert captured["tile_overlap"] == (800, 760)
|
||||
|
||||
|
||||
def test_seedvr2_decode_tiled_preserves_ambiguous_channel_first_latents(monkeypatch):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = _SeedVR2DecodeStub()
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.latent_channels = 16
|
||||
vae.memory_used_decode = lambda shape, dtype: 1
|
||||
vae.process_output = lambda x: x
|
||||
vae.patcher = object()
|
||||
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called")))
|
||||
|
||||
latent = torch.zeros(1, 16, 8, 8, 16)
|
||||
vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
|
||||
|
||||
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 8, 8, 16)
|
||||
|
||||
|
||||
def test_seedvr2_decode_tiled_does_not_repair_latent_layout(monkeypatch):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = _SeedVR2DecodeStub()
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.latent_channels = 16
|
||||
vae.memory_used_decode = lambda shape, dtype: 1
|
||||
vae.process_output = lambda x: x
|
||||
vae.patcher = object()
|
||||
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_3d", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_3d called")))
|
||||
|
||||
latent = torch.zeros(1, 9, 8, 8, 16)
|
||||
vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
|
||||
|
||||
assert vae.first_stage_model.calls[0]["shape"] == (1, 9, 8, 8, 16)
|
||||
|
||||
|
||||
def test_seedvr2_decode_tiled_routes_collapsed_latents_to_seedvr2_tiler(monkeypatch):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = _SeedVR2DecodeStub()
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.latent_channels = 16
|
||||
vae.memory_used_decode = lambda shape, dtype: 1
|
||||
vae.process_output = lambda x: x
|
||||
vae.patcher = object()
|
||||
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
monkeypatch.setattr(sd_mod.VAE, "decode_tiled_", lambda *a, **k: (_ for _ in ()).throw(AssertionError("generic decode_tiled_ called")))
|
||||
|
||||
latent = torch.zeros(1, 48, 2, 2)
|
||||
vae.decode_tiled(latent, tile_x=2, tile_y=2, overlap=1, tile_t=16, overlap_t=4)
|
||||
|
||||
assert vae.first_stage_model.calls[0]["shape"] == (1, 48, 2, 2)
|
||||
assert vae.first_stage_model.calls[0]["seedvr2_tiling"]["temporal_overlap"] == 16
|
||||
|
||||
|
||||
class _TemporalChunkRecorder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(()))
|
||||
self.device = "cpu"
|
||||
self.spatial_downsample_factor = 1
|
||||
self.temporal_downsample_factor = 4
|
||||
self.chunks = []
|
||||
|
||||
def decode_(self, z):
|
||||
self.chunks.append([int(v) for v in z[0, 0, :, 0, 0].tolist()])
|
||||
pieces = [z[:, :1, :1]]
|
||||
if z.shape[2] > 1:
|
||||
pieces.append(z[:, :1, 1:].repeat_interleave(4, dim=2))
|
||||
return torch.cat(pieces, dim=2)
|
||||
|
||||
|
||||
def test_seedvr2_tiled_vae_decode_uses_single_slicing_call_per_spatial_tile():
|
||||
"""After the temporal-stitching fix, run_temporal_chunks delegates to
|
||||
the wrapper's slicing path with a single decode_ call per spatial tile
|
||||
(rather than the old hand-rolled outer temporal chunking that reset
|
||||
causal cache between chunks). Validate the new contract: recorder sees
|
||||
one call covering the full temporal axis, output shape and value
|
||||
pattern are equivalent to what the temporal-overlap path produced.
|
||||
"""
|
||||
recorder = _TemporalChunkRecorder()
|
||||
latent = torch.arange(6, dtype=torch.float32).view(1, 1, 6, 1, 1)
|
||||
|
||||
out = vae_mod.tiled_vae(
|
||||
latent,
|
||||
recorder,
|
||||
tile_size=(1, 1),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=16,
|
||||
temporal_overlap=4,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert recorder.chunks == [[0, 1, 2, 3, 4, 5]]
|
||||
assert tuple(out.shape) == (1, 1, 21, 1, 1)
|
||||
assert [int(v) for v in out[0, 0, [0, 1, 5, 9, 13, 17], 0, 0].tolist()] == [0, 1, 2, 3, 4, 5]
|
||||
133
tests-unit/comfy_test/test_seedvr_vae_decode_batch_axes.py
Normal file
133
tests-unit/comfy_test/test_seedvr_vae_decode_batch_axes.py
Normal file
@ -0,0 +1,133 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
|
||||
|
||||
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _fingerprint_decode_(self, z, return_dict=True):
|
||||
b = int(z.shape[0])
|
||||
t = int(z.shape[2])
|
||||
h = int(z.shape[3])
|
||||
w = int(z.shape[4])
|
||||
out = torch.empty(b, 3, t, h * 8, w * 8)
|
||||
for batch_idx in range(b):
|
||||
out[batch_idx].fill_(float(batch_idx + 1))
|
||||
return out
|
||||
|
||||
|
||||
def _decode_with_patches(wrapper, z):
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
|
||||
return wrapper.decode(z)
|
||||
|
||||
|
||||
def test_decode_b1_t1_shape_and_ordering_correct():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(1, 16, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 1, 16, 16)
|
||||
assert out[0, 0, 0, 0, 0].item() == 1.0
|
||||
|
||||
|
||||
def test_decode_b1_t5_video_shape_unchanged():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(1, 16 * 5, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 5, 16, 16)
|
||||
|
||||
|
||||
def test_decode_b2_t1_preserves_batch_time_axes():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (2, 3, 1, 16, 16)
|
||||
assert out[0, 0, 0, 0, 0].item() == 1.0
|
||||
assert out[1, 0, 0, 0, 0].item() == 2.0
|
||||
|
||||
|
||||
def test_decode_b4_t1_preserves_batch_time_axes():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(4, 16, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (4, 3, 1, 16, 16)
|
||||
assert [out[b, 0, 0, 0, 0].item() for b in range(4)] == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
|
||||
def test_decode_b2_t3_multi_frame_batch_unchanged():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
||||
|
||||
|
||||
def _tiled_vae_4d_stub(latent, vae_model, **kwargs):
|
||||
b = int(latent.shape[0])
|
||||
h = int(latent.shape[3]) * 8
|
||||
w = int(latent.shape[4]) * 8
|
||||
out = torch.empty(b, 3, h, w)
|
||||
for batch_idx in range(b):
|
||||
out[batch_idx].fill_(float(batch_idx + 1))
|
||||
return out
|
||||
|
||||
|
||||
def test_decode_tiled_single_frame_4d_output_normalized():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
|
||||
out = wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 1, 16, 16)
|
||||
assert out[0, 0, 0, 0, 0].item() == 1.0
|
||||
|
||||
|
||||
def test_decode_tiled_b2_t1_per_sample_ordering():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
|
||||
out = wrapper.decode(torch.zeros(2, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
|
||||
|
||||
assert tuple(out.shape) == (2, 3, 1, 16, 16)
|
||||
assert out[0, 0, 0, 0, 0].item() == 1.0
|
||||
assert out[1, 0, 0, 0, 0].item() == 2.0
|
||||
|
||||
|
||||
def test_decode_b2_t1_stacked_equals_individual_per_sample_ordering():
|
||||
wrapper = _make_wrapper()
|
||||
out_stacked = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
|
||||
|
||||
def _decode_pinned(value):
|
||||
def _stub(self, z, return_dict=True):
|
||||
b = int(z.shape[0])
|
||||
t = int(z.shape[2])
|
||||
h = int(z.shape[3])
|
||||
w = int(z.shape[4])
|
||||
return torch.full((b, 3, t, h * 8, w * 8), value)
|
||||
return _stub
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(1.0)):
|
||||
out_individual_0 = wrapper.decode(torch.zeros(1, 16, 2, 2))
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(2.0)):
|
||||
out_individual_1 = wrapper.decode(torch.zeros(1, 16, 2, 2))
|
||||
|
||||
assert torch.equal(out_stacked[0, :, 0, :, :], out_individual_0[0, :, 0, :, :])
|
||||
assert torch.equal(out_stacked[1, :, 0, :, :], out_individual_1[0, :, 0, :, :])
|
||||
85
tests-unit/comfy_test/test_seedvr_vae_decode_guards.py
Normal file
85
tests-unit/comfy_test/test_seedvr_vae_decode_guards.py
Normal file
@ -0,0 +1,85 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
|
||||
|
||||
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self.calls = []
|
||||
|
||||
def parameters(self):
|
||||
return iter([torch.nn.Parameter(torch.zeros(()))])
|
||||
|
||||
def _decode_stub(self, latent):
|
||||
self.calls.append(tuple(latent.shape))
|
||||
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
||||
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
||||
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_accepts_collapsed_4d_latents_without_preprocessor_state():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
||||
out = wrapper.decode(torch.zeros(1, 32, 4, 5))
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
||||
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_accepts_noncontiguous_collapsed_4d_latents():
|
||||
wrapper = _Wrapper()
|
||||
latent = torch.zeros(1, 4, 5, 32).permute(0, 3, 1, 2)
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
||||
out = wrapper.decode(latent)
|
||||
|
||||
assert not latent.is_contiguous()
|
||||
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
||||
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_rejects_non_dict_tiling_options():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with pytest.raises(RuntimeError, match="seedvr2_tiling.*dict"):
|
||||
wrapper.decode(torch.zeros(1, 16, 2, 4, 5), seedvr2_tiling=True)
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_rejects_wrong_5d_channel_count():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with pytest.raises(RuntimeError, match="5-D latent input must have 16 channels"):
|
||||
wrapper.decode(torch.zeros(1, 8, 2, 4, 5))
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_rejects_misaligned_collapsed_4d_latents():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"4-D latent input must use collapsed channel layout"):
|
||||
wrapper.decode(torch.zeros(1, 17, 4, 5))
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
|
||||
wrapper.decode(torch.zeros(1, 16, 4))
|
||||
35
tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py
Normal file
35
tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def _t_padded(t_in: int) -> int:
|
||||
if t_in == 1:
|
||||
return 1
|
||||
if t_in <= 4:
|
||||
return 5
|
||||
if (t_in - 1) % 4 == 0:
|
||||
return t_in
|
||||
return t_in + (4 - ((t_in - 1) % 4))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("t_in", [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
def test_t_padded_matches_cut_videos(t_in):
|
||||
dummy = torch.zeros(1, t_in, 1, 1, 1)
|
||||
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("t_in", [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
def test_post_processing_trims_decoded_video_to_explicit_reference_frames(t_in):
|
||||
decoded = torch.zeros(1, _t_padded(t_in), 32, 32, 3)
|
||||
original = torch.zeros(1, t_in, 32, 32, 3)
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 32, "none").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, t_in, 32, 32, 3)
|
||||
165
tests-unit/comfy_test/test_seedvr_vae_loader_metadata.py
Normal file
165
tests-unit/comfy_test/test_seedvr_vae_loader_metadata.py
Normal file
@ -0,0 +1,165 @@
|
||||
"""Regression test for ``comfy/sd.py``'s ``VAE.__init__`` loader — must
|
||||
apply SeedVR2-specific metadata when the SeedVR2 magic key
|
||||
``decoder.up_blocks.2.upsamplers.0.upscale_conv.weight`` is present in the
|
||||
state dict.
|
||||
|
||||
Without the SeedVR2 elif branch the loader leaves ``latent_channels=4`` /
|
||||
``latent_dim=2`` defaults, so down-stream consumers mis-shape the latent
|
||||
buffer and crash with a channel-count mismatch. The expected behaviour
|
||||
sets ``latent_channels=16``, ``latent_dim=3``, ``disable_offload=True``,
|
||||
``downscale_index_formula=(4, 8, 8)``, ``upscale_index_formula=(4, 8,
|
||||
8)``, plus the SeedVR2 ``memory_used_decode`` / ``memory_used_encode``
|
||||
lambdas, the ``downscale_ratio`` / ``upscale_ratio`` tuples, and the
|
||||
SeedVR2 ``process_input`` / ``crop_input=False`` overrides.
|
||||
|
||||
This module exercises the real ``VAE.__init__`` detection-and-load path
|
||||
with a stubbed state dict containing only the SeedVR2 magic key, and
|
||||
patches ``comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper`` with a tiny
|
||||
``nn.Module`` subclass so the test stays CPU-only and weight-load-free
|
||||
while still satisfying ``isinstance(...)`` against the real wrapper class
|
||||
(see ``_StubVideoAutoencoderKLWrapper`` below).
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# CPU-only CI fix: ``comfy.sd`` transitively imports
|
||||
# ``comfy.model_management``, whose import-time
|
||||
# ``cpu_state = CPUState.CPU if args.cpu`` initialiser reads
|
||||
# ``comfy.cli_args.args.cpu``. Match the pattern at
|
||||
# ``tests-unit/comfy_test/test_seedvr_vae_decode_unpadded_t.py:33-44``: flip
|
||||
# ``args.cpu`` BEFORE importing any ``comfy.sd`` / ``comfy.ldm.*`` symbol
|
||||
# when CUDA is unavailable. Issue-191 AC-3 additionally requires the
|
||||
# ``_cli_args.cpu = True`` assignment line number to precede every line
|
||||
# matching ``^import comfy`` or ``^from comfy`` in the committed file, so
|
||||
# the cli_args module is loaded via ``importlib`` here rather than via
|
||||
# ``from comfy.cli_args import args``.
|
||||
import importlib
|
||||
|
||||
_cli_args = importlib.import_module("comfy.cli_args").args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
_cli_args.cpu = True
|
||||
|
||||
import torch.nn as nn # noqa: E402
|
||||
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae # noqa: E402
|
||||
import comfy.sd # noqa: E402
|
||||
|
||||
|
||||
_SEEDVR2_MAGIC_KEY = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight"
|
||||
|
||||
|
||||
class _StubVideoAutoencoderKLWrapper(seedvr_vae.VideoAutoencoderKLWrapper):
|
||||
"""Subclass that bypasses the real wrapper's heavy weight construction.
|
||||
|
||||
The downstream ``comfy.sd.VAE.__init__`` lifecycle after line 519 only
|
||||
relies on ``nn.Module`` machinery — ``.eval()``, ``.to(dtype)``,
|
||||
``state_dict()`` for ``module_size``, and
|
||||
``load_state_dict(strict=False)``. A bare ``nn.Module.__init__`` provides
|
||||
all of that. Subclassing ``VideoAutoencoderKLWrapper`` keeps
|
||||
``isinstance(stub_instance, VideoAutoencoderKLWrapper)`` ``True`` after
|
||||
the patch context exits, so the AC-A isinstance assertion holds against
|
||||
the real wrapper class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
|
||||
def _build_seedvr2_stub_sd():
|
||||
"""Minimum state dict that triggers the SeedVR2 elif branch in
|
||||
``comfy/sd.py``. The detection is a pure ``in sd`` containment check
|
||||
against the magic key at line 518; no other key is required to reach
|
||||
that branch (the diffusers-convert early-out at lines 444-446 is
|
||||
short-circuited by the ``is_seedvr2_vae`` flag set at line 443).
|
||||
|
||||
The ``load_state_dict`` call at line 884 uses ``strict=False`` so the
|
||||
single magic key is accepted as ``unexpected`` against the empty stub
|
||||
module without raising.
|
||||
"""
|
||||
return {_SEEDVR2_MAGIC_KEY: torch.zeros(1)}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def seedvr2_vae():
|
||||
"""Build a real ``comfy.sd.VAE`` instance through the detection-and-load
|
||||
path with the SeedVR2 wrapper class stubbed for CPU-only execution.
|
||||
"""
|
||||
sd = _build_seedvr2_stub_sd()
|
||||
with patch.object(
|
||||
seedvr_vae,
|
||||
"VideoAutoencoderKLWrapper",
|
||||
_StubVideoAutoencoderKLWrapper,
|
||||
):
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
return vae
|
||||
|
||||
|
||||
def test_seedvr2_loader_first_stage_model_is_video_autoencoder_kl_wrapper(
|
||||
seedvr2_vae,
|
||||
):
|
||||
assert isinstance(
|
||||
seedvr2_vae.first_stage_model, seedvr_vae.VideoAutoencoderKLWrapper
|
||||
) is True, (
|
||||
"Expected first_stage_model to be a VideoAutoencoderKLWrapper "
|
||||
f"instance; got {type(seedvr2_vae.first_stage_model).__name__}. The "
|
||||
"SeedVR2 elif branch at comfy/sd.py:518 may not have been taken."
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_loader_sets_latent_channels_16(seedvr2_vae):
|
||||
assert seedvr2_vae.latent_channels == 16, (
|
||||
"Expected latent_channels=16 (set at comfy/sd.py:520 inside the "
|
||||
f"SeedVR2 elif branch); got {seedvr2_vae.latent_channels}. SeedVR2's "
|
||||
"VideoAutoencoderKL uses 16-channel latents per Wang et al., ICLR "
|
||||
"2026 (arXiv 2506.05301) §3; the loader default of 4 (comfy/sd.py:457)"
|
||||
" is wrong for the SeedVR2 path."
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_loader_sets_latent_dim_3(seedvr2_vae):
|
||||
assert seedvr2_vae.latent_dim == 3, (
|
||||
"Expected latent_dim=3 (set at comfy/sd.py:521 inside the SeedVR2 "
|
||||
f"elif branch); got {seedvr2_vae.latent_dim}. SeedVR2 latents are 3D "
|
||||
"(T, H, W) per the upstream ByteDance-Seed/SeedVR "
|
||||
"VideoAutoencoderKL contract; the loader default of 2 "
|
||||
"(comfy/sd.py:458) is wrong for the SeedVR2 path."
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_loader_sets_downscale_index_formula(seedvr2_vae):
|
||||
assert seedvr2_vae.downscale_index_formula == (4, 8, 8), (
|
||||
"Expected downscale_index_formula=(4, 8, 8) (set at "
|
||||
f"comfy/sd.py:527); got {seedvr2_vae.downscale_index_formula}. "
|
||||
"SeedVR2's spatial-temporal downscale ratio is 4× temporal × 8× "
|
||||
"spatial × 8× spatial."
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_loader_sets_upscale_index_formula(seedvr2_vae):
|
||||
assert seedvr2_vae.upscale_index_formula == (4, 8, 8), (
|
||||
"Expected upscale_index_formula=(4, 8, 8) (set at "
|
||||
f"comfy/sd.py:529); got {seedvr2_vae.upscale_index_formula}. "
|
||||
"SeedVR2's spatial-temporal upscale ratio is the inverse of its "
|
||||
"downscale ratio: 4× temporal × 8× spatial × 8× spatial."
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_loader_sets_disable_offload(seedvr2_vae):
|
||||
assert seedvr2_vae.disable_offload is True, (
|
||||
"Expected disable_offload=True (set at comfy/sd.py:522); got "
|
||||
f"{seedvr2_vae.disable_offload}. SeedVR2 cannot tolerate CPU "
|
||||
"offload during decode (the wrapper retains memory-state references "
|
||||
"across slice boundaries — see VideoAutoencoderKL.slicing_decode)."
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_loader_normalizes_comfy_pixels_at_vae_boundary(seedvr2_vae):
|
||||
pixels = torch.tensor([0.0, 0.5, 1.0])
|
||||
|
||||
normalized = seedvr2_vae.process_input(pixels)
|
||||
|
||||
assert torch.equal(normalized, torch.tensor([-1.0, 0.0, 1.0]))
|
||||
@ -0,0 +1,11 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_seedvr_vae_decode_uses_explicit_tiling_options_not_object_state():
|
||||
path = Path(__file__).resolve().parents[2] / "comfy" / "ldm" / "seedvr" / "vae.py"
|
||||
src = path.read_text(encoding="utf-8")
|
||||
assert not re.search(r"(?:self\.)?tiled_args\b", src), (
|
||||
"VideoAutoencoderKLWrapper.decode must not read or mutate tiled_args "
|
||||
f"object state. Source path: {path}"
|
||||
)
|
||||
78
tests-unit/comfy_test/test_seedvr_vae_tiled_decode_5d.py
Normal file
78
tests-unit/comfy_test/test_seedvr_vae_tiled_decode_5d.py
Normal file
@ -0,0 +1,78 @@
|
||||
from copy import deepcopy
|
||||
|
||||
def _valid_probe_payload():
|
||||
sha = "0" * 64
|
||||
return {
|
||||
"torch_equal": True,
|
||||
"non_tiled_sha256": sha,
|
||||
"tiled_sha256": sha,
|
||||
"dtype": "torch.float16",
|
||||
"source_frames": 32,
|
||||
"temporal_tile_size": 16,
|
||||
"temporal_overlap": 4,
|
||||
"generic_fallback_used": False,
|
||||
}
|
||||
|
||||
|
||||
def _assert_real_probe_json_contract(payload):
|
||||
required = {
|
||||
"torch_equal",
|
||||
"non_tiled_sha256",
|
||||
"tiled_sha256",
|
||||
"dtype",
|
||||
"source_frames",
|
||||
"temporal_tile_size",
|
||||
"temporal_overlap",
|
||||
"generic_fallback_used",
|
||||
}
|
||||
missing = required.difference(payload)
|
||||
if missing:
|
||||
raise AssertionError(f"missing keys: {sorted(missing)}")
|
||||
if payload["torch_equal"] is not True:
|
||||
raise AssertionError("torch_equal must be true")
|
||||
if payload["non_tiled_sha256"] != payload["tiled_sha256"]:
|
||||
raise AssertionError("tensor sha256 values must match")
|
||||
if payload["dtype"] != "torch.float16":
|
||||
raise AssertionError("dtype must be torch.float16")
|
||||
if payload["source_frames"] != 32:
|
||||
raise AssertionError("source_frames must be 32")
|
||||
if payload["temporal_tile_size"] != 16:
|
||||
raise AssertionError("temporal_tile_size must be 16")
|
||||
if payload["temporal_overlap"] != 4:
|
||||
raise AssertionError("temporal_overlap must be 4")
|
||||
if payload["generic_fallback_used"] is not False:
|
||||
raise AssertionError("generic_fallback_used must be false")
|
||||
|
||||
|
||||
def test_real_probe_json_contract():
|
||||
valid = _valid_probe_payload()
|
||||
_assert_real_probe_json_contract(valid)
|
||||
|
||||
for key in valid:
|
||||
missing = deepcopy(valid)
|
||||
missing.pop(key)
|
||||
try:
|
||||
_assert_real_probe_json_contract(missing)
|
||||
except AssertionError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"accepted payload missing {key}")
|
||||
|
||||
invalid_values = {
|
||||
"torch_equal": False,
|
||||
"tiled_sha256": "1" * 64,
|
||||
"dtype": "torch.float32",
|
||||
"source_frames": 31,
|
||||
"temporal_tile_size": 8,
|
||||
"temporal_overlap": 0,
|
||||
"generic_fallback_used": True,
|
||||
}
|
||||
for key, value in invalid_values.items():
|
||||
invalid = deepcopy(valid)
|
||||
invalid[key] = value
|
||||
try:
|
||||
_assert_real_probe_json_contract(invalid)
|
||||
except AssertionError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError(f"accepted payload with invalid {key}")
|
||||
@ -0,0 +1,86 @@
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
|
||||
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
|
||||
|
||||
class StubVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = 2
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.decode_min_sizes = []
|
||||
self.memory_states = []
|
||||
|
||||
def decode_(self, t_chunk):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
b, c, d, h, w = z.shape
|
||||
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
|
||||
|
||||
vae = StubVAEModel()
|
||||
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert vae.decode_min_sizes == [5]
|
||||
assert vae.memory_states == [MemoryState.DISABLED]
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
|
||||
|
||||
def test_runtime_decode_preserves_min_size_when_decode_raises():
|
||||
from comfy.ldm.seedvr.vae import tiled_vae
|
||||
|
||||
class RaisingVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = 2
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
|
||||
def decode_(self, t_chunk):
|
||||
raise RuntimeError("simulated decode failure")
|
||||
|
||||
vae = RaisingVAEModel()
|
||||
z = torch.zeros((1, 16, 4, 8, 8), dtype=torch.float32)
|
||||
|
||||
raised = False
|
||||
try:
|
||||
tiled_vae(
|
||||
z,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=False,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if "simulated decode failure" not in str(exc):
|
||||
raise
|
||||
raised = True
|
||||
|
||||
assert raised
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
@ -0,0 +1,89 @@
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
|
||||
def test_slicing_encode_merges_runt_active_tail():
|
||||
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
|
||||
|
||||
class StubVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_sample_min_size = 4
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.memory_states = []
|
||||
self.encode_t = []
|
||||
|
||||
def encode(self, t_chunk):
|
||||
h = VideoAutoencoderKL.slicing_encode(self, t_chunk)
|
||||
return (h, h)
|
||||
|
||||
def _encode(self, x, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
self.encode_t.append(x.shape[2])
|
||||
b, c, t_in, h, w = x.shape
|
||||
target_d = max(1, (t_in + self.temporal_downsample_factor - 1) // self.temporal_downsample_factor)
|
||||
target_h = (h + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
|
||||
target_w = (w + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
|
||||
return torch.zeros((b, 16, target_d, target_h, target_w), dtype=x.dtype)
|
||||
|
||||
vae = StubVAEModel()
|
||||
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
|
||||
|
||||
tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=None,
|
||||
encode=True,
|
||||
)
|
||||
|
||||
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
|
||||
assert vae.encode_t == [5, 7]
|
||||
assert min(vae.encode_t[1:]) >= vae.temporal_downsample_factor
|
||||
|
||||
|
||||
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||
from comfy.ldm.seedvr.vae import tiled_vae
|
||||
|
||||
class RaisingVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_sample_min_size = 4
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
|
||||
def encode(self, t_chunk):
|
||||
raise RuntimeError("simulated encode failure")
|
||||
|
||||
vae = RaisingVAEModel()
|
||||
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
|
||||
|
||||
raised = False
|
||||
try:
|
||||
tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if "simulated encode failure" not in str(exc):
|
||||
raise
|
||||
raised = True
|
||||
|
||||
assert raised
|
||||
assert vae.slicing_sample_min_size == 4
|
||||
232
tests-unit/comfy_test/test_seedvr_vae_tiled_temporal_slicing.py
Normal file
232
tests-unit/comfy_test/test_seedvr_vae_tiled_temporal_slicing.py
Normal file
@ -0,0 +1,232 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
|
||||
|
||||
|
||||
class _SlicingDecodeVAE(nn.Module):
|
||||
def __init__(self, slicing_latent_min_size):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = slicing_latent_min_size
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.decode_min_sizes = []
|
||||
self.memory_states = []
|
||||
|
||||
def decode_(self, z):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
x = z[:, :1].repeat(
|
||||
1,
|
||||
3,
|
||||
1,
|
||||
self.spatial_downsample_factor,
|
||||
self.spatial_downsample_factor,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class _EncodeVAE(nn.Module):
|
||||
def __init__(self, slicing_sample_min_size):
|
||||
super().__init__()
|
||||
self.slicing_sample_min_size = slicing_sample_min_size
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.memory_states = []
|
||||
self.encoded_t = []
|
||||
self.encode_min_sizes = []
|
||||
|
||||
def encode(self, t_chunk):
|
||||
self.encode_min_sizes.append(self.slicing_sample_min_size)
|
||||
h = vae_mod.VideoAutoencoderKL.slicing_encode(self, t_chunk)
|
||||
return (h, h)
|
||||
|
||||
def _encode(self, x, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
self.encoded_t.append(x.shape[2])
|
||||
b, c, t_in, h, w = x.shape
|
||||
target_d = max(1, (t_in + self.temporal_downsample_factor - 1) // self.temporal_downsample_factor)
|
||||
target_h = (h + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
|
||||
target_w = (w + self.spatial_downsample_factor - 1) // self.spatial_downsample_factor
|
||||
z = torch.zeros((b, 16, target_d, target_h, target_w), dtype=x.dtype)
|
||||
return z
|
||||
|
||||
|
||||
class _LocalSpatialDecodeVAE(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = 99
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.tile_shapes = []
|
||||
|
||||
def decode_(self, z):
|
||||
self.tile_shapes.append(tuple(z.shape))
|
||||
b, _, t, h, w = z.shape
|
||||
width = w * self.spatial_downsample_factor
|
||||
local_x = torch.arange(width, dtype=z.dtype).view(1, 1, 1, 1, width)
|
||||
return local_x.expand(
|
||||
b,
|
||||
1,
|
||||
t,
|
||||
h * self.spatial_downsample_factor,
|
||||
width,
|
||||
).clone()
|
||||
|
||||
|
||||
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
||||
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
|
||||
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=12,
|
||||
temporal_overlap=4,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert vae.decode_min_sizes == [2]
|
||||
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
seedvr2_tiling = {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (64, 64),
|
||||
"tile_overlap": (0, 0),
|
||||
"temporal_size": 8,
|
||||
"temporal_overlap": 7,
|
||||
}
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_tiled_vae(latent, model, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return torch.zeros(1, 3, 1, 16, 16)
|
||||
|
||||
with (
|
||||
patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae),
|
||||
patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content),
|
||||
):
|
||||
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||
|
||||
assert captured["temporal_overlap"] == 7
|
||||
|
||||
|
||||
def test_encode_tiled_vae_zero_temporal_size_disables_wrapper_slicing():
|
||||
vae = _EncodeVAE(slicing_sample_min_size=4)
|
||||
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
|
||||
|
||||
tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
)
|
||||
|
||||
assert vae.encode_min_sizes == [12]
|
||||
assert vae.memory_states == [MemoryState.DISABLED]
|
||||
assert vae.encoded_t == [12]
|
||||
assert vae.slicing_sample_min_size == 4
|
||||
|
||||
|
||||
def test_encode_tiled_vae_maps_temporal_args_to_sample_slicing_min_size():
|
||||
vae = _EncodeVAE(slicing_sample_min_size=4)
|
||||
x = torch.zeros((1, 3, 14, 64, 64), dtype=torch.float32)
|
||||
|
||||
tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=8,
|
||||
temporal_overlap=2,
|
||||
encode=True,
|
||||
)
|
||||
|
||||
assert vae.encode_min_sizes == [6]
|
||||
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
|
||||
assert vae.encoded_t == [7, 7]
|
||||
assert vae.slicing_sample_min_size == 4
|
||||
|
||||
|
||||
def test_boundary_reference_latent_no_periodic_temporal_tile_discontinuity():
|
||||
z = torch.arange(1 * 16 * 7 * 8 * 8, dtype=torch.float32).reshape(1, 16, 7, 8, 8)
|
||||
|
||||
reference_vae = _SlicingDecodeVAE(slicing_latent_min_size=3)
|
||||
expected = reference_vae.decode_(z)
|
||||
|
||||
tiled_vae_model = _SlicingDecodeVAE(slicing_latent_min_size=3)
|
||||
actual = tiled_vae(
|
||||
z,
|
||||
tiled_vae_model,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
assert tiled_vae_model.decode_min_sizes == [7]
|
||||
assert tiled_vae_model.memory_states == [MemoryState.DISABLED]
|
||||
assert tiled_vae_model.slicing_latent_min_size == 3
|
||||
|
||||
spatial_vae = _LocalSpatialDecodeVAE()
|
||||
spatial = tiled_vae(
|
||||
torch.zeros(1, 16, 1, 8, 12),
|
||||
spatial_vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 32),
|
||||
encode=False,
|
||||
)
|
||||
ramp = 0.5 - 0.5 * torch.cos(torch.linspace(0, 1, steps=32) * torch.pi)
|
||||
expected = (36.0 * (1.0 - ramp[4])) + (4.0 * ramp[4])
|
||||
|
||||
assert spatial_vae.tile_shapes == [
|
||||
(1, 16, 1, 8, 8),
|
||||
(1, 16, 1, 8, 8),
|
||||
]
|
||||
assert torch.isclose(spatial[0, 0, 0, 0, 36], expected)
|
||||
|
||||
|
||||
def test_decode_tiled_vae_clamps_overlap_sized_tiles_to_preserve_coverage():
|
||||
spatial_vae = _LocalSpatialDecodeVAE()
|
||||
spatial = tiled_vae(
|
||||
torch.zeros(1, 16, 1, 8, 12),
|
||||
spatial_vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 128),
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert len(spatial_vae.tile_shapes) > 1
|
||||
assert torch.count_nonzero(spatial[0, 0, 0, 0, 64:]) > 0
|
||||
@ -0,0 +1,165 @@
|
||||
"""Unit test for the ``VAE.decode`` tiled-fallback dispatcher routing of
|
||||
SeedVR2 latents in their 4D collapsed form ``(B, 16*T, H, W)``.
|
||||
|
||||
Regression: the dispatcher branch at ``comfy/sd.py``'s
|
||||
``VAE.decode -> if do_tile: ... elif dims == 2`` previously routed
|
||||
``ndim == 4`` SeedVR2 latents to the generic ``decode_tiled_``, whose
|
||||
``tiled_scale`` mask broadcast does not understand the
|
||||
``(16, T)`` channel-time collapse and crashed with
|
||||
``"The size of tensor a (1024) must match the size of tensor b (256)
|
||||
at non-singleton dimension 4"``.
|
||||
|
||||
Post-fix: when the wrapped model is a
|
||||
``comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper`` and the input is 4D,
|
||||
the dispatcher must route to ``decode_tiled_seedvr2`` instead. This
|
||||
test verifies the dispatcher selection without invoking the actual VAE
|
||||
math (which would require real model weights and a GPU): the two
|
||||
candidate methods are patched, the regular decode is forced to OOM via
|
||||
a stub, and the test asserts that ``decode_tiled_seedvr2`` is called
|
||||
exactly once (and ``decode_tiled_`` zero times) for a 4D SeedVR2
|
||||
input.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
|
||||
|
||||
def _make_minimal_seedvr2_vae():
|
||||
"""Construct a ``comfy.sd.VAE`` instance whose ``first_stage_model``
|
||||
is a real ``VideoAutoencoderKLWrapper`` (built via ``__new__`` to
|
||||
skip weight allocation), with the VAE's other attributes stubbed
|
||||
to the minimum that ``VAE.decode``'s regular-decode setup path
|
||||
requires before the OOM forced fallback.
|
||||
"""
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
vae.first_stage_model = wrapper
|
||||
|
||||
# Minimum surface that ``VAE.decode`` touches before tiled fallback:
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = 8
|
||||
vae.upscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3 # SeedVR2 is a 3D-temporal latent format (T, H, W)
|
||||
vae.downscale_ratio = 8
|
||||
vae.downscale_index_formula = None
|
||||
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_decode = lambda: 8
|
||||
vae.process_input = lambda x: x
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_decode = lambda *a, **k: 1
|
||||
return vae
|
||||
|
||||
|
||||
def _force_regular_decode_oom(*args, **kwargs):
|
||||
"""Stub ``first_stage_model.decode`` to raise an OOM-shaped error
|
||||
so ``VAE.decode``'s ``except`` branch sets ``do_tile = True`` and
|
||||
falls into the tiled-fallback dispatcher.
|
||||
"""
|
||||
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||
|
||||
|
||||
def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2():
|
||||
vae = _make_minimal_seedvr2_vae()
|
||||
samples_4d = torch.zeros(1, 16 * 3, 8, 8) # (B, 16*T, H, W), T=3
|
||||
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||
lambda: None), \
|
||||
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
|
||||
side_effect=_force_regular_decode_oom), \
|
||||
patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call), \
|
||||
patch.object(sd_mod.VAE, "decode_tiled_", generic_call):
|
||||
vae.decode(samples_4d)
|
||||
|
||||
assert seedvr2_call.call_count == 1, (
|
||||
f"Expected decode_tiled_seedvr2 to be called once for a 4D SeedVR2 "
|
||||
f"latent under tiled fallback; got {seedvr2_call.call_count} calls."
|
||||
)
|
||||
assert generic_call.call_count == 0, (
|
||||
f"decode_tiled_ must NOT be called for a 4D SeedVR2 latent; got "
|
||||
f"{generic_call.call_count} calls. Pre-fix dispatcher would route "
|
||||
f"to this method and crash inside tiled_scale's mask broadcast."
|
||||
)
|
||||
|
||||
|
||||
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
|
||||
"""The dispatcher fix must NOT affect non-SeedVR2 4D latents: any
|
||||
other VAE whose ``first_stage_model`` is not a
|
||||
``VideoAutoencoderKLWrapper`` continues to route to the generic
|
||||
``decode_tiled_``.
|
||||
"""
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = MagicMock() # NOT a VideoAutoencoderKLWrapper
|
||||
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = 8
|
||||
vae.upscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = 4
|
||||
vae.latent_dim = 2
|
||||
vae.downscale_ratio = 8
|
||||
vae.downscale_index_formula = None
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_decode = lambda: 8
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_decode = lambda *a, **k: 1
|
||||
vae.first_stage_model.decode = MagicMock(
|
||||
side_effect=_force_regular_decode_oom
|
||||
)
|
||||
|
||||
samples_4d = torch.zeros(1, 4, 8, 8)
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||
lambda: None), \
|
||||
patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call), \
|
||||
patch.object(sd_mod.VAE, "decode_tiled_", generic_call):
|
||||
vae.decode(samples_4d)
|
||||
|
||||
assert generic_call.call_count == 1, (
|
||||
f"Expected decode_tiled_ to be called once for a non-SeedVR2 4D "
|
||||
f"latent; got {generic_call.call_count} calls."
|
||||
)
|
||||
assert seedvr2_call.call_count == 0, (
|
||||
f"decode_tiled_seedvr2 must NOT be called for non-SeedVR2 latents; "
|
||||
f"got {seedvr2_call.call_count} calls."
|
||||
)
|
||||
@ -0,0 +1,119 @@
|
||||
"""Unit tests for the explicit ``VAE.encode_tiled`` dispatcher routing of
|
||||
SeedVR2 vs non-SeedVR2 3D inputs.
|
||||
|
||||
Mirrors the decode-side dispatcher contract in
|
||||
``test_vae_decode_tiled_dispatcher_seedvr2_4d.py`` and the encode OOM
|
||||
fallback contract in ``test_vae_encode_tiled_fallback_dispatcher_seedvr2.py``:
|
||||
the two candidate methods (``encode_tiled_seedvr2``, ``encode_tiled_3d``)
|
||||
are patched on the ``VAE`` class, ``encode_tiled`` is invoked directly,
|
||||
and the test asserts the dispatcher selects the SeedVR2-aware tiler when
|
||||
``first_stage_model`` is a ``VideoAutoencoderKLWrapper`` while preserving
|
||||
the generic 3D tiler for non-SeedVR2 inputs.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
|
||||
|
||||
def _populate_common_vae_attrs(vae):
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = [lambda x: x]
|
||||
vae.upscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = [lambda x: x]
|
||||
vae.downscale_index_formula = None
|
||||
vae.not_video = False
|
||||
vae.crop_input = False
|
||||
vae.pad_channel_value = None
|
||||
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_encode = lambda: 8
|
||||
vae.vae_encode_crop_pixels = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_encode = lambda *a, **k: 1
|
||||
|
||||
|
||||
def _make_seedvr2_vae():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
vae.first_stage_model = wrapper
|
||||
_populate_common_vae_attrs(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def _make_non_seedvr2_vae():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = MagicMock()
|
||||
_populate_common_vae_attrs(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def test_explicit_encode_tiled_seedvr2_3d_routes_to_seedvr2_tiler():
|
||||
vae = _make_seedvr2_vae()
|
||||
pixel_samples = torch.zeros((1, 64, 64, 3))
|
||||
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
|
||||
with patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
|
||||
create=True), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
vae.encode_tiled(pixel_samples)
|
||||
|
||||
assert seedvr2_call.call_count == 1, (
|
||||
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
|
||||
f"input via explicit encode_tiled; got {seedvr2_call.call_count} calls."
|
||||
)
|
||||
assert generic_call.call_count == 0, (
|
||||
f"encode_tiled_3d must NOT be called for a SeedVR2 input via explicit "
|
||||
f"encode_tiled; got {generic_call.call_count} calls."
|
||||
)
|
||||
|
||||
|
||||
def test_explicit_encode_tiled_dispatcher_breakdown():
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
|
||||
seedvr2_vae = _make_seedvr2_vae()
|
||||
non_seedvr2_vae = _make_non_seedvr2_vae()
|
||||
|
||||
pixel_samples = torch.zeros((1, 64, 64, 3))
|
||||
|
||||
with patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
|
||||
create=True), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
seedvr2_vae.encode_tiled(pixel_samples)
|
||||
non_seedvr2_vae.encode_tiled(pixel_samples)
|
||||
|
||||
assert seedvr2_call.call_count == 1, (
|
||||
f"Expected encode_tiled_seedvr2 called once across SeedVR2 + "
|
||||
f"non-SeedVR2 explicit encode_tiled calls; got "
|
||||
f"{seedvr2_call.call_count}."
|
||||
)
|
||||
assert generic_call.call_count == 1, (
|
||||
f"Expected encode_tiled_3d called once across SeedVR2 + non-SeedVR2 "
|
||||
f"explicit encode_tiled calls; got {generic_call.call_count}."
|
||||
)
|
||||
@ -0,0 +1,184 @@
|
||||
"""Unit tests for the ``VAE.encode`` OOM-fallback dispatcher routing of
|
||||
SeedVR2 vs non-SeedVR2 3D inputs.
|
||||
|
||||
Mirrors the decode-side dispatcher contract in
|
||||
``test_vae_decode_tiled_dispatcher_seedvr2_4d.py``: the two candidate
|
||||
methods (``encode_tiled_seedvr2``, ``encode_tiled_3d``) are patched on
|
||||
the ``VAE`` class, the regular encode is forced to OOM via a stub, and
|
||||
the test asserts the dispatcher selects the SeedVR2-aware tiler when
|
||||
``first_stage_model`` is a ``VideoAutoencoderKLWrapper`` while
|
||||
preserving the generic 3D tiler for non-SeedVR2 inputs.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
|
||||
|
||||
def _populate_common_vae_attrs(vae):
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = 8
|
||||
vae.upscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = 8
|
||||
vae.downscale_index_formula = None
|
||||
vae.not_video = False
|
||||
vae.crop_input = False
|
||||
vae.pad_channel_value = None
|
||||
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_encode = lambda: 8
|
||||
vae.process_input = lambda x: x
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_encode = lambda *a, **k: 1
|
||||
|
||||
|
||||
def _make_seedvr2_vae():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
vae.first_stage_model = wrapper
|
||||
_populate_common_vae_attrs(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def _make_non_seedvr2_vae():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = MagicMock()
|
||||
_populate_common_vae_attrs(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def _force_regular_encode_oom(*args, **kwargs):
|
||||
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||
|
||||
|
||||
def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom():
|
||||
vae = _make_seedvr2_vae()
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||
lambda: None), \
|
||||
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
|
||||
side_effect=_force_regular_encode_oom), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
|
||||
create=True), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
vae.encode(pixel_samples)
|
||||
|
||||
assert seedvr2_call.call_count == 1, (
|
||||
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
|
||||
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
|
||||
)
|
||||
assert generic_call.call_count == 0, (
|
||||
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
|
||||
f"{generic_call.call_count} calls."
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_oom_fallback_uses_explicit_seedvr2_tile_defaults():
|
||||
vae = _make_seedvr2_vae()
|
||||
vae.first_stage_model.tiled_args = {
|
||||
"tile_size": (128, 128),
|
||||
"tile_overlap": (32, 32),
|
||||
"temporal_size": 12,
|
||||
"temporal_overlap": 4,
|
||||
}
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||
lambda: None), \
|
||||
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
|
||||
side_effect=_force_regular_encode_oom), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
|
||||
create=True):
|
||||
vae.encode(pixel_samples)
|
||||
|
||||
assert seedvr2_call.call_count == 1
|
||||
assert seedvr2_call.call_args.kwargs == {
|
||||
"tile_x": 256,
|
||||
"tile_y": 256,
|
||||
"overlap": 64,
|
||||
}
|
||||
|
||||
|
||||
def test_oom_fallback_dispatcher_breakdown():
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
|
||||
seedvr2_vae = _make_seedvr2_vae()
|
||||
non_seedvr2_vae = _make_non_seedvr2_vae()
|
||||
non_seedvr2_vae.first_stage_model.encode = MagicMock(
|
||||
side_effect=_force_regular_encode_oom
|
||||
)
|
||||
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||
lambda: None), \
|
||||
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
|
||||
side_effect=_force_regular_encode_oom), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
|
||||
create=True), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
seedvr2_vae.encode(pixel_samples)
|
||||
non_seedvr2_vae.encode(pixel_samples)
|
||||
|
||||
assert seedvr2_call.call_count == 1, (
|
||||
f"Expected encode_tiled_seedvr2 called once across SeedVR2 + "
|
||||
f"non-SeedVR2 OOM fallbacks; got {seedvr2_call.call_count}."
|
||||
)
|
||||
assert generic_call.call_count == 1, (
|
||||
f"Expected encode_tiled_3d called once across SeedVR2 + non-SeedVR2 "
|
||||
f"OOM fallbacks; got {generic_call.call_count}."
|
||||
)
|
||||
|
||||
|
||||
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
|
||||
vae = _make_non_seedvr2_vae()
|
||||
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
|
||||
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
with patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
vae.encode_tiled(pixel_samples)
|
||||
|
||||
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)
|
||||
205
tests-unit/comfy_test/test_vae_encode_tiled_seedvr2_method.py
Normal file
205
tests-unit/comfy_test/test_vae_encode_tiled_seedvr2_method.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""Unit tests for ``VAE.encode_tiled_seedvr2``: existence with the
|
||||
SeedVR2 tile-shape signature and delegation through
|
||||
``comfy.ldm.seedvr.vae.tiled_vae(..., encode=True)`` with one call per
|
||||
spatial tile.
|
||||
|
||||
Mirrors the decode-side method-existence + delegation contract for
|
||||
``VAE.decode_tiled_seedvr2``; CPU-only via mocks and a
|
||||
``VideoAutoencoderKLWrapper.__new__`` wrapper stub (no weights, no
|
||||
GPU).
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
import nodes as nodes_mod # noqa: E402
|
||||
|
||||
|
||||
def _make_minimal_seedvr2_vae():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
vae.first_stage_model = wrapper
|
||||
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = 8
|
||||
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.process_input = lambda x: x
|
||||
return vae
|
||||
|
||||
|
||||
def test_method_exists_with_seedvr2_signature():
|
||||
assert hasattr(sd_mod.VAE, "encode_tiled_seedvr2"), (
|
||||
"VAE.encode_tiled_seedvr2 must be defined on the VAE class."
|
||||
)
|
||||
sig = inspect.signature(sd_mod.VAE.encode_tiled_seedvr2)
|
||||
params = list(sig.parameters)
|
||||
for required in ("self", "pixel_samples", "tile_x", "tile_y",
|
||||
"overlap", "tile_t", "overlap_t"):
|
||||
assert required in params, (
|
||||
f"VAE.encode_tiled_seedvr2 missing required parameter "
|
||||
f"{required!r}; got parameters {params}."
|
||||
)
|
||||
|
||||
|
||||
def test_vae_encode_tiled_allows_zero_temporal_controls_and_passes_zero_through():
|
||||
input_types = nodes_mod.VAEEncodeTiled.INPUT_TYPES()["required"]
|
||||
assert input_types["temporal_size"][1]["min"] == 0
|
||||
assert input_types["temporal_overlap"][1]["min"] == 0
|
||||
assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"]
|
||||
|
||||
class _EncodeRecorder:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def encode_tiled(self, pixels, **kwargs):
|
||||
self.calls.append({"shape": tuple(pixels.shape), **kwargs})
|
||||
return torch.zeros(1, 16, 1, 8, 8)
|
||||
|
||||
recorder = _EncodeRecorder()
|
||||
node = nodes_mod.VAEEncodeTiled()
|
||||
|
||||
output = node.encode(
|
||||
recorder,
|
||||
torch.zeros(1, 64, 64, 3),
|
||||
tile_size=256,
|
||||
overlap=64,
|
||||
temporal_size=0,
|
||||
temporal_overlap=8,
|
||||
)
|
||||
|
||||
assert recorder.calls == [
|
||||
{
|
||||
"shape": (1, 64, 64, 3),
|
||||
"tile_x": 256,
|
||||
"tile_y": 256,
|
||||
"overlap": 64,
|
||||
"tile_t": 0,
|
||||
"overlap_t": 0,
|
||||
}
|
||||
]
|
||||
assert torch.equal(output[0]["samples"], torch.zeros(1, 16, 1, 8, 8))
|
||||
|
||||
|
||||
def test_method_routes_through_tiled_vae_encode_true():
|
||||
vae = _make_minimal_seedvr2_vae()
|
||||
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
||||
|
||||
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
||||
|
||||
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
||||
vae.encode_tiled_seedvr2(pixel_samples)
|
||||
|
||||
assert tiled_vae_mock.call_count >= 1, (
|
||||
f"Expected encode_tiled_seedvr2 to delegate to tiled_vae at "
|
||||
f"least once; got {tiled_vae_mock.call_count} calls."
|
||||
)
|
||||
for call in tiled_vae_mock.call_args_list:
|
||||
assert call.kwargs.get("encode") is True, (
|
||||
f"Every tiled_vae delegation from encode_tiled_seedvr2 must "
|
||||
f"pass encode=True; got kwargs={call.kwargs!r}."
|
||||
)
|
||||
|
||||
|
||||
def test_method_sets_wrapper_device_before_tiled_vae():
|
||||
vae = _make_minimal_seedvr2_vae()
|
||||
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
||||
assert not hasattr(vae.first_stage_model, "device")
|
||||
|
||||
def _assert_device_initialized(*args, **kwargs):
|
||||
vae_model = args[1]
|
||||
assert vae_model.device == vae.device
|
||||
return torch.zeros((1, 16, 2, 8, 8))
|
||||
|
||||
with patch.object(seedvr_vae_mod, "tiled_vae",
|
||||
MagicMock(side_effect=_assert_device_initialized)):
|
||||
vae.encode_tiled_seedvr2(pixel_samples)
|
||||
|
||||
|
||||
def test_method_honors_explicit_tile_parameters_over_stale_wrapper_args():
|
||||
vae = _make_minimal_seedvr2_vae()
|
||||
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
||||
vae.first_stage_model.tiled_args = {
|
||||
"tile_size": (17, 19),
|
||||
"tile_overlap": (3, 5),
|
||||
"temporal_size": 7,
|
||||
"temporal_overlap": 2,
|
||||
"preserved": "value",
|
||||
}
|
||||
|
||||
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
||||
|
||||
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
||||
vae.encode_tiled_seedvr2(
|
||||
pixel_samples,
|
||||
tile_x=96,
|
||||
tile_y=80,
|
||||
overlap=12,
|
||||
tile_t=11,
|
||||
overlap_t=4,
|
||||
)
|
||||
|
||||
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (80, 96)
|
||||
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (12, 12)
|
||||
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 11
|
||||
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 4
|
||||
assert vae.first_stage_model.tiled_args["preserved"] == "value"
|
||||
|
||||
|
||||
def test_method_uses_explicit_defaults_when_call_omits_tile_parameters():
|
||||
vae = _make_minimal_seedvr2_vae()
|
||||
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
||||
vae.first_stage_model.tiled_args = {
|
||||
"tile_size": (128, 160),
|
||||
"tile_overlap": (16, 24),
|
||||
"temporal_size": 9,
|
||||
"temporal_overlap": 1,
|
||||
}
|
||||
|
||||
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
||||
|
||||
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
||||
vae.encode_tiled_seedvr2(pixel_samples)
|
||||
|
||||
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (512, 512)
|
||||
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (64, 64)
|
||||
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 9999
|
||||
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 0
|
||||
assert vae.first_stage_model.tiled_args == {
|
||||
"tile_size": (128, 160),
|
||||
"tile_overlap": (16, 24),
|
||||
"temporal_size": 9,
|
||||
"temporal_overlap": 1,
|
||||
}
|
||||
|
||||
|
||||
def test_method_clamps_overlap_below_tile_size():
|
||||
vae = _make_minimal_seedvr2_vae()
|
||||
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
||||
|
||||
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
||||
|
||||
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
||||
vae.encode_tiled_seedvr2(
|
||||
pixel_samples,
|
||||
tile_x=64,
|
||||
tile_y=48,
|
||||
overlap=96,
|
||||
)
|
||||
|
||||
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (40, 56)
|
||||
Loading…
Reference in New Issue
Block a user