mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
206 lines
6.7 KiB
Python
206 lines
6.7 KiB
Python
"""Unit tests for ``VAE.encode_tiled_seedvr2``: existence with the
|
|
SeedVR2 tile-shape signature and delegation through
|
|
``comfy.ldm.seedvr.vae.tiled_vae(..., encode=True)`` with one call per
|
|
spatial tile.
|
|
|
|
Mirrors the decode-side method-existence + delegation contract for
|
|
``VAE.decode_tiled_seedvr2``; CPU-only via mocks and a
|
|
``VideoAutoencoderKLWrapper.__new__`` wrapper stub (no weights, no
|
|
GPU).
|
|
"""
|
|
|
|
import inspect
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
|
import comfy.sd as sd_mod # noqa: E402
|
|
import nodes as nodes_mod # noqa: E402
|
|
|
|
|
|
def _make_minimal_seedvr2_vae():
|
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
|
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
|
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
|
)
|
|
vae.first_stage_model = wrapper
|
|
|
|
vae.device = torch.device("cpu")
|
|
vae.output_device = torch.device("cpu")
|
|
vae.vae_dtype = torch.float32
|
|
vae.latent_channels = 16
|
|
vae.latent_dim = 3
|
|
vae.downscale_ratio = 8
|
|
|
|
vae.vae_output_dtype = lambda: torch.float32
|
|
vae.process_input = lambda x: x
|
|
return vae
|
|
|
|
|
|
def test_method_exists_with_seedvr2_signature():
|
|
assert hasattr(sd_mod.VAE, "encode_tiled_seedvr2"), (
|
|
"VAE.encode_tiled_seedvr2 must be defined on the VAE class."
|
|
)
|
|
sig = inspect.signature(sd_mod.VAE.encode_tiled_seedvr2)
|
|
params = list(sig.parameters)
|
|
for required in ("self", "pixel_samples", "tile_x", "tile_y",
|
|
"overlap", "tile_t", "overlap_t"):
|
|
assert required in params, (
|
|
f"VAE.encode_tiled_seedvr2 missing required parameter "
|
|
f"{required!r}; got parameters {params}."
|
|
)
|
|
|
|
|
|
def test_vae_encode_tiled_allows_zero_temporal_controls_and_passes_zero_through():
|
|
input_types = nodes_mod.VAEEncodeTiled.INPUT_TYPES()["required"]
|
|
assert input_types["temporal_size"][1]["min"] == 0
|
|
assert input_types["temporal_overlap"][1]["min"] == 0
|
|
assert "SeedVR2 allows 0" in input_types["temporal_size"][1]["tooltip"]
|
|
|
|
class _EncodeRecorder:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def encode_tiled(self, pixels, **kwargs):
|
|
self.calls.append({"shape": tuple(pixels.shape), **kwargs})
|
|
return torch.zeros(1, 16, 1, 8, 8)
|
|
|
|
recorder = _EncodeRecorder()
|
|
node = nodes_mod.VAEEncodeTiled()
|
|
|
|
output = node.encode(
|
|
recorder,
|
|
torch.zeros(1, 64, 64, 3),
|
|
tile_size=256,
|
|
overlap=64,
|
|
temporal_size=0,
|
|
temporal_overlap=8,
|
|
)
|
|
|
|
assert recorder.calls == [
|
|
{
|
|
"shape": (1, 64, 64, 3),
|
|
"tile_x": 256,
|
|
"tile_y": 256,
|
|
"overlap": 64,
|
|
"tile_t": 0,
|
|
"overlap_t": 0,
|
|
}
|
|
]
|
|
assert torch.equal(output[0]["samples"], torch.zeros(1, 16, 1, 8, 8))
|
|
|
|
|
|
def test_method_routes_through_tiled_vae_encode_true():
|
|
vae = _make_minimal_seedvr2_vae()
|
|
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
|
|
|
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
|
|
|
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
|
vae.encode_tiled_seedvr2(pixel_samples)
|
|
|
|
assert tiled_vae_mock.call_count >= 1, (
|
|
f"Expected encode_tiled_seedvr2 to delegate to tiled_vae at "
|
|
f"least once; got {tiled_vae_mock.call_count} calls."
|
|
)
|
|
for call in tiled_vae_mock.call_args_list:
|
|
assert call.kwargs.get("encode") is True, (
|
|
f"Every tiled_vae delegation from encode_tiled_seedvr2 must "
|
|
f"pass encode=True; got kwargs={call.kwargs!r}."
|
|
)
|
|
|
|
|
|
def test_method_sets_wrapper_device_before_tiled_vae():
|
|
vae = _make_minimal_seedvr2_vae()
|
|
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
|
assert not hasattr(vae.first_stage_model, "device")
|
|
|
|
def _assert_device_initialized(*args, **kwargs):
|
|
vae_model = args[1]
|
|
assert vae_model.device == vae.device
|
|
return torch.zeros((1, 16, 2, 8, 8))
|
|
|
|
with patch.object(seedvr_vae_mod, "tiled_vae",
|
|
MagicMock(side_effect=_assert_device_initialized)):
|
|
vae.encode_tiled_seedvr2(pixel_samples)
|
|
|
|
|
|
def test_method_honors_explicit_tile_parameters_over_stale_wrapper_args():
|
|
vae = _make_minimal_seedvr2_vae()
|
|
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
|
vae.first_stage_model.tiled_args = {
|
|
"tile_size": (17, 19),
|
|
"tile_overlap": (3, 5),
|
|
"temporal_size": 7,
|
|
"temporal_overlap": 2,
|
|
"preserved": "value",
|
|
}
|
|
|
|
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
|
|
|
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
|
vae.encode_tiled_seedvr2(
|
|
pixel_samples,
|
|
tile_x=96,
|
|
tile_y=80,
|
|
overlap=12,
|
|
tile_t=11,
|
|
overlap_t=4,
|
|
)
|
|
|
|
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (80, 96)
|
|
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (12, 12)
|
|
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 11
|
|
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 4
|
|
assert vae.first_stage_model.tiled_args["preserved"] == "value"
|
|
|
|
|
|
def test_method_uses_explicit_defaults_when_call_omits_tile_parameters():
|
|
vae = _make_minimal_seedvr2_vae()
|
|
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
|
vae.first_stage_model.tiled_args = {
|
|
"tile_size": (128, 160),
|
|
"tile_overlap": (16, 24),
|
|
"temporal_size": 9,
|
|
"temporal_overlap": 1,
|
|
}
|
|
|
|
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
|
|
|
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
|
vae.encode_tiled_seedvr2(pixel_samples)
|
|
|
|
assert tiled_vae_mock.call_args.kwargs["tile_size"] == (512, 512)
|
|
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (64, 64)
|
|
assert tiled_vae_mock.call_args.kwargs["temporal_size"] == 9999
|
|
assert tiled_vae_mock.call_args.kwargs["temporal_overlap"] == 0
|
|
assert vae.first_stage_model.tiled_args == {
|
|
"tile_size": (128, 160),
|
|
"tile_overlap": (16, 24),
|
|
"temporal_size": 9,
|
|
"temporal_overlap": 1,
|
|
}
|
|
|
|
|
|
def test_method_clamps_overlap_below_tile_size():
|
|
vae = _make_minimal_seedvr2_vae()
|
|
pixel_samples = torch.zeros((1, 3, 8, 64, 64))
|
|
|
|
tiled_vae_mock = MagicMock(return_value=torch.zeros((1, 16, 2, 8, 8)))
|
|
|
|
with patch.object(seedvr_vae_mod, "tiled_vae", tiled_vae_mock):
|
|
vae.encode_tiled_seedvr2(
|
|
pixel_samples,
|
|
tile_x=64,
|
|
tile_y=48,
|
|
overlap=96,
|
|
)
|
|
|
|
assert tiled_vae_mock.call_args.kwargs["tile_overlap"] == (40, 56)
|