ComfyUI/tests-unit/comfy_test/test_seedvr2_internals.py
John Pollock fc4a135c04 Finalize SeedVR2 review additions
- Reduce SeedVR2 coverage down to production unit tests

- Route SeedVR2 7B through Comfy varlength attention

- Disable SeedVR2 RoPE cache reuse after the upstream DynamicVRAM change
2026-05-27 04:17:23 -05:00

390 lines
15 KiB
Python

"""Consolidated SeedVR2 internals regression tests.
Sources (all merged verbatim, helper names disambiguated where colliding):
* RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy
apply_rotary_emb wrapper oracle at fp32.
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
memory_occupy against get_norm_limit(), not float('inf').
* var_attention backend registry.
* var_attention_pytorch SeedVR2-named guard — present-API shape contract
with AST-level pinning of the guard ordering.
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
comfy.ldm.modules.attention transitively pull in comfy.model_management,
which probes torch.cuda.current_device() at import time unless args.cpu is
set first.
"""
from __future__ import annotations
import ast
import inspect
import logging
import textwrap
import warnings
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.modules.attention as attention # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
from comfy.ldm.seedvr.model import ( # noqa: E402
Cache,
NaMMRotaryEmbedding3d,
)
from comfy.ldm.seedvr.vae import ( # noqa: E402
causal_norm_wrapper,
set_norm_limit,
)
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
# ---------------------------------------------------------------------------
# RoPE rewrite tests (test_seedvr_rope_rewrite.py)
# ---------------------------------------------------------------------------
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
_DIM = 192
_HEADS = 4
_VID_T, _VID_H, _VID_W = 2, 4, 4
_TXT_L = 8
_L_VID = _VID_T * _VID_H * _VID_W
_SEED = 0
def _make_inputs(dtype=torch.float32, device="cpu"):
"""Construct the 6 forward inputs + cache. Deterministic via local
Generator so global RNG state is not mutated.
"""
g = torch.Generator(device=device).manual_seed(_SEED)
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
cache = Cache(disable=True)
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
unchanged by the rewrite).
"""
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
with torch.amp.autocast(device_type="cuda", enabled=False):
vid_freqs_full = rope.get_axial_freqs(
min(max_temporal + 16, 1024),
min(max_height + 4, 128),
min(max_width + 4, 128),
).float()
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
vid_freq_list.append(vid_freq)
txt_freq_list.append(txt_freq)
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
txt_q, txt_k, txt_shape):
"""Compute expected forward output via the unchanged
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
"""
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
vid_freqs = vid_freqs.to(vid_q.device)
txt_freqs = txt_freqs.to(txt_q.device)
from einops import rearrange
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
txt_q = rearrange(txt_q, "L h d -> h L d")
txt_k = rearrange(txt_k, "L h d -> h L d")
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
rope = NaMMRotaryEmbedding3d(dim=_DIM)
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
rope,
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape,
)
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape, cache,
)
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
msg="vid_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
msg="vid_k output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
msg="txt_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
msg="txt_k output diverges from wrapper oracle")
# ---------------------------------------------------------------------------
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
# ---------------------------------------------------------------------------
_NUM_CHANNELS = 8
_NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
_GROUPNORM_SUBCLASSES = [
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
]
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
real_group_norm = vae_mod.F.group_norm
set_norm_limit(1e-9)
try:
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
gn.eval()
forward_hook_calls = []
def _hook(module, inputs, output):
forward_hook_calls.append(tuple(inputs[0].shape))
spy_calls = []
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
spy_calls.append({"num_groups": int(num_groups_arg)})
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
handle = gn.register_forward_hook(_hook)
try:
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
finally:
handle.remove()
full_calls = len(forward_hook_calls)
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
assert full_calls == 0, (
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
)
assert chunked_calls > 0, (
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
)
finally:
set_norm_limit(None)
# ---------------------------------------------------------------------------
# var_attention backend tests (test_seedvr_var_attention_backends.py)
# ---------------------------------------------------------------------------
def test_var_attention_registry_contains_always_available_entries():
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
dim = 8
heads = 2
head_dim = 4
attn = seedvr_model.NaSwinAttention(
vid_dim=dim,
txt_dim=dim,
heads=heads,
head_dim=head_dim,
qk_bias=False,
qk_norm=seedvr_model.CustomRMSNorm,
qk_norm_eps=1e-6,
rope_type=None,
rope_dim=head_dim,
shared_weights=False,
window=(2, 1, 1),
window_method="720pwin_by_size_bysize",
version=True,
device="cpu",
dtype=torch.float32,
operations=comfy_ops.disable_weight_init,
)
generator = torch.Generator(device="cpu").manual_seed(11)
vid = torch.randn(8, dim, generator=generator)
txt = torch.randn(3, dim, generator=generator)
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
txt_shape = torch.tensor([[3]], dtype=torch.long)
calls = []
def fake_optimized_var_attention(**kwargs):
calls.append(kwargs)
return kwargs["q"]
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
assert tuple(vid_out.shape) == (8, dim)
assert tuple(txt_out.shape) == (3, dim)
assert len(calls) == 1
call = calls[0]
assert tuple(call["q"].shape) == (14, heads, head_dim)
assert tuple(call["k"].shape) == (14, heads, head_dim)
assert tuple(call["v"].shape) == (14, heads, head_dim)
assert call["heads"] == heads
assert call["skip_reshape"] is True
assert call["skip_output_reshape"] is True
torch.testing.assert_close(
call["cu_seqlens_q"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
torch.testing.assert_close(
call["cu_seqlens_k"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
# ---------------------------------------------------------------------------
# var_attention_pytorch SeedVR2 guard tests
# (test_var_attention_pytorch_seedvr2_guard.py)
# ---------------------------------------------------------------------------
def _pytorch_guard_inputs():
heads, head_dim, total_tokens = 2, 8, 6
embed_dim = heads * head_dim
q = torch.randn(total_tokens, embed_dim)
k = torch.randn(total_tokens, embed_dim)
v = torch.randn(total_tokens, embed_dim)
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
return q, k, v, heads, cu, cu, total_tokens, embed_dim
def _assert_guard_source_pin():
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
tree = ast.parse(src)
raise_lines = []
nested_lines = []
for node in ast.walk(tree):
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
func = node.exc.func
if isinstance(func, ast.Name) and func.id == "RuntimeError":
raise_lines.append(node.lineno)
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
nested_lines.append(node.lineno)
assert raise_lines, (
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
)
assert nested_lines, (
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
f"attribute access; cannot pin guard ordering.\n"
f"--- source ---\n{src}"
)
first_raise = min(raise_lines)
first_nested = min(nested_lines)
assert first_raise < first_nested, (
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
f"at line {first_nested}; the guard must precede the lookup.\n"
f"--- source ---\n{src}"
)
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch, "nested", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
def test_present_api_returns_expected_shape():
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs()
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
old_torch_fx_level = torch_fx_logger.level
torch_fx_logger.setLevel(logging.ERROR)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The PyTorch API of nested tensors is in prototype stage.*",
category=UserWarning,
)
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
finally:
torch_fx_logger.setLevel(old_torch_fx_level)
assert tuple(out.shape) == (total_tokens, embed_dim), (
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
)
_assert_guard_source_pin()
def test_malformed_offsets_propagates_torch_runtime_error():
q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs()
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
with pytest.raises(RuntimeError) as exc_info:
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
msg = str(exc_info.value)
assert "SeedVR2" not in msg
_assert_guard_source_pin()