mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-31 03:17:23 +08:00
- Reduce SeedVR2 coverage down to production unit tests - Route SeedVR2 7B through Comfy varlength attention - Disable SeedVR2 RoPE cache reuse after the upstream DynamicVRAM change
309 lines
10 KiB
Python
309 lines
10 KiB
Python
"""Consolidated SeedVR2 model/graph/forward regression tests.
|
|
|
|
Merged from:
|
|
- seedvr_model_test.py
|
|
- test_seedvr_7b_final_block_text_path.py
|
|
- test_seedvr_forward_no_device_cast.py
|
|
- test_seedvr_latent_format.py
|
|
- test_seedvr2_vae_graph_boundaries.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from comfy.cli_args import args
|
|
|
|
if not torch.cuda.is_available():
|
|
args.cpu = True
|
|
|
|
import comfy # noqa: E402
|
|
import comfy.latent_formats # noqa: E402
|
|
import comfy.ldm.seedvr.model # noqa: E402
|
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
|
import comfy.model_management # noqa: E402
|
|
import comfy.sample # noqa: E402
|
|
import comfy.sd as sd_mod # noqa: E402
|
|
import nodes as nodes_mod # noqa: E402
|
|
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers from seedvr_model_test.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_standin(positive_conditioning):
|
|
class _StandIn(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"positive_conditioning", positive_conditioning
|
|
)
|
|
|
|
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
|
|
|
|
return _StandIn()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers from test_seedvr_7b_final_block_text_path.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _StubModule(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
|
|
|
|
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
|
|
flags = []
|
|
|
|
class _Block(_StubModule):
|
|
def __init__(self, *args, **kwargs):
|
|
flags.append(kwargs["is_last_layer"])
|
|
super().__init__()
|
|
|
|
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
|
|
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
|
|
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
|
|
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
|
|
|
|
seedvr_model.NaDiT(
|
|
norm_eps=1e-5,
|
|
qk_rope=None,
|
|
num_layers=4,
|
|
mlp_type="normal",
|
|
vid_dim=vid_dim,
|
|
txt_in_dim=txt_in_dim,
|
|
heads=24,
|
|
mm_layers=3,
|
|
)
|
|
|
|
return flags
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers from test_seedvr_latent_format.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _Model:
|
|
def __init__(self, latent_format):
|
|
self._latent_format = latent_format
|
|
|
|
def get_model_object(self, name):
|
|
assert name == "latent_format"
|
|
return self._latent_format
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers from test_seedvr2_vae_graph_boundaries.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _Patcher:
|
|
def get_free_memory(self, device):
|
|
return 1024 * 1024 * 1024
|
|
|
|
|
|
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self, encoded):
|
|
nn.Module.__init__(self)
|
|
self.encoded = encoded
|
|
self.spatial_downsample_factor = 8
|
|
self.temporal_downsample_factor = 4
|
|
self.seen = []
|
|
|
|
def encode(self, x):
|
|
self.seen.append(tuple(x.shape))
|
|
return self.encoded.to(device=x.device, dtype=x.dtype)
|
|
|
|
|
|
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self.spatial_downsample_factor = 8
|
|
self.temporal_downsample_factor = 4
|
|
self.calls = []
|
|
|
|
def decode(self, z, seedvr2_tiling=None):
|
|
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
|
if z.ndim == 4:
|
|
b, tc, h, w = z.shape
|
|
t = tc // 16
|
|
else:
|
|
b, _, t, h, w = z.shape
|
|
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
|
|
|
|
|
def _make_vae(wrapper):
|
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
|
vae.first_stage_model = wrapper
|
|
vae.device = torch.device("cpu")
|
|
vae.output_device = torch.device("cpu")
|
|
vae.vae_dtype = torch.float32
|
|
vae.latent_channels = 16
|
|
vae.latent_dim = 3
|
|
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
|
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
|
vae.output_channels = 3
|
|
vae.disable_offload = True
|
|
vae.extra_1d_channel = None
|
|
vae.crop_input = False
|
|
vae.not_video = False
|
|
vae.patcher = _Patcher()
|
|
vae.process_input = lambda image: image
|
|
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
|
vae.vae_output_dtype = lambda: torch.float32
|
|
vae.memory_used_encode = lambda shape, dtype: 1
|
|
vae.memory_used_decode = lambda shape, dtype: 1
|
|
vae.throw_exception_if_invalid = lambda: None
|
|
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
|
vae.spacial_compression_decode = lambda: 8
|
|
vae.temporal_compression_decode = lambda: 4
|
|
return vae
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests from seedvr_model_test.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_missing_context_falls_back_to_positive_buffer():
|
|
"""AC: ``context is None`` falls back to the registered
|
|
``positive_conditioning`` buffer and runs to completion — no
|
|
silent zero substitution, no raised exception.
|
|
"""
|
|
pos_buffer = torch.full((58, 5120), 7.0)
|
|
standin = _make_standin(pos_buffer)
|
|
txt, txt_shape = standin._resolve_text_conditioning(None)
|
|
assert txt.shape == (58, 5120)
|
|
assert (txt == 7.0).all(), (
|
|
"fallback path must use the positive_conditioning buffer "
|
|
"verbatim, not a zero tensor"
|
|
)
|
|
assert txt_shape.shape == (1, 1)
|
|
assert txt_shape[0, 0].item() == 58
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests from test_seedvr_7b_final_block_text_path.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
|
|
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
]
|
|
|
|
|
|
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
|
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
|
q = torch.randn(4, 2, 128, generator=generator)
|
|
k = torch.randn(4, 2, 128, generator=generator)
|
|
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
|
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
|
|
|
expected_q = seedvr_model.apply_rotary_emb(
|
|
freqs,
|
|
q.permute(1, 0, 2).float(),
|
|
).to(q.dtype).permute(1, 0, 2)
|
|
expected_k = seedvr_model.apply_rotary_emb(
|
|
freqs,
|
|
k.permute(1, 0, 2).float(),
|
|
).to(k.dtype).permute(1, 0, 2)
|
|
|
|
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
|
|
|
|
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
|
|
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests from test_seedvr_latent_format.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
|
|
latent_format = comfy.latent_formats.SeedVR2()
|
|
latent_image = torch.zeros(1, 1, 4, 5)
|
|
|
|
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
|
|
|
assert latent_format.latent_channels == 16
|
|
assert latent_format.latent_dimensions == 2
|
|
assert fixed.shape == (1, 16, 4, 5)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests from test_seedvr2_vae_graph_boundaries.py
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
|
|
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
|
vae = _make_vae(_EncodeWrapper(encoded))
|
|
pixels = torch.zeros(1, 5, 32, 40, 3)
|
|
|
|
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
|
node_latent = node_output["samples"]
|
|
assert set(node_output) == {"samples"}
|
|
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
|
assert node_latent.dtype == torch.float32
|
|
assert node_latent.stride()[-1] == 1
|
|
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
|
|
|
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
|
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
|
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
|
vae,
|
|
pixels,
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)[0]
|
|
tiled_latent = tiled_output["samples"]
|
|
assert set(tiled_output) == {"samples"}
|
|
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
|
assert tiled_latent.dtype == torch.float32
|
|
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
|
|
|
|
|
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
nodes_mod.VAEDecodeTiled().decode(
|
|
vae,
|
|
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)
|
|
|
|
assert vae.first_stage_model.calls == [
|
|
{
|
|
"shape": (1, 16, 2, 4, 5),
|
|
"seedvr2_tiling": {
|
|
"enable_tiling": True,
|
|
"tile_size": (512, 512),
|
|
"tile_overlap": (64, 64),
|
|
"temporal_size": 16,
|
|
"temporal_overlap": 4,
|
|
},
|
|
}
|
|
]
|