mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
322 lines
11 KiB
Python
322 lines
11 KiB
Python
"""SeedVR2 model, latent-format, and VAE graph regression tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
from comfy.cli_args import args
|
|
|
|
if not torch.cuda.is_available():
|
|
args.cpu = True
|
|
|
|
import comfy # noqa: E402
|
|
import comfy.latent_formats # noqa: E402
|
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
|
import comfy.model_management # noqa: E402
|
|
import comfy.ops as comfy_ops # noqa: E402
|
|
import comfy.sample # noqa: E402
|
|
import comfy.sd as sd_mod # noqa: E402
|
|
import nodes as nodes_mod # noqa: E402
|
|
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
|
|
|
|
|
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
|
|
|
|
|
|
def _make_standin(positive_conditioning):
|
|
class _StandIn(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"positive_conditioning", positive_conditioning
|
|
)
|
|
|
|
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
|
|
|
|
return _StandIn()
|
|
|
|
|
|
class _StubModule(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
|
|
|
|
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
|
|
flags = []
|
|
|
|
class _Block(_StubModule):
|
|
def __init__(self, *args, **kwargs):
|
|
flags.append(kwargs["is_last_layer"])
|
|
super().__init__()
|
|
|
|
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
|
|
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
|
|
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
|
|
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
|
|
|
|
seedvr_model.NaDiT(
|
|
norm_eps=1e-5,
|
|
num_layers=4,
|
|
mlp_type="normal",
|
|
vid_dim=vid_dim,
|
|
txt_in_dim=txt_in_dim,
|
|
heads=24,
|
|
mm_layers=3,
|
|
operations=comfy_ops.disable_weight_init,
|
|
)
|
|
|
|
return flags
|
|
|
|
|
|
class _Model:
|
|
def __init__(self, latent_format):
|
|
self._latent_format = latent_format
|
|
|
|
def get_model_object(self, name):
|
|
assert name == "latent_format"
|
|
return self._latent_format
|
|
|
|
|
|
class _Patcher:
|
|
def get_free_memory(self, device):
|
|
return 1024 * 1024 * 1024
|
|
|
|
|
|
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self, encoded):
|
|
nn.Module.__init__(self)
|
|
self.encoded = encoded
|
|
self.spatial_downsample_factor = 8
|
|
self.temporal_downsample_factor = 4
|
|
self.seen = []
|
|
|
|
def encode(self, x):
|
|
self.seen.append(tuple(x.shape))
|
|
return self.encoded.to(device=x.device, dtype=x.dtype)
|
|
|
|
|
|
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self.spatial_downsample_factor = 8
|
|
self.temporal_downsample_factor = 4
|
|
self.calls = []
|
|
|
|
def decode(self, z, seedvr2_tiling=None):
|
|
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
|
if z.ndim == 4:
|
|
b, tc, h, w = z.shape
|
|
t = tc // _LATENT_CHANNELS
|
|
else:
|
|
b, _, t, h, w = z.shape
|
|
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
|
|
|
|
|
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
|
|
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0)
|
|
seen_shapes = []
|
|
|
|
def base_encode(self, x):
|
|
seen_shapes.append(tuple(x.shape))
|
|
return raw_latent.to(device=x.device, dtype=x.dtype)
|
|
|
|
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
|
|
|
|
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
|
nn.Module.__init__(vae)
|
|
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
|
|
|
|
latent = vae.encode(torch.zeros(1, 3, 32, 40))
|
|
|
|
assert type(latent) is torch.Tensor
|
|
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
|
|
assert seen_shapes == [(1, 3, 1, 32, 40)]
|
|
|
|
|
|
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
|
|
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0)
|
|
|
|
def base_encode(self, x):
|
|
return raw_latent.to(device=x.device, dtype=x.dtype)
|
|
|
|
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
|
|
|
|
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
|
nn.Module.__init__(vae)
|
|
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
|
|
|
|
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
|
|
|
|
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
|
|
assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5)
|
|
assert torch.equal(raw, raw_latent)
|
|
|
|
|
|
def _make_vae(wrapper):
|
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
|
vae.first_stage_model = wrapper
|
|
vae.device = torch.device("cpu")
|
|
vae.output_device = torch.device("cpu")
|
|
vae.vae_dtype = torch.float32
|
|
vae.latent_channels = _LATENT_CHANNELS
|
|
vae.latent_dim = 3
|
|
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
|
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
|
vae.output_channels = 3
|
|
vae.disable_offload = True
|
|
vae.extra_1d_channel = None
|
|
vae.crop_input = False
|
|
vae.not_video = False
|
|
vae.handles_tiling = isinstance(wrapper, seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
|
vae.format_encoded = wrapper.comfy_format_encoded
|
|
vae.patcher = _Patcher()
|
|
vae.process_input = lambda image: image
|
|
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
|
vae.vae_output_dtype = lambda: torch.float32
|
|
vae.memory_used_encode = lambda shape, dtype: 1
|
|
vae.memory_used_decode = lambda shape, dtype: 1
|
|
vae.throw_exception_if_invalid = lambda: None
|
|
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
|
vae.spacial_compression_decode = lambda: 8
|
|
vae.temporal_compression_decode = lambda: 4
|
|
return vae
|
|
|
|
|
|
def test_missing_context_falls_back_to_positive_buffer():
|
|
pos_buffer = torch.full((58, 5120), 7.0)
|
|
standin = _make_standin(pos_buffer)
|
|
txt, txt_shape = standin._resolve_text_conditioning(None)
|
|
assert txt.shape == (58, 5120)
|
|
assert (txt == 7.0).all(), (
|
|
"fallback path must use the positive_conditioning buffer "
|
|
"verbatim, not a zero tensor"
|
|
)
|
|
assert txt_shape.shape == (1, 1)
|
|
assert txt_shape[0, 0].item() == 58
|
|
|
|
|
|
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
|
|
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
]
|
|
|
|
|
|
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
|
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
|
q = torch.randn(4, 2, 128, generator=generator)
|
|
k = torch.randn(4, 2, 128, generator=generator)
|
|
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
|
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
|
|
|
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
|
|
freqs,
|
|
q.permute(1, 0, 2).float(),
|
|
).to(q.dtype).permute(1, 0, 2)
|
|
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
|
|
freqs,
|
|
k.permute(1, 0, 2).float(),
|
|
).to(k.dtype).permute(1, 0, 2)
|
|
|
|
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
|
|
|
|
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
|
|
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
|
|
|
|
|
|
def test_seedvr2_forward_requires_conditioning_latents():
|
|
model = NaDiT.__new__(NaDiT)
|
|
x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5)
|
|
|
|
with pytest.raises(ValueError, match="requires conditioning latents"):
|
|
NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None)
|
|
|
|
|
|
def test_seedvr2_latent_format_uses_native_video_latent_shape():
|
|
latent_format = comfy.latent_formats.SeedVR2()
|
|
latent_image = torch.zeros(1, 1, 4, 5)
|
|
|
|
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
|
|
|
assert latent_format.latent_channels == _LATENT_CHANNELS
|
|
assert latent_format.latent_dimensions == 3
|
|
assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5)
|
|
|
|
|
|
def test_seedvr2_model_requires_native_5d_latent():
|
|
latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)
|
|
assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent
|
|
|
|
with pytest.raises(ValueError, match="5-D native latent"):
|
|
NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent")
|
|
|
|
|
|
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
|
|
encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0)
|
|
vae = _make_vae(_EncodeWrapper(encoded))
|
|
pixels = torch.zeros(1, 5, 32, 40, 3)
|
|
|
|
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
|
node_latent = node_output["samples"]
|
|
assert set(node_output) == {"samples"}
|
|
assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
|
|
assert node_latent.dtype == torch.float32
|
|
assert node_latent.stride()[-1] == 1
|
|
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
|
|
|
|
tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0)
|
|
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
|
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
|
vae,
|
|
pixels,
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)[0]
|
|
tiled_latent = tiled_output["samples"]
|
|
assert set(tiled_output) == {"samples"}
|
|
assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
|
|
assert tiled_latent.dtype == torch.float32
|
|
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
|
|
|
|
|
|
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
nodes_mod.VAEDecodeTiled().decode(
|
|
vae,
|
|
{"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)},
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)
|
|
|
|
# Spatial inputs flow through; temporal inputs are discarded — SeedVR2 owns
|
|
# temporal via the MemoryState causal cache, so VAEDecodeTiled's temporal
|
|
# knobs are no-ops at the wrapper.
|
|
assert vae.first_stage_model.calls == [
|
|
{
|
|
"shape": (1, _LATENT_CHANNELS, 2, 4, 5),
|
|
"seedvr2_tiling": {
|
|
"enable_tiling": True,
|
|
"tile_size": (512, 512),
|
|
"tile_overlap": (64, 64),
|
|
"temporal_size": 0,
|
|
"temporal_overlap": 0,
|
|
},
|
|
}
|
|
]
|