ComfyUI/tests-unit/comfy_test/test_vae_encode_tiled_seedvr2_method.py
2026-05-26 00:28:36 -05:00

206 lines
6.7 KiB
Python

"""Unit tests for ``VAE.encode_tiled_seedvr2``: existence with the
SeedVR2 tile-shape signature and delegation through
``comfy.ldm.seedvr.vae.tiled_vae(..., encode=True)`` with one call per
spatial tile.
Mirrors the decode-side method-existence + delegation contract for
``VAE.decode_tiled_seedvr2``; CPU-only via mocks and a
``VideoAutoencoderKLWrapper.__new__`` wrapper stub (no weights, no
GPU).
"""
import inspect
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
def _make_minimal_seedvr2_vae():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.vae_output_dtype = lambda: torch.float32
vae.process_input = lambda x: x
return vae
def test_method_exists_with_seedvr2_signature():
assert hasattr(sd_mod.VAE, "encode_tiled_seedvr2"), (
"VAE.encode_tiled_seedvr2 must be defined on the VAE class."
)
sig = inspect.signature(sd_mod.VAE.encode_tiled_seedvr2)
params = list(sig.parameters)
for required in ("self", "pixel_samples", "tile_x", "tile_y",
"overlap", "tile_t", "overlap_t"):
assert required in params, (
f"VAE.encode_tiled_seedvr2 missing required parameter "
f"{required!r}; got parameters {params}."
)
def test_vae_encode_tiled_allows_zero_temporal_controls_and_passes_zero_through():
input_types = nodes_mod.VAEEncodeTiled.INPUT_TYPES()["required"]
assert input_types["temporal_size"][1]["min"] == 0
assert input_types["temporal_overlap"][1]["min"] == 0
assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"]
class _EncodeRecorder:
def __init__(self):
self.calls = []
def encode_tiled(self, pixels, **kwargs):
self.calls.append({"shape": tuple(pixels.shape), **kwargs})
return torch.zeros(1, 16, 1, 8, 8)
recorder = _EncodeRecorder()
node = nodes_mod.VAEEncodeTiled()
output = node.encode(
recorder,
torch.zeros(1, 64, 64, 3),
tile_size=256,
overlap=64,
temporal_size=0,
temporal_overlap=8,
)
assert recorder.calls == [
{
"shape": (1, 64, 64, 3),
"tile_x": 256,
"tile_y": 256,
"overlap": 64,
"tile_t": 0,
"overlap_t": 0,
}
]
assert torch.equal(output[0]["samples"], torch.zeros(1, 16, 1, 8, 8))
def test_method_routes_through_tiled_vae_encode_true():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(pixel_samples)
assert tiled_vae_mock.call_count >= 1, (
f"Expected encode_tiled_seedvr2 to delegate to tiled_vae at "
f"least once; got {tiled_vae_mock.call_count} calls."
)
for call in tiled_vae_mock.call_args_list:
assert call.kwargs.get("encode") is True, (
f"Every tiled_vae delegation from encode_tiled_seedvr2 must "
f"pass encode=True; got kwargs={call.kwargs!r}."
)
def test_method_sets_wrapper_device_before_tiled_vae():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
assert not hasattr(vae.first_stage_model, "device")
def _assert_device_initialized(*args, **kwargs):
vae_model = args[1]
assert vae_model.device == vae.device
return torch.zeros((1, 16, 2, 8, 8))
with patch.object(seedvr_vae_mod, "tiled_vae",
MagicMock(side_effect=_assert_device_initialized)):
vae.encode_tiled_seedvr2(pixel_samples)
def test_method_honors_explicit_tile_parameters_over_stale_wrapper_args():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
vae.first_stage_model.tiled_args = {
"tile_size": (17, 19),
"tile_overlap": (3, 5),
"temporal_size": 7,
"temporal_overlap": 2,
"preserved": "value",
}
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(
pixel_samples,
tile_x=96,
tile_y=80,
overlap=12,
tile_t=11,
overlap_t=4,
)
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (80, 96)
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (12, 12)
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 11
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 4
assert vae.first_stage_model.tiled_args["preserved"] == "value"
def test_method_uses_explicit_defaults_when_call_omits_tile_parameters():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
vae.first_stage_model.tiled_args = {
"tile_size": (128, 160),
"tile_overlap": (16, 24),
"temporal_size": 9,
"temporal_overlap": 1,
}
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(pixel_samples)
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (512, 512)
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (64, 64)
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 9999
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 0
assert vae.first_stage_model.tiled_args == {
"tile_size": (128, 160),
"tile_overlap": (16, 24),
"temporal_size": 9,
"temporal_overlap": 1,
}
def test_method_clamps_overlap_below_tile_size():
vae = _make_minimal_seedvr2_vae()
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
vae.encode_tiled_seedvr2(
pixel_samples,
tile_x=64,
tile_y=48,
overlap=96,
)
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (40, 56)