ComfyUI/tests-unit/comfy_test/test_seedvr2_model.py
2026-07-02 22:59:38 -04:00

322 lines
11 KiB
Python

"""SeedVR2 model, latent-format, and VAE graph regression tests."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
import torch
from torch import nn
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy # noqa: E402
import comfy.latent_formats # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
import comfy.sample # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
def _make_standin(positive_conditioning):
class _StandIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"positive_conditioning", positive_conditioning
)
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
return _StandIn()
class _StubModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
flags = []
class _Block(_StubModule):
def __init__(self, *args, **kwargs):
flags.append(kwargs["is_last_layer"])
super().__init__()
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
seedvr_model.NaDiT(
norm_eps=1e-5,
num_layers=4,
mlp_type="normal",
vid_dim=vid_dim,
txt_in_dim=txt_in_dim,
heads=24,
mm_layers=3,
operations=comfy_ops.disable_weight_init,
)
return flags
class _Model:
def __init__(self, latent_format):
self._latent_format = latent_format
def get_model_object(self, name):
assert name == "latent_format"
return self._latent_format
class _Patcher:
def get_free_memory(self, device):
return 1024 * 1024 * 1024
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self, encoded):
nn.Module.__init__(self)
self.encoded = encoded
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.seen = []
def encode(self, x):
self.seen.append(tuple(x.shape))
return self.encoded.to(device=x.device, dtype=x.dtype)
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.calls = []
def decode(self, z, seedvr2_tiling=None):
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
if z.ndim == 4:
b, tc, h, w = z.shape
t = tc // _LATENT_CHANNELS
else:
b, _, t, h, w = z.shape
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0)
seen_shapes = []
def base_encode(self, x):
seen_shapes.append(tuple(x.shape))
return raw_latent.to(device=x.device, dtype=x.dtype)
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
nn.Module.__init__(vae)
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
latent = vae.encode(torch.zeros(1, 3, 32, 40))
assert type(latent) is torch.Tensor
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert seen_shapes == [(1, 3, 1, 32, 40)]
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0)
def base_encode(self, x):
return raw_latent.to(device=x.device, dtype=x.dtype)
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
nn.Module.__init__(vae)
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5)
assert torch.equal(raw, raw_latent)
def _make_vae(wrapper):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = wrapper
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = _LATENT_CHANNELS
vae.latent_dim = 3
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
vae.output_channels = 3
vae.disable_offload = True
vae.extra_1d_channel = None
vae.crop_input = False
vae.not_video = False
vae.handles_tiling = isinstance(wrapper, seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae.format_encoded = wrapper.comfy_format_encoded
vae.patcher = _Patcher()
vae.process_input = lambda image: image
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
vae.vae_output_dtype = lambda: torch.float32
vae.memory_used_encode = lambda shape, dtype: 1
vae.memory_used_decode = lambda shape, dtype: 1
vae.throw_exception_if_invalid = lambda: None
vae.vae_encode_crop_pixels = lambda pixels: pixels
vae.spacial_compression_decode = lambda: 8
vae.temporal_compression_decode = lambda: 4
return vae
def test_missing_context_falls_back_to_positive_buffer():
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
assert txt.shape == (58, 5120)
assert (txt == 7.0).all(), (
"fallback path must use the positive_conditioning buffer "
"verbatim, not a zero tensor"
)
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False,
False,
False,
False,
]
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
generator = torch.Generator(device="cpu").manual_seed(0)
q = torch.randn(4, 2, 128, generator=generator)
k = torch.randn(4, 2, 128, generator=generator)
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
q.permute(1, 0, 2).float(),
).to(q.dtype).permute(1, 0, 2)
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
k.permute(1, 0, 2).float(),
).to(k.dtype).permute(1, 0, 2)
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
def test_seedvr2_forward_requires_conditioning_latents():
model = NaDiT.__new__(NaDiT)
x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5)
with pytest.raises(ValueError, match="requires conditioning latents"):
NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None)
def test_seedvr2_latent_format_uses_native_video_latent_shape():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == _LATENT_CHANNELS
assert latent_format.latent_dimensions == 3
assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5)
def test_seedvr2_model_requires_native_5d_latent():
latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)
assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent
with pytest.raises(ValueError, match="5-D native latent"):
NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent")
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0)
vae = _make_vae(_EncodeWrapper(encoded))
pixels = torch.zeros(1, 5, 32, 40, 3)
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
node_latent = node_output["samples"]
assert set(node_output) == {"samples"}
assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert node_latent.dtype == torch.float32
assert node_latent.stride()[-1] == 1
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0)
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
tiled_output = nodes_mod.VAEEncodeTiled().encode(
vae,
pixels,
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)[0]
tiled_latent = tiled_output["samples"]
assert set(tiled_output) == {"samples"}
assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert tiled_latent.dtype == torch.float32
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
vae = _make_vae(_DecodeWrapper())
nodes_mod.VAEDecodeTiled().decode(
vae,
{"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)},
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)
# Spatial inputs flow through; temporal inputs are discarded — SeedVR2 owns
# temporal via the MemoryState causal cache, so VAEDecodeTiled's temporal
# knobs are no-ops at the wrapper.
assert vae.first_stage_model.calls == [
{
"shape": (1, _LATENT_CHANNELS, 2, 4, 5),
"seedvr2_tiling": {
"enable_tiling": True,
"tile_size": (512, 512),
"tile_overlap": (64, 64),
"temporal_size": 0,
"temporal_overlap": 0,
},
}
]