mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 10:57:23 +08:00
- Reduce SeedVR2 coverage down to production unit tests - Route SeedVR2 7B through Comfy varlength attention - Disable SeedVR2 RoPE cache reuse after the upstream DynamicVRAM change
390 lines
15 KiB
Python
390 lines
15 KiB
Python
"""Consolidated SeedVR2 internals regression tests.
|
|
|
|
Sources (all merged verbatim, helper names disambiguated where colliding):
|
|
|
|
* RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy
|
|
apply_rotary_emb wrapper oracle at fp32.
|
|
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
|
memory_occupy against get_norm_limit(), not float('inf').
|
|
* var_attention backend registry.
|
|
* var_attention_pytorch SeedVR2-named guard — present-API shape contract
|
|
with AST-level pinning of the guard ordering.
|
|
|
|
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
|
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
|
which probes torch.cuda.current_device() at import time unless args.cpu is
|
|
set first.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import inspect
|
|
import logging
|
|
import textwrap
|
|
import warnings
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from comfy.cli_args import args
|
|
|
|
if not torch.cuda.is_available():
|
|
args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
|
import comfy.ldm.modules.attention as attention # noqa: E402
|
|
import comfy.ops as comfy_ops # noqa: E402
|
|
from comfy.ldm.seedvr.model import ( # noqa: E402
|
|
Cache,
|
|
NaMMRotaryEmbedding3d,
|
|
)
|
|
from comfy.ldm.seedvr.vae import ( # noqa: E402
|
|
causal_norm_wrapper,
|
|
set_norm_limit,
|
|
)
|
|
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RoPE rewrite tests (test_seedvr_rope_rewrite.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
|
|
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
|
|
_DIM = 192
|
|
_HEADS = 4
|
|
_VID_T, _VID_H, _VID_W = 2, 4, 4
|
|
_TXT_L = 8
|
|
_L_VID = _VID_T * _VID_H * _VID_W
|
|
_SEED = 0
|
|
|
|
|
|
def _make_inputs(dtype=torch.float32, device="cpu"):
|
|
"""Construct the 6 forward inputs + cache. Deterministic via local
|
|
Generator so global RNG state is not mutated.
|
|
"""
|
|
g = torch.Generator(device=device).manual_seed(_SEED)
|
|
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
|
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
|
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
|
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
|
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
|
|
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
|
|
cache = Cache(disable=True)
|
|
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
|
|
|
|
|
|
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
|
|
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
|
|
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
|
|
unchanged by the rewrite).
|
|
"""
|
|
max_temporal = 0
|
|
max_height = 0
|
|
max_width = 0
|
|
max_txt_len = 0
|
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
|
max_temporal = max(max_temporal, l + f)
|
|
max_height = max(max_height, h)
|
|
max_width = max(max_width, w)
|
|
max_txt_len = max(max_txt_len, l)
|
|
with torch.amp.autocast(device_type="cuda", enabled=False):
|
|
vid_freqs_full = rope.get_axial_freqs(
|
|
min(max_temporal + 16, 1024),
|
|
min(max_height + 4, 128),
|
|
min(max_width + 4, 128),
|
|
).float()
|
|
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
|
|
vid_freq_list, txt_freq_list = [], []
|
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
|
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
|
|
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
|
|
vid_freq_list.append(vid_freq)
|
|
txt_freq_list.append(txt_freq)
|
|
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
|
|
|
|
|
|
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
|
|
txt_q, txt_k, txt_shape):
|
|
"""Compute expected forward output via the unchanged
|
|
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
|
|
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
|
|
"""
|
|
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
|
|
vid_freqs = vid_freqs.to(vid_q.device)
|
|
txt_freqs = txt_freqs.to(txt_q.device)
|
|
|
|
from einops import rearrange
|
|
|
|
vid_q = rearrange(vid_q, "L h d -> h L d")
|
|
vid_k = rearrange(vid_k, "L h d -> h L d")
|
|
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
|
|
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
|
|
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
|
|
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
|
|
|
|
txt_q = rearrange(txt_q, "L h d -> h L d")
|
|
txt_k = rearrange(txt_k, "L h d -> h L d")
|
|
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
|
|
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
|
|
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
|
|
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
|
|
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
|
|
|
|
|
|
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
|
|
rope = NaMMRotaryEmbedding3d(dim=_DIM)
|
|
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
|
|
|
|
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
|
|
rope,
|
|
vid_q.clone(), vid_k.clone(), vid_shape,
|
|
txt_q.clone(), txt_k.clone(), txt_shape,
|
|
)
|
|
|
|
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
|
|
vid_q.clone(), vid_k.clone(), vid_shape,
|
|
txt_q.clone(), txt_k.clone(), txt_shape, cache,
|
|
)
|
|
|
|
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
|
|
msg="vid_q output diverges from wrapper oracle")
|
|
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
|
|
msg="vid_k output diverges from wrapper oracle")
|
|
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
|
|
msg="txt_q output diverges from wrapper oracle")
|
|
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
|
|
msg="txt_k output diverges from wrapper oracle")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_NUM_CHANNELS = 8
|
|
_NUM_GROUPS = 4
|
|
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
|
|
|
|
_GROUPNORM_SUBCLASSES = [
|
|
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
|
|
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
|
|
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
|
real_group_norm = vae_mod.F.group_norm
|
|
set_norm_limit(1e-9)
|
|
try:
|
|
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
|
|
gn.eval()
|
|
|
|
forward_hook_calls = []
|
|
|
|
def _hook(module, inputs, output):
|
|
forward_hook_calls.append(tuple(inputs[0].shape))
|
|
|
|
spy_calls = []
|
|
|
|
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
|
|
spy_calls.append({"num_groups": int(num_groups_arg)})
|
|
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
|
|
|
|
handle = gn.register_forward_hook(_hook)
|
|
try:
|
|
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
|
|
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
|
|
finally:
|
|
handle.remove()
|
|
|
|
full_calls = len(forward_hook_calls)
|
|
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
|
|
|
|
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
|
|
assert full_calls == 0, (
|
|
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
|
|
)
|
|
assert chunked_calls > 0, (
|
|
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
|
|
)
|
|
finally:
|
|
set_norm_limit(None)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# var_attention backend tests (test_seedvr_var_attention_backends.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_var_attention_registry_contains_always_available_entries():
|
|
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
|
|
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
|
|
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
|
|
|
|
|
|
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
|
dim = 8
|
|
heads = 2
|
|
head_dim = 4
|
|
attn = seedvr_model.NaSwinAttention(
|
|
vid_dim=dim,
|
|
txt_dim=dim,
|
|
heads=heads,
|
|
head_dim=head_dim,
|
|
qk_bias=False,
|
|
qk_norm=seedvr_model.CustomRMSNorm,
|
|
qk_norm_eps=1e-6,
|
|
rope_type=None,
|
|
rope_dim=head_dim,
|
|
shared_weights=False,
|
|
window=(2, 1, 1),
|
|
window_method="720pwin_by_size_bysize",
|
|
version=True,
|
|
device="cpu",
|
|
dtype=torch.float32,
|
|
operations=comfy_ops.disable_weight_init,
|
|
)
|
|
generator = torch.Generator(device="cpu").manual_seed(11)
|
|
vid = torch.randn(8, dim, generator=generator)
|
|
txt = torch.randn(3, dim, generator=generator)
|
|
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
|
|
txt_shape = torch.tensor([[3]], dtype=torch.long)
|
|
calls = []
|
|
|
|
def fake_optimized_var_attention(**kwargs):
|
|
calls.append(kwargs)
|
|
return kwargs["q"]
|
|
|
|
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
|
|
|
|
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
|
|
|
|
assert tuple(vid_out.shape) == (8, dim)
|
|
assert tuple(txt_out.shape) == (3, dim)
|
|
assert len(calls) == 1
|
|
call = calls[0]
|
|
assert tuple(call["q"].shape) == (14, heads, head_dim)
|
|
assert tuple(call["k"].shape) == (14, heads, head_dim)
|
|
assert tuple(call["v"].shape) == (14, heads, head_dim)
|
|
assert call["heads"] == heads
|
|
assert call["skip_reshape"] is True
|
|
assert call["skip_output_reshape"] is True
|
|
torch.testing.assert_close(
|
|
call["cu_seqlens_q"],
|
|
torch.tensor([0, 7, 14], dtype=torch.int32),
|
|
rtol=0,
|
|
atol=0,
|
|
)
|
|
torch.testing.assert_close(
|
|
call["cu_seqlens_k"],
|
|
torch.tensor([0, 7, 14], dtype=torch.int32),
|
|
rtol=0,
|
|
atol=0,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# var_attention_pytorch SeedVR2 guard tests
|
|
# (test_var_attention_pytorch_seedvr2_guard.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _pytorch_guard_inputs():
|
|
heads, head_dim, total_tokens = 2, 8, 6
|
|
embed_dim = heads * head_dim
|
|
q = torch.randn(total_tokens, embed_dim)
|
|
k = torch.randn(total_tokens, embed_dim)
|
|
v = torch.randn(total_tokens, embed_dim)
|
|
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
|
|
return q, k, v, heads, cu, cu, total_tokens, embed_dim
|
|
|
|
|
|
def _assert_guard_source_pin():
|
|
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
|
|
tree = ast.parse(src)
|
|
raise_lines = []
|
|
nested_lines = []
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
|
|
func = node.exc.func
|
|
if isinstance(func, ast.Name) and func.id == "RuntimeError":
|
|
raise_lines.append(node.lineno)
|
|
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
|
|
nested_lines.append(node.lineno)
|
|
assert raise_lines, (
|
|
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
|
|
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
|
|
)
|
|
assert nested_lines, (
|
|
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
|
|
f"attribute access; cannot pin guard ordering.\n"
|
|
f"--- source ---\n{src}"
|
|
)
|
|
first_raise = min(raise_lines)
|
|
first_nested = min(nested_lines)
|
|
assert first_raise < first_nested, (
|
|
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
|
|
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
|
|
f"at line {first_nested}; the guard must precede the lookup.\n"
|
|
f"--- source ---\n{src}"
|
|
)
|
|
|
|
|
|
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
|
|
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
|
|
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
|
|
|
|
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
|
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
|
|
|
_assert_guard_source_pin()
|
|
|
|
|
|
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
|
|
monkeypatch.delattr(torch, "nested", raising=False)
|
|
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
|
|
|
|
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
|
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
|
|
|
_assert_guard_source_pin()
|
|
|
|
|
|
def test_present_api_returns_expected_shape():
|
|
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs()
|
|
|
|
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
|
|
old_torch_fx_level = torch_fx_logger.level
|
|
torch_fx_logger.setLevel(logging.ERROR)
|
|
try:
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="The PyTorch API of nested tensors is in prototype stage.*",
|
|
category=UserWarning,
|
|
)
|
|
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
|
finally:
|
|
torch_fx_logger.setLevel(old_torch_fx_level)
|
|
|
|
assert tuple(out.shape) == (total_tokens, embed_dim), (
|
|
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
|
|
)
|
|
|
|
_assert_guard_source_pin()
|
|
|
|
|
|
def test_malformed_offsets_propagates_torch_runtime_error():
|
|
q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs()
|
|
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
|
|
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
|
|
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
|
|
|
|
msg = str(exc_info.value)
|
|
assert "SeedVR2" not in msg
|
|
|
|
_assert_guard_source_pin()
|