"""Consolidated SeedVR2 internals regression tests. Sources (all merged verbatim, helper names disambiguated where colliding): * RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy apply_rotary_emb wrapper oracle at fp32. * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare memory_occupy against get_norm_limit(), not float('inf'). * var_attention backend registry. * var_attention_pytorch SeedVR2-named guard — present-API shape contract with AST-level pinning of the guard ordering. Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and comfy.ldm.modules.attention transitively pull in comfy.model_management, which probes torch.cuda.current_device() at import time unless args.cpu is set first. """ from __future__ import annotations import ast import inspect import logging import textwrap import warnings from unittest.mock import patch import pytest import torch from comfy.cli_args import args if not torch.cuda.is_available(): args.cpu = True import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 import comfy.ldm.modules.attention as attention # noqa: E402 import comfy.ops as comfy_ops # noqa: E402 from comfy.ldm.seedvr.model import ( # noqa: E402 Cache, NaMMRotaryEmbedding3d, ) from comfy.ldm.seedvr.vae import ( # noqa: E402 causal_norm_wrapper, set_norm_limit, ) from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402 # --------------------------------------------------------------------------- # RoPE rewrite tests (test_seedvr_rope_rewrite.py) # --------------------------------------------------------------------------- # Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains # requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8. _DIM = 192 _HEADS = 4 _VID_T, _VID_H, _VID_W = 2, 4, 4 _TXT_L = 8 _L_VID = _VID_T * _VID_H * _VID_W _SEED = 0 def _make_inputs(dtype=torch.float32, device="cpu"): """Construct the 6 forward inputs + cache. Deterministic via local Generator so global RNG state is not mutated. """ g = torch.Generator(device=device).manual_seed(_SEED) vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device) txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device) cache = Cache(disable=True) return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape): """Reproduce the pre-rewrite ``get_freqs`` body verbatim against ``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method, unchanged by the rewrite). """ max_temporal = 0 max_height = 0 max_width = 0 max_txt_len = 0 for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): max_temporal = max(max_temporal, l + f) max_height = max(max_height, h) max_width = max(max_width, w) max_txt_len = max(max_txt_len, l) with torch.amp.autocast(device_type="cuda", enabled=False): vid_freqs_full = rope.get_axial_freqs( min(max_temporal + 16, 1024), min(max_height + 4, 128), min(max_width + 4, 128), ).float() txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024)) vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1)) txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1)) vid_freq_list.append(vid_freq) txt_freq_list.append(txt_freq) return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape): """Compute expected forward output via the unchanged ``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the oracle. The wrapper itself is out of scope for the rewrite (Shape B). """ vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape) vid_freqs = vid_freqs.to(vid_q.device) txt_freqs = txt_freqs.to(txt_q.device) from einops import rearrange vid_q = rearrange(vid_q, "L h d -> h L d") vid_k = rearrange(vid_k, "L h d -> h L d") vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) vid_q_out = rearrange(vid_q_out, "h L d -> L h d") vid_k_out = rearrange(vid_k_out, "h L d -> L h d") txt_q = rearrange(txt_q, "L h d -> h L d") txt_k = rearrange(txt_k, "L h d -> h L d") txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) txt_q_out = rearrange(txt_q_out, "h L d -> L h d") txt_k_out = rearrange(txt_k_out, "h L d -> L h d") return vid_q_out, vid_k_out, txt_q_out, txt_k_out def test_namm_forward_output_tensor_equal_against_legacy_oracle(): rope = NaMMRotaryEmbedding3d(dim=_DIM) vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( rope, vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape, ) actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape, cache, ) torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, msg="vid_q output diverges from wrapper oracle") torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, msg="vid_k output diverges from wrapper oracle") torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, msg="txt_q output diverges from wrapper oracle") torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, msg="txt_k output diverges from wrapper oracle") # --------------------------------------------------------------------------- # GroupNorm limit tests (test_seedvr_groupnorm_limit.py) # --------------------------------------------------------------------------- _NUM_CHANNELS = 8 _NUM_GROUPS = 4 _TENSOR_SHAPE = (1, 8, 2, 4, 4) _GROUPNORM_SUBCLASSES = [ pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"), pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"), ] @pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): real_group_norm = vae_mod.F.group_norm set_norm_limit(1e-9) try: gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS) gn.eval() forward_hook_calls = [] def _hook(module, inputs, output): forward_hook_calls.append(tuple(inputs[0].shape)) spy_calls = [] def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): spy_calls.append({"num_groups": int(num_groups_arg)}) return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) handle = gn.register_forward_hook(_hook) try: with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) finally: handle.remove() full_calls = len(forward_hook_calls) chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS) assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE assert full_calls == 0, ( f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}" ) assert chunked_calls > 0, ( f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}" ) finally: set_norm_limit(None) # --------------------------------------------------------------------------- # var_attention backend tests (test_seedvr_var_attention_backends.py) # --------------------------------------------------------------------------- def test_var_attention_registry_contains_always_available_entries(): assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): dim = 8 heads = 2 head_dim = 4 attn = seedvr_model.NaSwinAttention( vid_dim=dim, txt_dim=dim, heads=heads, head_dim=head_dim, qk_bias=False, qk_norm=seedvr_model.CustomRMSNorm, qk_norm_eps=1e-6, rope_type=None, rope_dim=head_dim, shared_weights=False, window=(2, 1, 1), window_method="720pwin_by_size_bysize", version=True, device="cpu", dtype=torch.float32, operations=comfy_ops.disable_weight_init, ) generator = torch.Generator(device="cpu").manual_seed(11) vid = torch.randn(8, dim, generator=generator) txt = torch.randn(3, dim, generator=generator) vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long) txt_shape = torch.tensor([[3]], dtype=torch.long) calls = [] def fake_optimized_var_attention(**kwargs): calls.append(kwargs) return kwargs["q"] monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention) vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True)) assert tuple(vid_out.shape) == (8, dim) assert tuple(txt_out.shape) == (3, dim) assert len(calls) == 1 call = calls[0] assert tuple(call["q"].shape) == (14, heads, head_dim) assert tuple(call["k"].shape) == (14, heads, head_dim) assert tuple(call["v"].shape) == (14, heads, head_dim) assert call["heads"] == heads assert call["skip_reshape"] is True assert call["skip_output_reshape"] is True torch.testing.assert_close( call["cu_seqlens_q"], torch.tensor([0, 7, 14], dtype=torch.int32), rtol=0, atol=0, ) torch.testing.assert_close( call["cu_seqlens_k"], torch.tensor([0, 7, 14], dtype=torch.int32), rtol=0, atol=0, ) # --------------------------------------------------------------------------- # var_attention_pytorch SeedVR2 guard tests # (test_var_attention_pytorch_seedvr2_guard.py) # --------------------------------------------------------------------------- def _pytorch_guard_inputs(): heads, head_dim, total_tokens = 2, 8, 6 embed_dim = heads * head_dim q = torch.randn(total_tokens, embed_dim) k = torch.randn(total_tokens, embed_dim) v = torch.randn(total_tokens, embed_dim) cu = torch.tensor([0, 3, 6], dtype=torch.int32) return q, k, v, heads, cu, cu, total_tokens, embed_dim def _assert_guard_source_pin(): src = textwrap.dedent(inspect.getsource(var_attention_pytorch)) tree = ast.parse(src) raise_lines = [] nested_lines = [] for node in ast.walk(tree): if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call): func = node.exc.func if isinstance(func, ast.Name) and func.id == "RuntimeError": raise_lines.append(node.lineno) if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged": nested_lines.append(node.lineno) assert raise_lines, ( "var_attention_pytorch has no `raise RuntimeError(...)` AST node; " f"the SeedVR2-named guard is missing.\n--- source ---\n{src}" ) assert nested_lines, ( "var_attention_pytorch source has no `nested_tensor_from_jagged` " f"attribute access; cannot pin guard ordering.\n" f"--- source ---\n{src}" ) first_raise = min(raise_lines) first_nested = min(nested_lines) assert first_raise < first_nested, ( f"`raise RuntimeError(...)` first appears at line {first_raise}, " f"but `torch.nested.nested_tensor_from_jagged` is referenced first " f"at line {first_nested}; the guard must precede the lookup.\n" f"--- source ---\n{src}" ) def test_missing_api_raises_seedvr2_runtime_error(monkeypatch): monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False) q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs() with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"): var_attention_pytorch(q, k, v, heads, cu_q, cu_k) _assert_guard_source_pin() def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch): monkeypatch.delattr(torch, "nested", raising=False) q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs() with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"): var_attention_pytorch(q, k, v, heads, cu_q, cu_k) _assert_guard_source_pin() def test_present_api_returns_expected_shape(): q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs() torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace") old_torch_fx_level = torch_fx_logger.level torch_fx_logger.setLevel(logging.ERROR) try: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="The PyTorch API of nested tensors is in prototype stage.*", category=UserWarning, ) out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k) finally: torch_fx_logger.setLevel(old_torch_fx_level) assert tuple(out.shape) == (total_tokens, embed_dim), ( f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}" ) _assert_guard_source_pin() def test_malformed_offsets_propagates_torch_runtime_error(): q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs() cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32) cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32) with pytest.raises(RuntimeError) as exc_info: var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok) msg = str(exc_info.value) assert "SeedVR2" not in msg _assert_guard_source_pin()