Add SeedVR2 core coverage

This commit is contained in:
John Pollock 2026-05-25 22:12:12 -05:00
parent 6e5186ddac
commit 9eb6c7fe9e
15 changed files with 2727 additions and 0 deletions

View File

@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd():
return sd return sd
def _make_seedvr2_7b_separate_mm_sd():
return {
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
}
def _make_seedvr2_7b_shared_mm_sd():
return {
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
def _make_seedvr2_3b_shared_mm_sd():
return {
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
class TestModelDetection: class TestModelDetection:
"""Verify that first-match model detection selects the correct model """Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity.""" based on list ordering and unet_config specificity."""
@ -125,6 +143,46 @@ class TestModelDetection:
assert model_config is not None assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell" assert type(model_config).__name__ == "FluxSchnell"
def test_seedvr2_7b_separate_mm_detection_config(self):
sd = _make_seedvr2_7b_separate_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 36
assert unet_config["mlp_type"] == "normal"
assert unet_config["qk_rope"] is True
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_7b_shared_mm_detection_config(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 10
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is True
def test_seedvr2_3b_shared_mm_detection_config(self):
sd = _make_seedvr2_3b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 2560
assert unet_config["heads"] == 20
assert unet_config["num_layers"] == 32
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is None
def test_unet_config_and_required_keys_combination_is_unique(self): def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of """Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same ``unet_config`` and ``required_keys``. If two models share the same

View File

@ -0,0 +1,192 @@
"""Regression tests for SeedVR2 conditioning split hardening.
Two bare ``except:`` clauses in ``NaDiT.forward`` previously swallowed
every failure mode on (1) the input-side text-conditioning split and
(2) the output-side positive/negative split, silently substituting
wrong fallbacks: the ``positive_conditioning`` buffer (which prior to
explicit zero-init held **uninitialized** memory NaN, residual heap
contents, never guaranteed-zero) for the input, and the un-split
tensor for the output. Real prompt-shape, dtype, OOM, and downstream
tensor failures were re-routed to "no prompt supplied" with arbitrary
buffer contents standing in for actual prompt embeddings, or to a
wrong-order output, with no diagnostic.
The fix:
1. Input-side: explicit absence predicate (``context is None`` or
``context.numel() == 0``) fall back to ``positive_conditioning``
buffer. Any other failure (wrong rank, odd batch, dtype, OOM)
propagates the original torch exception.
2. Output-side: no try/except at all. ``out.chunk(2)`` of the
network output is a contract: an unsplittable result is a bug,
not a recoverable condition.
The two blocks were extracted into named private methods on
``NaDiT`` (``_resolve_text_conditioning`` and ``_swap_pos_neg_halves``)
so the regression evidence drives the actual production code paths
without standing up a full transformer. The methods are called from
``forward`` exactly where the original try/except blocks lived.
"""
from comfy.cli_args import args
import torch
if not torch.cuda.is_available():
args.cpu = True
import ast # noqa: E402
import inspect # noqa: E402
import textwrap # noqa: E402
import pytest # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
def _make_standin(positive_conditioning):
class _StandIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"positive_conditioning", positive_conditioning
)
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
_swap_pos_neg_halves = NaDiT._swap_pos_neg_halves
return _StandIn()
def test_no_bare_except_in_forward_path():
"""Source-level pin: neither ``NaDiT.forward`` nor its split helpers
may carry the bare ``except:`` clauses that swallowed real torch
failures on the SeedVR2 conditioning paths. AST-walked rather than
substring-matched so that ``except:`` appearing in a docstring or
comment does not false-positive, and so that ``except Exception:``
(a typed handler, fine to have) does not false-negative.
"""
sources = [
inspect.getsource(NaDiT.forward),
inspect.getsource(NaDiT._resolve_text_conditioning),
inspect.getsource(NaDiT._swap_pos_neg_halves),
]
for src in sources:
tree = ast.parse(textwrap.dedent(src))
for node in ast.walk(tree):
if isinstance(node, ast.ExceptHandler):
assert node.type is not None, (
"Bare 'except:' (ast.ExceptHandler with type=None) "
f"must not appear on the SeedVR2 forward path:\n{src}"
)
def test_valid_context_splits_pos_neg():
"""AC: valid (neg, pos)-stacked context (shape ``(2, L, C)``)
produces a flattened ``[pos, neg]`` text tensor first ``L`` rows
are positive, next ``L`` rows are negative matching the original
semantics of the ``flatten([pos_cond, neg_cond])`` call.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
seq_len, channels = 7, 5120
neg = torch.full((1, seq_len, channels), -1.0)
pos = torch.full((1, seq_len, channels), 1.0)
context = torch.cat([neg, pos], dim=0)
txt, txt_shape = standin._resolve_text_conditioning(context)
assert txt.shape == (2 * seq_len, channels)
assert (txt[:seq_len] == 1.0).all(), "first half must be positive cond"
assert (txt[seq_len:] == -1.0).all(), "second half must be negative cond"
assert txt_shape.shape == (2, 1)
assert txt_shape[0].item() == seq_len
assert txt_shape[1].item() == seq_len
def test_missing_context_falls_back_to_positive_buffer():
"""AC: ``context is None`` falls back to the registered
``positive_conditioning`` buffer and runs to completion no
silent zero substitution, no raised exception.
"""
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
assert txt.shape == (58, 5120)
assert (txt == 7.0).all(), (
"fallback path must use the positive_conditioning buffer "
"verbatim, not a zero tensor"
)
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
def test_empty_context_falls_back_to_positive_buffer():
"""AC: ``context.numel() == 0`` falls back to the registered
``positive_conditioning`` buffer and runs to completion.
"""
pos_buffer = torch.full((58, 5120), 13.0)
standin = _make_standin(pos_buffer)
empty = torch.empty((0, 5120))
assert empty.numel() == 0
txt, txt_shape = standin._resolve_text_conditioning(empty)
assert txt.shape == (58, 5120)
assert (txt == 13.0).all()
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
def test_wrong_rank_context_raises_original_torch_exception():
"""AC: a 1-D context tensor cannot be split into ``[pos, neg]``
via the ``chunk + squeeze + flatten`` chain; the original torch
exception must propagate rather than silently falling back.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
bad = torch.zeros(10)
with pytest.raises((RuntimeError, IndexError, ValueError)):
standin._resolve_text_conditioning(bad)
def test_odd_batch_context_raises_original_exception():
"""AC: a context whose batch dim cannot be split into two equal
chunks (here batch=1 so ``chunk(2, dim=0)`` returns a single
tensor) must propagate the original exception no silent fallback.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
bad = torch.zeros((1, 7, 5120))
with pytest.raises((RuntimeError, ValueError)):
standin._resolve_text_conditioning(bad)
def test_output_side_misshaped_tensor_raises():
"""AC: the post-network output split must raise on an unsplittable
tensor (no silent return of the un-split tensor in the wrong
order/shape). Here a batch=1 tensor cannot be ``chunk(2, dim=0)``
into two halves; ``pos, neg = out.chunk(2, dim=0)`` raises on
unpacking matching the production helper's explicit-dim contract
(``_swap_pos_neg_halves`` calls ``chunk(2, dim=0)`` and
``torch.cat(..., dim=0)``).
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
bad_out = torch.zeros((1, 4, 8, 8))
with pytest.raises((RuntimeError, ValueError)):
standin._swap_pos_neg_halves(bad_out)
def test_output_side_swaps_pos_neg_halves():
"""AC complement: ``_swap_pos_neg_halves`` reorders the post-network
output so the first half (positive) and second half (negative) trade
places. For a 2-batch tensor with distinguishable halves, the
returned tensor must be the swap first half becomes negative,
second half becomes positive matching the original
``torch.cat([neg, pos])`` semantics from the pre-fix forward path.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
pos_half = torch.full((1, 4, 8, 8), 1.0)
neg_half = torch.full((1, 4, 8, 8), -1.0)
out = torch.cat([pos_half, neg_half], dim=0)
swapped = standin._swap_pos_neg_halves(out)
assert swapped.shape == out.shape
assert (swapped[0] == -1.0).all(), "first half of swapped output must be the original negative half"
assert (swapped[1] == 1.0).all(), "second half of swapped output must be the original positive half"

View File

@ -0,0 +1,124 @@
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
honor the actual tensor/tuple return contract of ``encode()`` and
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
or ``.sample`` attributes on those returns.
The pre-fix body raised ``AttributeError: 'Tensor' object has no
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
for ``mode == "decode"`` (the class only defines ``decode_`` with a
trailing underscore). The post-fix body unwraps the optional one-element
tuple shape that ``return_dict=False`` produces and returns the tensor
directly.
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
overrides ``encode``/``decode_`` with known tensors so the contract can
be probed without loading any real VAE weights.
"""
import inspect
import re
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
_LATENT_SHAPE = (1, 16, 2, 2, 2)
_DECODED_SHAPE = (1, 3, 5, 16, 16)
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
class _StubVAE(VideoAutoencoderKL):
def __init__(self):
nn.Module.__init__(self)
self._encode_out = torch.zeros(*_LATENT_SHAPE)
self._decode_out = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return self._encode_out
def decode_(self, z, return_dict=True):
return self._decode_out
def test_forward_encode_returns_tensor():
vae = _StubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="encode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_LATENT_SHAPE)
def test_forward_decode_returns_tensor():
vae = _StubVAE()
z = torch.zeros(*_INPUT_DECODE_SHAPE)
result = vae.forward(z, mode="decode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
def test_forward_all_returns_tensor():
vae = _StubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
def test_forward_source_has_no_diffusers_attr_access():
src = inspect.getsource(VideoAutoencoderKL.forward)
assert ".latent_dist" not in src
assert ".sample" not in src
assert re.search(r"self\.decode\(", src) is None
class _TupleReturningStubVAE(VideoAutoencoderKL):
"""Stub variant whose ``encode``/``decode_`` return the
``(tensor,)`` one-element tuple shape ``return_dict=False`` produces
in the parent class. Exercises the unwrap branch of
``VideoAutoencoderKL.forward``.
"""
def __init__(self):
nn.Module.__init__(self)
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return (self._encode_tensor,)
def decode_(self, z, return_dict=True):
return (self._decode_tensor,)
def test_forward_encode_unwraps_one_tuple():
vae = _TupleReturningStubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="encode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_LATENT_SHAPE)
def test_forward_decode_unwraps_one_tuple():
vae = _TupleReturningStubVAE()
z = torch.zeros(*_INPUT_DECODE_SHAPE)
result = vae.forward(z, mode="decode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
def test_forward_all_unwraps_one_tuple_at_each_step():
vae = _TupleReturningStubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)

View File

@ -0,0 +1,63 @@
import inspect
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
VideoAutoencoderKLWrapper = vae_mod.VideoAutoencoderKLWrapper
_INPUT_SHAPE = (1, 3, 5, 16, 16)
_POSTERIOR_SHAPE = (1, 16, 1, 2, 2)
_DECODE_OUT_SHAPE = (1, 3, 5, 16, 16)
def _build_wrapper_standin() -> VideoAutoencoderKLWrapper:
wrapper = VideoAutoencoderKLWrapper.__new__(VideoAutoencoderKLWrapper)
nn.Module.__init__(wrapper)
return wrapper
def test_wrapper_forward_returns_tensor_triple(monkeypatch):
wrapper = _build_wrapper_standin()
wrapper.original_image_video = torch.zeros(*_INPUT_SHAPE)
wrapper.img_dims = (16, 16)
wrapper.freeze_encoder = True
posterior = torch.full(_POSTERIOR_SHAPE, 7.0)
decode_out = torch.full(_DECODE_OUT_SHAPE, 13.0)
def stub_encode(self, x, orig_dims=None):
return posterior.squeeze(2), posterior
def stub_decode(self, z):
return decode_out
monkeypatch.setattr(VideoAutoencoderKLWrapper, "encode", stub_encode)
monkeypatch.setattr(VideoAutoencoderKLWrapper, "decode", stub_decode)
x = torch.zeros(*_INPUT_SHAPE)
result = wrapper.forward(x)
assert isinstance(result, tuple)
assert len(result) == 3
x_out, z, p = result
assert type(x_out) is torch.Tensor
assert type(z) is torch.Tensor
assert type(p) is torch.Tensor
assert x_out.shape == decode_out.shape
assert z.shape == posterior.squeeze(2).shape
assert torch.equal(x_out, decode_out)
assert torch.equal(z, posterior.squeeze(2))
assert p is posterior
def test_wrapper_forward_source_has_no_sample_access():
src = inspect.getsource(VideoAutoencoderKLWrapper.forward)
assert ".sample" not in src

View File

@ -0,0 +1,105 @@
"""Regression tests for the diffusers-format guard inside ``comfy.sd.VAE.__init__``.
The guard previously indexed ``metadata["keep_diffusers_format"]`` directly,
raising ``KeyError`` when ``metadata`` was non-``None`` but lacked that key. The
fixed guard uses ``metadata.get("keep_diffusers_format") != "true"``: a missing
key flows through to ``convert_vae_state_dict``; the explicit ``"true"`` value
bypasses it.
Five cells exercise every reachable shape of the guard input missing key,
explicit ``"true"``, ``None``, explicit non-``"true"``, empty dict and halt
the constructor at the first post-guard call (``model_management.is_amd``).
``_make_standin`` borrows ``__init__`` onto a bare class, mirroring
``seedvr_model_test.py::_make_standin`` (#109). ``_exercise_guard`` single-
sources the patched-constructor harness so the cells stay synchronised.
"""
from comfy.cli_args import args
import torch
if not torch.cuda.is_available():
args.cpu = True
import contextlib # noqa: E402
import unittest.mock # noqa: E402
import comfy.sd # noqa: E402
_DIFFUSERS_TRIGGER_KEY = "decoder.up_blocks.0.resnets.0.norm1.weight"
class _PostGuardReached(Exception):
"""Sentinel raised by the patched ``is_amd`` to halt ``__init__`` at the first post-guard statement."""
def _make_standin():
class _StandIn:
__init__ = comfy.sd.VAE.__init__
return _StandIn
def _exercise_guard(metadata):
"""Drive ``VAE.__init__`` with the diffusers trigger key and the supplied
``metadata``; halt at ``is_amd``. Returns ``(mock_convert, mock_is_amd)``
for branch (call_count) + reach (called) assertions per cell.
"""
StandIn = _make_standin()
sd = {_DIFFUSERS_TRIGGER_KEY: torch.zeros(1)}
with unittest.mock.patch.object(
comfy.sd.diffusers_convert,
"convert_vae_state_dict",
autospec=True,
side_effect=lambda state_dict: state_dict,
) as mock_convert, unittest.mock.patch.object(
comfy.sd.model_management,
"is_amd",
autospec=True,
side_effect=_PostGuardReached("post-guard reached"),
) as mock_is_amd:
with contextlib.suppress(_PostGuardReached):
StandIn(sd=sd, metadata=metadata)
return mock_convert, mock_is_amd
def test_diffusers_guard_invokes_convert_when_metadata_missing_key():
"""AC1: metadata is non-None but lacks ``keep_diffusers_format`` → convert is invoked."""
mock_convert, mock_is_amd = _exercise_guard({"unrelated_key": "value"})
assert mock_convert.call_count == 1
assert mock_is_amd.called
def test_diffusers_guard_skips_convert_when_metadata_pins_keep_true():
"""AC2: metadata pins ``keep_diffusers_format == "true"`` → convert is skipped."""
mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "true"})
assert mock_convert.call_count == 0
assert mock_is_amd.called
def test_diffusers_guard_invokes_convert_when_metadata_is_none():
"""AC3: metadata is ``None`` → first disjunct fires, convert is invoked."""
mock_convert, mock_is_amd = _exercise_guard(None)
assert mock_convert.call_count == 1
assert mock_is_amd.called
def test_diffusers_guard_invokes_convert_when_metadata_pins_keep_false():
"""AC4: metadata pins a non-``"true"`` value → second disjunct fires, convert is invoked."""
mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "false"})
assert mock_convert.call_count == 1
assert mock_is_amd.called
def test_diffusers_guard_invokes_convert_when_metadata_is_empty_dict():
"""AC5: metadata is ``{}`` (the ``convert_old_quants`` None→{} normalization shape) → convert is invoked."""
mock_convert, mock_is_amd = _exercise_guard({})
assert mock_convert.call_count == 1
assert mock_is_amd.called

View File

@ -0,0 +1,503 @@
import inspect
import logging
import warnings
from pathlib import Path
from types import SimpleNamespace
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.modules.attention as attention
import comfy.sd
import comfy.supported_models
import comfy.ldm.seedvr.model as seedvr_model
def test_set_model_config_inference_dtype_preserves_legacy_signature():
calls = []
class LegacyConfig:
def set_inference_dtype(self, dtype, manual_cast_dtype):
calls.append((dtype, manual_cast_dtype))
comfy.sd._set_model_config_inference_dtype(LegacyConfig(), torch.float16, None, object())
assert calls == [(torch.float16, None)]
def test_set_model_config_inference_dtype_passes_device_when_supported():
calls = []
device = object()
class DeviceAwareConfig:
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
calls.append((dtype, manual_cast_dtype, device))
comfy.sd._set_model_config_inference_dtype(DeviceAwareConfig(), torch.float16, None, device)
assert calls == [(torch.float16, None, device)]
def test_set_model_config_inference_dtype_passes_device_to_kwargs_override():
calls = []
device = object()
class KwargsConfig:
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
calls.append((dtype, manual_cast_dtype, kwargs))
comfy.sd._set_model_config_inference_dtype(KwargsConfig(), torch.float16, None, device)
assert calls == [(torch.float16, None, {"device": device})]
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
bf16_device = object()
fp16_device = object()
monkeypatch.setattr(
comfy.supported_models.comfy.model_management,
"should_use_bf16",
lambda device=None: device is bf16_device,
)
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
assert bf16_config.manual_cast_dtype is torch.bfloat16
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
assert fp16_config.manual_cast_dtype is None
def test_apply_rope1_partial_preserves_full_rotation_input_dtype(monkeypatch):
def fake_apply_rope1(t, freqs_cis):
return t.float() + 1.0
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
t = torch.arange(8, dtype=torch.float16).reshape(1, 2, 4)
original = t.clone()
freqs_cis = torch.zeros(1, 2, 2, 2)
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
assert out.dtype is torch.float16
torch.testing.assert_close(out, (original.float() + 1.0).to(torch.float16))
def test_apply_rope1_partial_preserves_partial_rotation_input_dtype(monkeypatch):
def fake_apply_rope1(t, freqs_cis):
return t.float() + 1.0
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
t = torch.arange(12, dtype=torch.float16).reshape(1, 2, 6)
original = t.clone()
freqs_cis = torch.zeros(1, 2, 2, 2)
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
assert out.dtype is torch.float16
torch.testing.assert_close(
out[..., :4],
(original[..., :4].float() + 1.0).to(torch.float16),
)
torch.testing.assert_close(out[..., 4:], original[..., 4:])
def test_apply_rope1_partial_chunks_sequence_dimension(monkeypatch):
calls = []
def fake_apply_rope1(t, freqs_cis):
calls.append(t.shape[-2])
return t.float() + 1.0
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
monkeypatch.setattr(seedvr_model, "_ROPE1_PARTIAL_CHUNK_TOKENS", 2)
t = torch.arange(30, dtype=torch.float16).reshape(1, 5, 6)
original = t.clone()
freqs_cis = torch.zeros(5, 2, 2, 2)
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
assert calls == [2, 2, 1]
torch.testing.assert_close(out[..., :4], (original[..., :4].float() + 1.0).to(torch.float16))
torch.testing.assert_close(out[..., 4:], original[..., 4:])
def test_apply_rope1_partial_clones_training_tensor(monkeypatch):
def fake_apply_rope1(t, freqs_cis):
return t + 1.0
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
base = torch.arange(12, dtype=torch.float32, requires_grad=True)
t = base.reshape(1, 2, 6)
original = t.clone()
freqs_cis = torch.zeros(2, 2, 2, 2)
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
out.sum().backward()
assert out is not t
torch.testing.assert_close(t, original)
torch.testing.assert_close(out[..., :4], original[..., :4] + 1.0)
torch.testing.assert_close(out[..., 4:], original[..., 4:])
assert base.grad is not None
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
torch.testing.assert_close(txt, context.squeeze(0))
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
def test_seedvr2_text_conditioning_accepts_batched_cfg1_single_branch():
context = torch.arange(12, dtype=torch.float32).reshape(2, 3, 2)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
torch.testing.assert_close(txt, context.flatten(0, -2))
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device))
def test_seedvr2_text_conditioning_accepts_multi_entry_cfg1_single_branch():
context = torch.arange(12, dtype=torch.float32).reshape(2, 3, 2)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0, 0])
torch.testing.assert_close(txt, context.flatten(0, -2))
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device))
def test_seedvr2_text_conditioning_preserves_two_branch_swap_contract():
neg = torch.full((1, 3, 2), -1.0)
pos = torch.full((1, 3, 2), 1.0)
context = torch.cat([neg, pos], dim=0)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context)
torch.testing.assert_close(txt[:3], pos.squeeze(0))
torch.testing.assert_close(txt[3:], neg.squeeze(0))
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device))
def test_seedvr2_text_conditioning_preserves_batched_two_branch_swap_contract():
neg = torch.full((2, 3, 2), -1.0)
pos = torch.full((2, 3, 2), 1.0)
context = torch.cat([neg, pos], dim=0)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [1, 0])
torch.testing.assert_close(txt[:6], pos.flatten(0, -2))
torch.testing.assert_close(txt[6:], neg.flatten(0, -2))
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3], [3], [3]], device=context.device))
def test_seedvr2_cfg1_single_branch_output_is_not_swapped():
out = torch.arange(6, dtype=torch.float32).reshape(1, 6)
swapped = seedvr_model.NaDiT._swap_pos_neg_halves(object(), out, [0])
torch.testing.assert_close(swapped, out)
def test_seedvr2_multi_entry_cfg1_output_is_not_swapped():
out = torch.arange(12, dtype=torch.float32).reshape(2, 6)
swapped = seedvr_model.NaDiT._swap_pos_neg_halves(object(), out, [0, 0])
torch.testing.assert_close(swapped, out)
def test_seedvr2_conditioning_keeps_comfy_cfg1_optimization_enabled():
source = (Path(__file__).resolve().parents[2] / "comfy_extras" / "nodes_seedvr.py").read_text(encoding="utf-8")
assert "disable_model_cfg1_optimization()" not in source
def test_seedvr2_split_var_attention_matches_nested_var_attention():
torch.manual_seed(1)
q = torch.randn(5, 2, 4)
k = torch.randn(7, 2, 4)
v = torch.randn(7, 2, 4)
cu_q = torch.tensor([0, 2, 5], dtype=torch.int32)
cu_k = torch.tensor([0, 3, 7], dtype=torch.int32)
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
old_torch_fx_level = torch_fx_logger.level
torch_fx_logger.setLevel(logging.ERROR)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The PyTorch API of nested tensors is in prototype stage.*",
category=UserWarning,
)
nested = attention.var_attention_pytorch(
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
skip_reshape=True, skip_output_reshape=True,
)
finally:
torch_fx_logger.setLevel(old_torch_fx_level)
split = attention.var_attention_pytorch_split(
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
skip_reshape=True, skip_output_reshape=True,
)
torch.testing.assert_close(split, nested, rtol=1e-5, atol=1e-5)
def test_seedvr2_split_var_attention_preserves_flat_output_shape():
torch.manual_seed(2)
q = torch.randn(5, 8)
k = torch.randn(7, 8)
v = torch.randn(7, 8)
cu_q = torch.tensor([0, 1, 5], dtype=torch.int32)
cu_k = torch.tensor([0, 2, 7], dtype=torch.int32)
nested = attention.var_attention_pytorch(
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
)
split = attention.var_attention_pytorch_split(
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
)
assert split.shape == q.shape
torch.testing.assert_close(split, nested, rtol=1e-5, atol=1e-5)
def test_seedvr2_split_var_attention_rejects_mismatched_sequence_count():
q = torch.randn(5, 2, 4)
k = torch.randn(7, 2, 4)
v = torch.randn(7, 2, 4)
cu_q = torch.tensor([0, 2, 5], dtype=torch.int32)
cu_k = torch.tensor([0, 3, 5, 7], dtype=torch.int32)
try:
attention.var_attention_pytorch_split(
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
skip_reshape=True, skip_output_reshape=True,
)
except ValueError as exc:
assert "same sequence count" in str(exc)
else:
raise AssertionError("mismatched cu_seqlens sequence counts must fail")
def test_seedvr2_split_var_attention_rejects_malformed_offsets():
q = torch.randn(5, 2, 4)
k = torch.randn(7, 2, 4)
v = torch.randn(7, 2, 4)
cu_k = torch.tensor([0, 3, 7], dtype=torch.int32)
malformed_cases = (
(torch.tensor([1, 2, 5], dtype=torch.int32), "start at 0"),
(torch.tensor([0, 2, 2, 5], dtype=torch.int32), "strictly increasing"),
(torch.tensor([0.0, 2.0, 5.0], dtype=torch.float32), "integer dtype"),
)
for cu_q, message in malformed_cases:
try:
attention.var_attention_pytorch_split(
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
skip_reshape=True, skip_output_reshape=True,
)
except ValueError as exc:
assert message in str(exc)
else:
raise AssertionError("malformed cu_seqlens must fail")
def test_seedvr2_7b_window_attention_handles_mm_rope_source():
source = inspect.getsource(seedvr_model.NaSwinAttention.forward)
assert "if self.rope.mm" in source
assert "txt_q_repeat" in source
def test_seedvr2_7b_window_attention_routes_to_split_var_attention():
source = inspect.getsource(seedvr_model.NaSwinAttention.forward)
assert "_seedvr2_7b_window_attention_split" in source
assert "if self.version_7b" in source
def test_seedvr2_7b_window_attention_split_matches_concat_path():
torch.manual_seed(3)
vid_len_win = torch.tensor([1, 2, 3], dtype=torch.int64)
txt_len = torch.tensor([2, 3], dtype=torch.int64)
window_count = torch.tensor([2, 1], dtype=torch.int64)
heads = 2
dim = 4
vid_total = int(vid_len_win.sum().item())
txt_total = int(txt_len.sum().item())
vid_q = torch.randn(vid_total, heads, dim)
vid_k = torch.randn(vid_total, heads, dim)
vid_v = torch.randn(vid_total, heads, dim)
txt_q = torch.randn(txt_total, heads, dim)
txt_k = torch.randn(txt_total, heads, dim)
txt_v = torch.randn(txt_total, heads, dim)
concat_win, unconcat_win = seedvr_model.repeat_concat_idx(vid_len_win, txt_len, window_count)
all_len_win = vid_len_win + txt_len.repeat_interleave(window_count)
cu_seqlens = torch.nn.functional.pad(all_len_win.cumsum(0), (1, 0)).int()
concat_out = attention.var_attention_pytorch_split(
concat_win(vid_q, txt_q),
concat_win(vid_k, txt_k),
concat_win(vid_v, txt_v),
heads=heads,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
skip_reshape=True,
skip_output_reshape=True,
)
expected_vid, expected_txt = unconcat_win(concat_out)
split_vid, split_txt = seedvr_model._seedvr2_7b_window_attention_split(
vid_q, txt_q, vid_k, txt_k, vid_v, txt_v,
vid_len_win, txt_len, window_count,
)
torch.testing.assert_close(split_vid, expected_vid, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(split_txt, expected_txt, rtol=1e-5, atol=1e-5)
def test_seedvr2_7b_window_attention_split_preserves_autograd():
torch.manual_seed(4)
vid_len_win = torch.tensor([1, 2, 3], dtype=torch.int64)
txt_len = torch.tensor([2, 3], dtype=torch.int64)
window_count = torch.tensor([2, 1], dtype=torch.int64)
heads = 2
dim = 4
vid_total = int(vid_len_win.sum().item())
txt_total = int(txt_len.sum().item())
vid_q = torch.randn(vid_total, heads, dim, requires_grad=True)
vid_k = torch.randn(vid_total, heads, dim, requires_grad=True)
vid_v = torch.randn(vid_total, heads, dim, requires_grad=True)
txt_q = torch.randn(txt_total, heads, dim, requires_grad=True)
txt_k = torch.randn(txt_total, heads, dim, requires_grad=True)
txt_v = torch.randn(txt_total, heads, dim, requires_grad=True)
split_vid, split_txt = seedvr_model._seedvr2_7b_window_attention_split(
vid_q, txt_q, vid_k, txt_k, vid_v, txt_v,
vid_len_win, txt_len, window_count,
)
(split_vid.sum() + split_txt.sum()).backward()
for tensor in (vid_q, vid_k, vid_v, txt_q, txt_k, txt_v):
assert tensor.grad is not None
def test_seedvr2_7b_mlp_chunks_video_tokens(monkeypatch):
class TrackingModule(torch.nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
self.calls = []
def forward(self, x):
self.calls.append(x.shape[0])
return x * self.scale
monkeypatch.setattr(seedvr_model, "SEEDVR2_7B_MLP_CHUNK", 2)
vid_module = TrackingModule(2.0)
txt_module = TrackingModule(3.0)
block = SimpleNamespace(
mlp=SimpleNamespace(
shared_weights=False,
vid_only=False,
vid=vid_module,
txt=txt_module,
)
)
vid = torch.arange(24, dtype=torch.float32).reshape(6, 4)
txt = torch.arange(12, dtype=torch.float32).reshape(3, 4)
out_vid, out_txt = seedvr_model.NaMMSRTransformerBlock._seedvr2_7b_mlp(block, vid, txt)
assert vid_module.calls == [2, 2, 2]
assert txt_module.calls == [3]
torch.testing.assert_close(out_vid, vid * 2.0)
torch.testing.assert_close(out_txt, txt * 3.0)
def test_seedvr2_7b_mlp_preserves_video_autograd(monkeypatch):
class TrackingModule(torch.nn.Module):
def forward(self, x):
return x * 2.0
monkeypatch.setattr(seedvr_model, "SEEDVR2_7B_MLP_CHUNK", 2)
block = SimpleNamespace(
mlp=SimpleNamespace(
shared_weights=False,
vid_only=True,
vid=TrackingModule(),
)
)
vid_base = torch.arange(24, dtype=torch.float32, requires_grad=True)
vid = vid_base.reshape(6, 4)
txt = torch.arange(12, dtype=torch.float32).reshape(3, 4)
out_vid, _ = seedvr_model.NaMMSRTransformerBlock._seedvr2_7b_mlp(block, vid, txt)
out_vid.sum().backward()
assert vid_base.grad is not None
def test_seedvr2_7b_block_routes_mlp_to_chunk_helper():
source = inspect.getsource(seedvr_model.NaMMSRTransformerBlock.forward)
assert "if self.version" in source
assert "_seedvr2_7b_mlp" in source
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
assert estimate == 101 * 960 * 1280 * 160
assert estimate > 15 * 1024 ** 3
assert estimate > old_estimate * 100
def test_seedvr2_vae_decode_memory_estimate_is_per_sample():
single = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
batch = comfy.sd._seedvr2_vae_decode_memory_used((2, 16, 26, 120, 160))
assert batch == single
def test_seedvr2_vae_decode_memory_accepts_channel_last_tiled_latents():
channel_first = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
channel_last = comfy.sd._seedvr2_vae_decode_memory_used((1, 26, 120, 160, 16))
assert channel_last == channel_first
def test_seedvr2_vae_decode_memory_rounds_malformed_collapsed_channels_up():
malformed = comfy.sd._seedvr2_vae_decode_memory_used((1, 17, 120, 160))
expected = comfy.sd._seedvr2_vae_decode_output_pixels(2, 120, 160) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
assert malformed == expected
def test_seedvr2_vae_decode_memory_uses_conservative_ambiguous_5d_layout():
ambiguous = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 120, 160, 16))
channel_first = comfy.sd._seedvr2_vae_decode_output_pixels(120, 160, 16) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
channel_last = comfy.sd._seedvr2_vae_decode_output_pixels(16, 120, 160) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
assert ambiguous == max(channel_first, channel_last)

View File

@ -0,0 +1,218 @@
from __future__ import annotations
import torch
from torch import nn
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
class _StubModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
flags = []
class _Block(_StubModule):
def __init__(self, *args, **kwargs):
flags.append(kwargs["is_last_layer"])
super().__init__()
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
seedvr_model.NaDiT(
norm_eps=1e-5,
qk_rope=None,
num_layers=4,
mlp_type="normal",
vid_dim=vid_dim,
txt_in_dim=txt_in_dim,
heads=24,
mm_layers=3,
)
return flags
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False,
False,
False,
False,
]
def test_seedvr2_3b_keeps_final_block_vid_only_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=2560, txt_in_dim=2560) == [
False,
False,
False,
True,
]
def _capture_block_attention_rope_type(monkeypatch, qk_rope):
rope_types = []
class _Attention(_StubModule):
def __init__(self, *args, **kwargs):
rope_types.append(kwargs["rope_type"])
super().__init__()
monkeypatch.setattr(seedvr_model, "MMModule", _StubModule)
monkeypatch.setattr(seedvr_model, "NaSwinAttention", _Attention)
seedvr_model.NaMMSRTransformerBlock(
vid_dim=4,
txt_dim=4,
emb_dim=4,
heads=1,
head_dim=4,
expand_ratio=1,
norm=_StubModule,
norm_eps=1e-5,
ada=_StubModule,
qk_bias=False,
qk_rope=qk_rope,
qk_norm=_StubModule,
mlp_type="normal",
shared_weights=False,
rope_type="mmrope3d",
rope_dim=4,
is_last_layer=False,
device="cpu",
dtype=torch.float32,
operations=seedvr_model.comfy.ops.disable_weight_init,
)
return rope_types
def test_seedvr2_3b_qk_rope_none_preserves_checkpoint_rope_buffers(monkeypatch):
assert _capture_block_attention_rope_type(monkeypatch, qk_rope=None) == ["mmrope3d"]
def test_seedvr2_7b_qk_rope_true_preserves_attention_rope(monkeypatch):
assert _capture_block_attention_rope_type(monkeypatch, qk_rope=True) == ["mmrope3d"]
def test_seedvr2_7b_rope3d_matches_checkpoint_buffer_shape():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
assert isinstance(rope, seedvr_model.NaRotaryEmbedding3d)
assert tuple(rope.rope.freqs.shape) == (10,)
def test_seedvr2_7b_rope3d_preserves_qk_shape():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
q = torch.randn(4, 2, 128)
k = torch.randn(4, 2, 128)
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
q_out, k_out = rope(q, k, shape, seedvr_model.Cache(disable=True))
assert q_out.shape == q.shape
assert k_out.shape == k.shape
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
generator = torch.Generator(device="cpu").manual_seed(0)
q = torch.randn(4, 2, 128, generator=generator)
k = torch.randn(4, 2, 128, generator=generator)
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
expected_q = seedvr_model.apply_rotary_emb(
freqs,
q.permute(1, 0, 2).float(),
).to(q.dtype).permute(1, 0, 2)
expected_k = seedvr_model.apply_rotary_emb(
freqs,
k.permute(1, 0, 2).float(),
).to(k.dtype).permute(1, 0, 2)
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
def test_seedvr2_mmrope_handles_large_spatial_grid_without_truncation():
rope = seedvr_model.NaMMRotaryEmbedding3d(dim=12)
vid_shape = torch.tensor([[1, 129, 130]], dtype=torch.long)
txt_shape = torch.tensor([[2]], dtype=torch.long)
vid_tokens = int(vid_shape.prod().item())
txt_tokens = int(txt_shape.prod().item())
vid_q = torch.zeros(vid_tokens, 1, 12)
vid_k = torch.zeros_like(vid_q)
txt_q = torch.zeros(txt_tokens, 1, 12)
txt_k = torch.zeros_like(txt_q)
out = rope(vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, seedvr_model.Cache(disable=True))
assert [tuple(t.shape) for t in out] == [
tuple(vid_q.shape),
tuple(vid_k.shape),
tuple(txt_q.shape),
tuple(txt_k.shape),
]
def test_adasingle_init_preserves_supported_dtype():
ada = seedvr_model.AdaSingle(
dim=4,
emb_dim=24,
layers=["test"],
modes=["in", "out"],
device="cpu",
dtype=torch.bfloat16,
)
assert ada.test_shift.dtype is torch.bfloat16
assert ada.test_scale.dtype is torch.bfloat16
assert ada.test_gate.dtype is torch.bfloat16
def test_adasingle_init_uses_default_dtype_for_fp8():
if not hasattr(torch, "float8_e4m3fn"):
return
ada = seedvr_model.AdaSingle(
dim=4,
emb_dim=24,
layers=["test"],
modes=["in", "out"],
device="cpu",
dtype=torch.float8_e4m3fn,
)
assert ada.test_shift.dtype is torch.float32
assert ada.test_scale.dtype is torch.float32
assert ada.test_gate.dtype is torch.float32
def test_adasingle_init_and_forward_share_fp8_dtype_set():
expected = {
getattr(torch, name)
for name in (
"float8_e4m3fn",
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
)
if hasattr(torch, name)
}
assert set(seedvr_model._torch_float8_types()) == expected

View File

@ -0,0 +1,54 @@
from comfy.cli_args import args
import torch
if not torch.cuda.is_available():
args.cpu = True
import ast # noqa: E402
import inspect # noqa: E402
from torch import nn # noqa: E402
import comfy # noqa: E402
import comfy.ldm.seedvr.model # noqa: E402
import comfy.model_management # noqa: E402
from comfy.ldm.seedvr.model import MMModule # noqa: E402
def test_no_get_torch_device_in_forward_methods():
tree = ast.parse(inspect.getsource(comfy.ldm.seedvr.model))
assert [
(n.lineno, i.lineno)
for n in ast.walk(tree)
if isinstance(n, ast.FunctionDef) and n.name == "forward"
for i in ast.walk(n)
if isinstance(i, ast.Call)
and isinstance(i.func, ast.Attribute)
and i.func.attr == "get_torch_device"
] == []
def test_mmmodule_forward_succeeds_without_get_torch_device_lookup(monkeypatch):
call_count = [0]
def boom():
call_count[0] += 1
raise RuntimeError("MMModule.forward called get_torch_device()")
monkeypatch.setattr(comfy.model_management, "get_torch_device", boom)
class _IdentityCallable(nn.Module):
def forward(self, x, *args, **kwargs):
return x
mm = MMModule(_IdentityCallable, shared_weights=False, vid_only=False)
vid_in = torch.zeros(2, 4)
txt_in = torch.ones(2, 4)
vid_out, txt_out = mm.forward(vid_in, txt_in)
assert call_count[0] == 0
assert torch.equal(vid_out, vid_in)
assert torch.equal(txt_out, txt_in)
assert vid_out.device == vid_in.device
assert txt_out.device == txt_in.device

View File

@ -0,0 +1,179 @@
"""Regression: ``comfy.ldm.seedvr.vae.causal_norm_wrapper`` 5D GroupNorm
gate at ``vae.py:509`` must compare ``memory_occupy`` against the configured
``get_norm_limit()`` accessor, not against a hardcoded ``float('inf')``.
The original code path was ``... > float('inf')`` which is unreachable at any
finite ``memory_occupy`` value, so SeedVR2's ``norm_max_mem`` setting (wired
through ``set_norm_limit``) had no effect.
This module locks in two complementary cases against any future regression,
parametrized over both ``ops.GroupNorm`` subclasses (``disable_weight_init`` and
``manual_cast``) since the production gate ``isinstance(norm_layer, ops.GroupNorm)``
matches both.
* ``test_seedvr_groupnorm_default_limit_uses_full_groupnorm_path`` with
the limit at its default ``inf``, the full GroupNorm forward must run and
the chunked branch must NOT run, regardless of input tensor size.
* ``test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path`` with a
deliberately low limit (``1e-9 GiB``), the chunked branch must run and
the full GroupNorm forward must NOT run.
Each case discriminates the two branches with two independent observers:
1. ``nn.Module.register_forward_hook`` on the GroupNorm fires only on the
full-path branch ``norm_layer(x)``; the chunked branch bypasses the
module ``__call__`` and goes through ``F.group_norm`` directly.
2. ``unittest.mock.patch.object(vae.F, 'group_norm', ...)`` spy with
``side_effect`` delegating to the real ``torch.nn.functional.group_norm``
captures every direct ``F.group_norm`` call's ``num_groups`` argument.
Calls with ``num_groups < gn.num_groups`` come from the chunked branch
(``num_groups_per_chunk = gn.num_groups // num_chunks``).
The spy uses ``*args, **kwargs`` passthrough so future ``F.group_norm`` kwargs
do not break the test.
CPU-only by construction: the tests use a small float32 tensor and never
allocate a real model or GPU memory.
"""
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ops as comfy_ops # noqa: E402
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy.ldm.seedvr.vae import ( # noqa: E402
causal_norm_wrapper,
set_norm_limit,
)
_NUM_CHANNELS = 8
_NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
# Both ``ops.GroupNorm`` subclasses appear in production paths depending on
# the active backend. The dispatch gate at ``vae.py:509`` reads
# ``isinstance(norm_layer, ops.GroupNorm)`` and matches both via MRO.
_GROUPNORM_SUBCLASSES = [
pytest.param(
comfy_ops.disable_weight_init.GroupNorm,
id="disable_weight_init",
),
pytest.param(
comfy_ops.manual_cast.GroupNorm,
id="manual_cast",
),
]
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
def test_seedvr_groupnorm_default_limit_uses_full_groupnorm_path(groupnorm_cls):
real_group_norm = vae_mod.F.group_norm
set_norm_limit(None)
try:
gn = groupnorm_cls(
num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS
)
gn.eval()
forward_hook_calls = []
def _hook(module, inputs, output):
forward_hook_calls.append(tuple(inputs[0].shape))
spy_calls = []
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
spy_calls.append({
"num_groups": int(num_groups_arg),
"input_shape": tuple(int(s) for s in input_tensor.shape),
})
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
handle = gn.register_forward_hook(_hook)
try:
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
finally:
handle.remove()
full_calls = len(forward_hook_calls)
chunked_calls = sum(
1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS
)
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE, (
f"causal_norm_wrapper output shape {tuple(out_tensor.shape)} does not "
f"match input shape {_TENSOR_SHAPE}"
)
assert full_calls == 1, (
f"default-limit (inf) GroupNorm gate must take the full-forward path "
f"(register_forward_hook fires exactly once); got full_calls={full_calls}"
)
assert chunked_calls == 0, (
f"default-limit (inf) GroupNorm gate must NOT take the chunked path "
f"(no F.group_norm call with num_groups<{_NUM_GROUPS}); got "
f"chunked_calls={chunked_calls}"
)
finally:
set_norm_limit(None)
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
real_group_norm = vae_mod.F.group_norm
set_norm_limit(1e-9)
try:
gn = groupnorm_cls(
num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS
)
gn.eval()
forward_hook_calls = []
def _hook(module, inputs, output):
forward_hook_calls.append(tuple(inputs[0].shape))
spy_calls = []
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
spy_calls.append({
"num_groups": int(num_groups_arg),
"input_shape": tuple(int(s) for s in input_tensor.shape),
})
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
handle = gn.register_forward_hook(_hook)
try:
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
finally:
handle.remove()
full_calls = len(forward_hook_calls)
chunked_calls = sum(
1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS
)
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE, (
f"causal_norm_wrapper output shape {tuple(out_tensor.shape)} does not "
f"match input shape {_TENSOR_SHAPE}"
)
assert full_calls == 0, (
f"low-limit GroupNorm gate must NOT take the full-forward path "
f"(register_forward_hook should not fire); got full_calls={full_calls}"
)
assert chunked_calls > 0, (
f"low-limit GroupNorm gate must take the chunked path "
f"(at least one F.group_norm call with num_groups<{_NUM_GROUPS}); got "
f"chunked_calls={chunked_calls}"
)
finally:
set_norm_limit(None)

View File

@ -0,0 +1,40 @@
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.latent_formats
import comfy.sample
class _Model:
def __init__(self, latent_format):
self._latent_format = latent_format
def get_model_object(self, name):
assert name == "latent_format"
return self._latent_format
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == 16
assert latent_format.latent_dimensions == 2
assert fixed.shape == (1, 16, 4, 5)
def test_seedvr2_empty_collapsed_latent_preserves_temporal_channel_multiples():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 48, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.preserve_empty_channel_multiples is True
assert fixed.shape == latent_image.shape
assert fixed.data_ptr() == latent_image.data_ptr()

View File

@ -0,0 +1,176 @@
"""Regression test: ``comfy.ldm.seedvr.model.apply_rotary_emb`` must delegate
to ``comfy.ldm.flux.math.apply_rope1`` and produce exact-equality output
across the wrapper's slicing, scaling, and concatenation logic. Drift between
the wrapper and the delegate would silently corrupt SeedVR2's RoPE; this test
fails loudly on any future drift.
Each parametrized case does both:
1. Patches ``comfy.ldm.seedvr.model.apply_rope1`` with a ``wraps``-style spy
and asserts ``spy.call_count >= 1`` so a future change that inlines the
math and stops calling ``apply_rope1`` fails the test.
2. Compares the wrapper's output against a hand-rolled reproduction using
``torch.testing.assert_close(rtol=0, atol=0)`` -- exact tensor equality,
not bit-equality (``+0.0`` vs ``-0.0`` and NaN payloads can still match);
the assertion catches any future kernel-precision drift in the
``apply_rope1`` dispatch.
The test uses a local ``torch.Generator`` so global RNG state is not mutated.
Parametrization covers non-default ``start_index`` and ``scale`` and a case
where ``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's
``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised.
Imports are taken at module level. Heavy-import stubbing of
``comfy.model_management`` was attempted but is insufficient on this live
import chain (``comfy.ldm.seedvr.model`` pulls
``comfy.ldm.modules.diffusionmodules.model -> comfy.ops ->
comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor ->
torch._dynamo``), so this test intentionally runs against the real modules
to fail loudly if that import path or runtime state drifts. Other tests in
this repo (e.g. ``tests-unit/comfy_extras_test/image_stitch_test.py``) do
stub via ``patch.dict(sys.modules, ...)`` for narrower targets; the choice
here is local to this regression and not a repo-wide convention.
"""
from unittest.mock import patch
import pytest
import torch
# CPU-only CI fix: ``comfy.ldm.seedvr.model`` transitively imports
# ``comfy.model_management``, whose import-time ``get_torch_device()`` call
# probes ``torch.cuda.current_device()`` unless ``comfy.cli_args.args.cpu`` is
# set. On a CPU-only build that probe can raise during test collection before
# the ``cuda`` case has had a chance to be skipped. Match the pattern used by
# ``tests-unit/comfy_quant/test_mixed_precision.py``: flip ``args.cpu`` before
# importing any ``comfy.ldm.*`` symbol.
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
from comfy.ldm.flux.math import apply_rope1 # noqa: E402
from comfy.ldm.seedvr.model import apply_rotary_emb # noqa: E402
def _direct_reproduction(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
"""Reproduce the body of ``apply_rotary_emb`` for the default case where
``freqs.ndim == 2`` and ``t.ndim == 3`` (implicit ``freqs_seq_dim=0``).
Mirrors the wrapper's ``slice_at_dim(freqs, slice(-seq_len, None), dim=0)``
step when freqs is longer than ``t`` along ``seq_dim``. Calls the real
``apply_rope1`` via the test module's import (the test patches the
``seedvr_model.apply_rope1`` attribute; this call uses the unpatched
``flux.math`` symbol).
"""
if freqs.ndim == 2 and t.ndim == 3:
seq_len = t.shape[seq_dim]
freqs = freqs[-seq_len:]
rot_feats = freqs.shape[-1]
end_index = start_index + rot_feats
t_left = t[..., :start_index]
t_middle = t[..., start_index:end_index]
t_right = t[..., end_index:]
angles = freqs.to(t_middle.device)[..., ::2]
cos = torch.cos(angles) * scale
sin = torch.sin(angles) * scale
col0 = torch.stack([cos, sin], dim=-1)
col1 = torch.stack([-sin, cos], dim=-1)
freqs_mat = torch.stack([col0, col1], dim=-1)
t_middle_out = apply_rope1(t_middle, freqs_mat)
return torch.cat((t_left, t_middle_out, t_right), dim=-1).type(t.dtype)
def _cpu_trig_supported(dtype):
"""Return whether ``torch.cos`` (and by symmetry ``torch.sin``) is
implemented for the given dtype on CPU on the current runtime. Some
PyTorch CPU wheels don't implement trig ops for ``float16`` / ``bfloat16``
and raise at runtime; the parametrized cases for those dtypes are skipped
when that's the case so CI remains stable across PyTorch builds.
"""
try:
torch.cos(torch.zeros(1, dtype=dtype))
except (RuntimeError, TypeError):
return False
return True
_CPU_FP16_TRIG_OK = _cpu_trig_supported(torch.float16)
_CPU_BF16_TRIG_OK = _cpu_trig_supported(torch.bfloat16)
# (device, dtype, t_shape, freqs_shape, start_index, scale)
_CASES = [
pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 1.0,
id="cpu-float32-base"),
pytest.param(
"cpu", torch.float16, (1, 8, 16), (8, 16), 0, 1.0,
id="cpu-float16-base",
marks=pytest.mark.skipif(
not _CPU_FP16_TRIG_OK,
reason="torch.cos/torch.sin unsupported for float16 tensors on CPU",
),
),
pytest.param(
"cpu", torch.bfloat16, (1, 8, 16), (8, 16), 0, 1.0,
id="cpu-bfloat16-base",
marks=pytest.mark.skipif(
not _CPU_BF16_TRIG_OK,
reason="torch.cos/torch.sin unsupported for bfloat16 tensors on CPU",
),
),
pytest.param("cpu", torch.float32, (2, 16, 32), (16, 32), 0, 1.0,
id="cpu-float32-larger"),
pytest.param("cpu", torch.float32, (1, 8, 24), (8, 16), 4, 1.0,
id="cpu-float32-non-empty-left-and-right-slices"),
pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 0.5,
id="cpu-float32-non-default-scale"),
pytest.param("cpu", torch.float32, (1, 8, 16), (12, 16), 0, 1.0,
id="cpu-float32-freqs-longer-than-seq"),
pytest.param(
"cuda", torch.float16, (1, 8, 16), (8, 16), 0, 1.0,
id="cuda-float16-base",
marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda"),
),
]
@pytest.mark.parametrize("device,dtype,t_shape,freqs_shape,start_index,scale", _CASES)
def test_apply_rotary_emb_delegates_to_apply_rope1(
device, dtype, t_shape, freqs_shape, start_index, scale
):
generator = torch.Generator(device=device).manual_seed(0)
t = torch.randn(*t_shape, dtype=dtype, device=device, generator=generator)
freqs = torch.randn(*freqs_shape, dtype=dtype, device=device, generator=generator)
# Patch the apply_rope1 symbol as imported into seedvr.model with a wraps
# spy: a future change that inlines the math and stops calling the
# imported apply_rope1 makes spy.call_count == 0 and fails the test.
with patch.object(
seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1
) as spy:
wrapper_out = apply_rotary_emb(
freqs, t, start_index=start_index, scale=scale
)
assert spy.call_count >= 1, (
"apply_rotary_emb did not call comfy.ldm.seedvr.model.apply_rope1; "
"the delegation invariant is broken"
)
direct_out = _direct_reproduction(
freqs, t, start_index=start_index, scale=scale
)
msg = (
f"apply_rotary_emb output does not match direct apply_rope1 "
f"reproduction (device={device}, dtype={dtype}, t_shape={t_shape}, "
f"freqs_shape={freqs_shape}, start_index={start_index}, scale={scale})"
)
torch.testing.assert_close(
wrapper_out,
direct_out,
rtol=0,
atol=0,
msg=msg,
)

View File

@ -0,0 +1,335 @@
"""Regression tests for the SeedVR2 native RoPE rewrite that replaces the
``apply_rotary_emb`` wrapper inside ``NaMMRotaryEmbedding3d.forward`` with
direct calls to ``comfy.ldm.flux.math.apply_rope1`` matching the pattern
used by the other 7 ComfyUI native-DiT models (flux, hidream, kandinsky5,
lumina, qwen_image, wan, sam3).
The wrapper builds a 2x2 ``freqs_mat`` and ends in ``torch.cat((t_left,
t_middle_out, t_right), dim=-1)``; that cat OOMs on the largest cell of the
SeedVR2 native_3b non-tiled corpus (VideoLQ_000 1280x960x100 on RTX 5090
32GB). Canonical and numz pass the same cell because both call
``rotary_embedding_torch.apply_rotary_emb`` directly. The fix moves the
NaMMRotaryEmbedding3d path onto ``apply_rope1`` directly with freqs in
flux-canonical shape ``[..., d/2, 2, 2]`` (cos/-sin/sin/cos baked in).
This test file pins four invariants the rewrite must satisfy:
1. ``NaMMRotaryEmbedding3d.forward`` calls ``apply_rope1`` 4 times per
forward (vid_q, vid_k, txt_q, txt_k) and 0 times into the
``apply_rotary_emb`` wrapper.
2. ``NaMMRotaryEmbedding3d.get_freqs`` returns freqs in flux-canonical shape
``[..., d/2, 2, 2]`` with the cos/-sin/sin/cos pattern from
``comfy/ldm/flux/math.py:rope`` (line 27).
3. The forward output is tensor-equal at fp32 against an oracle computed
from the unchanged ``apply_rotary_emb`` wrapper fed with the legacy
freqs layout proving the rewrite is algorithmically lossless.
4. AST: no ``apply_rotary_emb`` call sites remain inside
``NaMMRotaryEmbedding3d.forward``.
The wrapper itself stays in the file (still used by
``RotaryEmbedding3d.forward`` lines 434-435 and the staticmethod
registration on lucidrains' ``RotaryEmbedding`` line 323). Out of scope
here.
Pre-import CPU-only guard mirrors ``test_seedvr_rope_delegation.py``
``comfy.ldm.seedvr.model`` transitively imports ``comfy.model_management``
which probes ``torch.cuda.current_device()`` at import time unless
``args.cpu`` is set first.
"""
from __future__ import annotations
import ast
import inspect
from pathlib import Path
from unittest.mock import patch
import torch
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
from comfy.ldm.seedvr.model import ( # noqa: E402
Cache,
NaMMRotaryEmbedding3d,
)
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
# heads = 4. These are all small enough to run on CPU in milliseconds.
_DIM = 192
_HEADS = 4
_VID_T, _VID_H, _VID_W = 2, 4, 4
_TXT_L = 8
_L_VID = _VID_T * _VID_H * _VID_W
_SEED = 0
def _make_inputs(dtype=torch.float32, device="cpu"):
"""Construct the 6 forward inputs + cache. Deterministic via local
Generator so global RNG state is not mutated.
"""
g = torch.Generator(device=device).manual_seed(_SEED)
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
cache = Cache(disable=True)
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
unchanged by the rewrite). Used by Test 3 to compute the oracle from
the wrapper path post-rewrite, when ``rope.get_freqs`` itself returns
the new flux-canonical shape.
"""
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
with torch.amp.autocast(device_type="cuda", enabled=False):
vid_freqs_full = rope.get_axial_freqs(
min(max_temporal + 16, 1024),
min(max_height + 4, 128),
min(max_width + 4, 128),
).float()
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
vid_freq_list.append(vid_freq)
txt_freq_list.append(txt_freq)
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
txt_q, txt_k, txt_shape):
"""Compute expected forward output via the unchanged
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
"""
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
vid_freqs = vid_freqs.to(vid_q.device)
txt_freqs = txt_freqs.to(txt_q.device)
from einops import rearrange
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
txt_q = rearrange(txt_q, "L h d -> h L d")
txt_k = rearrange(txt_k, "L h d -> h L d")
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
# Test 1 — drives AC-4 (call-graph): forward must reach apply_rope1 directly,
# never via the apply_rotary_emb wrapper.
def test_namm_forward_calls_apply_rope1_directly():
rope = NaMMRotaryEmbedding3d(dim=_DIM)
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
with patch.object(
seedvr_model, "apply_rotary_emb", wraps=seedvr_model.apply_rotary_emb
) as wrapper_spy, patch.object(
seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1
) as rope1_spy:
rope.forward(vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache)
assert wrapper_spy.call_count == 0, (
f"NaMMRotaryEmbedding3d.forward must not call apply_rotary_emb "
f"(saw {wrapper_spy.call_count} calls); the rewrite must rewire "
f"the 4 forward sites to apply_rope1 directly"
)
assert rope1_spy.call_count == 4, (
f"NaMMRotaryEmbedding3d.forward must call apply_rope1 exactly 4 "
f"times (vid_q, vid_k, txt_q, txt_k); saw {rope1_spy.call_count}"
)
# Test 2 — drives the get_freqs shape change to flux-canonical layout.
def test_get_freqs_emits_flux_canonical_shape():
rope = NaMMRotaryEmbedding3d(dim=_DIM)
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long)
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long)
vid_freqs, txt_freqs = rope.get_freqs(vid_shape, txt_shape)
# Flux's `rope()` (comfy/ldm/flux/math.py:17-29) emits freqs in shape
# [..., d/2, 2, 2] via stack([cos, -sin, sin, cos], dim=-1) +
# rearrange("b n d (i j) -> b n d i j", i=2, j=2). The rewrite must
# match: ndim >= 4, last two dims both == 2.
assert vid_freqs.ndim >= 4, (
f"vid_freqs.ndim must be >= 4 (flux-canonical layout has trailing "
f"[..., d/2, 2, 2]); got ndim={vid_freqs.ndim}, shape={tuple(vid_freqs.shape)}"
)
assert vid_freqs.shape[-1] == 2, (
f"vid_freqs.shape[-1] must be 2 (rotation matrix column); got "
f"shape={tuple(vid_freqs.shape)}"
)
assert vid_freqs.shape[-2] == 2, (
f"vid_freqs.shape[-2] must be 2 (rotation matrix row); got "
f"shape={tuple(vid_freqs.shape)}"
)
assert txt_freqs.ndim >= 4, (
f"txt_freqs must also be flux-canonical; got ndim={txt_freqs.ndim}, "
f"shape={tuple(txt_freqs.shape)}"
)
assert txt_freqs.shape[-1] == 2 and txt_freqs.shape[-2] == 2, (
f"txt_freqs trailing dims must be (2, 2); got shape={tuple(txt_freqs.shape)}"
)
# Verify the cos/-sin/sin/cos pattern at index 0:
# freqs_cis[..., 0, 0] = cos
# freqs_cis[..., 0, 1] = -sin
# freqs_cis[..., 1, 0] = sin
# freqs_cis[..., 1, 1] = cos
# so [0,0] == [1,1] (both cos) and [0,1] == -[1,0] (=-sin vs +sin).
cos_a = vid_freqs[..., 0, 0]
cos_b = vid_freqs[..., 1, 1]
neg_sin = vid_freqs[..., 0, 1]
sin = vid_freqs[..., 1, 0]
assert torch.allclose(cos_a, cos_b, rtol=0, atol=0), (
"vid_freqs[..., 0, 0] must equal vid_freqs[..., 1, 1] (both = cos)"
)
assert torch.allclose(neg_sin, -sin, rtol=0, atol=0), (
"vid_freqs[..., 0, 1] must equal -vid_freqs[..., 1, 0] (= -sin vs +sin)"
)
# Test 3 — drives AC-1: forward output is tensor-equal against the wrapper-
# fed oracle. Pre-rewrite: trivially passes (forward IS the wrapper path).
# Post-rewrite: must remain equal. Exact equality (rtol=atol=0) at fp32.
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
rope = NaMMRotaryEmbedding3d(dim=_DIM)
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
# Oracle: the unchanged apply_rotary_emb wrapper fed with legacy-shape
# freqs produced by reproducing the pre-rewrite get_freqs body.
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
rope,
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape,
)
# Actual: NaMMRotaryEmbedding3d.forward (under test).
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape, cache,
)
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
msg="vid_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
msg="vid_k output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
msg="txt_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
msg="txt_k output diverges from wrapper oracle")
# Test 5 — partial-rope coverage. The real SeedVR2-3B model is constructed
# with rope_dim=128, which integer-divides into 3 axes as 128//3 = 42 per-
# axis; total rope freq dims = 42*3 = 126. head_dim is 128, so the trailing
# 2 dims of each q/k must be passed through unrotated (matching the legacy
# wrapper's `t_right = t[..., end_index:]` behavior). The fp32-CPU oracle
# test (Test 3) uses dim=192 where rot_d == head_dim and the partial-rope
# path collapses to a single apply_rope1 call. This test exercises the
# partial path explicitly with dim=128 and asserts the rewired forward
# still tensor-equals the wrapper oracle in that regime.
def test_namm_forward_partial_rope_passthrough_matches_wrapper_oracle():
rope = NaMMRotaryEmbedding3d(dim=128)
g = torch.Generator(device="cpu").manual_seed(_SEED)
vid_q = torch.randn(_L_VID, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
vid_k = torch.randn(_L_VID, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
txt_q = torch.randn(_TXT_L, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
txt_k = torch.randn(_TXT_L, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long)
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long)
cache = Cache(disable=True)
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
rope, vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape,
)
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape, cache,
)
# Confirm the partial-rope contract: rot_d (= 2 * freqs_cis.shape[-3]) is
# 126 (= 42*3), strictly less than head_dim 128. The trailing 2 head-dims
# are pass-through.
vid_freqs, _ = rope.get_freqs(vid_shape, txt_shape)
rot_d = 2 * vid_freqs.shape[-3]
assert rot_d == 126, f"expected rot_d=126 for dim=128 model; got {rot_d}"
assert rot_d < 128, "partial-rope path must trigger (rot_d < head_dim)"
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
msg="vid_q partial-rope output diverges from wrapper oracle")
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
msg="vid_k partial-rope output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
msg="txt_q partial-rope output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
msg="txt_k partial-rope output diverges from wrapper oracle")
# Test 4 — drives AC-4 statically: AST walk over NaMMRotaryEmbedding3d.forward
# must find zero references to the apply_rotary_emb symbol.
def test_namm_forward_ast_has_no_apply_rotary_emb_calls():
source_path = Path(inspect.getsourcefile(NaMMRotaryEmbedding3d))
tree = ast.parse(source_path.read_text(encoding="utf-8"))
namm_class = None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "NaMMRotaryEmbedding3d":
namm_class = node
break
assert namm_class is not None, (
f"could not locate class NaMMRotaryEmbedding3d in {source_path}"
)
forward_fn = None
for node in namm_class.body:
if isinstance(node, ast.FunctionDef) and node.name == "forward":
forward_fn = node
break
assert forward_fn is not None, (
"could not locate NaMMRotaryEmbedding3d.forward"
)
offending = []
for node in ast.walk(forward_fn):
if isinstance(node, ast.Name) and node.id == "apply_rotary_emb":
offending.append((node.lineno, node.col_offset))
assert not offending, (
f"NaMMRotaryEmbedding3d.forward must not reference apply_rotary_emb; "
f"found {len(offending)} reference(s) at line:col positions {offending}. "
f"The rewrite must rewire to apply_rope1 directly."
)

View File

@ -0,0 +1,37 @@
from unittest.mock import patch
import torch
from torch import nn
import comfy.ldm.seedvr.vae as seedvr_vae
def test_seedvr_vae_4d_self_attention_uses_vae_attention_with_channel_first_layout():
calls = {}
def vae_attention_spy(q, k, v):
calls["q"] = q.detach().clone()
calls["k"] = k.detach().clone()
calls["v"] = v.detach().clone()
return q
def global_attention_forbidden(*args, **kwargs):
raise AssertionError("SeedVR2 VAE self-attention must not use global optimized_attention")
with patch.object(seedvr_vae, "vae_attention", return_value=vae_attention_spy):
attention = seedvr_vae.Attention(query_dim=4, heads=1, dim_head=4)
attention.to_q = nn.Identity()
attention.to_k = nn.Identity()
attention.to_v = nn.Identity()
attention.to_out[0] = nn.Identity()
hidden_states = torch.arange(24, dtype=torch.float32).reshape(1, 4, 2, 3)
with patch.object(seedvr_vae, "optimized_attention", global_attention_forbidden):
output = attention(hidden_states)
assert torch.equal(calls["q"], hidden_states)
assert torch.equal(calls["k"], hidden_states)
assert torch.equal(calls["v"], hidden_states)
assert torch.equal(output, hidden_states)

View File

@ -0,0 +1,476 @@
import subprocess
import sys
import textwrap
import ast
import inspect
import torch
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.modules.attention as attention # noqa: E402
_VAR_BACKENDS = (
"var_attention_sage",
"var_attention_sage3",
"var_attention_flash",
"var_attention_flash3",
"var_attention_sub_quad",
"var_attention_split",
)
def _inputs():
heads = 2
head_dim = 4
total = 6
q = torch.randn(total, heads, head_dim)
k = torch.randn(total, heads, head_dim)
v = torch.randn(total, heads, head_dim)
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
return q, k, v, heads, cu
def _has_dynamo_disable(decorator):
return (
isinstance(decorator, ast.Attribute)
and decorator.attr == "disable"
and isinstance(decorator.value, ast.Attribute)
and decorator.value.attr == "_dynamo"
and isinstance(decorator.value.value, ast.Name)
and decorator.value.value.id == "torch"
)
def test_var_attention_backend_functions_are_dynamo_disabled_and_signature_compatible():
tree = ast.parse(inspect.getsource(attention))
functions = {node.name: node for node in tree.body if isinstance(node, ast.FunctionDef)}
for name in _VAR_BACKENDS:
node = functions[name]
positional = [arg.arg for arg in node.args.args[:6]]
keyword_only = {arg.arg for arg in node.args.kwonlyargs}
assert positional == ["q", "k", "v", "heads", "cu_seqlens_q", "cu_seqlens_k"]
assert node.args.vararg is not None
assert node.args.kwarg is not None
assert "skip_reshape" in keyword_only
assert "skip_output_reshape" in keyword_only
assert any(_has_dynamo_disable(decorator) for decorator in node.decorator_list)
def test_var_attention_registry_contains_always_available_entries():
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
def _run_attention_import(flag, fake_modules=True, fake_module_code=None):
argv = ["pytest-subprocess", "--cpu", "--disable-xformers"]
if flag:
argv.append(flag)
if fake_module_code is None:
fake_module_code = ""
if fake_modules and not fake_module_code:
fake_module_code = """
import types
sageattention = types.ModuleType("sageattention")
sageattention.sageattn = lambda *a, **k: a[0]
sageattention.sageattn_varlen = lambda *a, **k: a[0]
sys.modules["sageattention"] = sageattention
sageattn3 = types.ModuleType("sageattn3")
sageattn3.sageattn3_blackwell = lambda *a, **k: a[0]
sys.modules["sageattn3"] = sageattn3
flash_attn = types.ModuleType("flash_attn")
flash_attn.flash_attn_func = lambda q, k, v, **kwargs: q
flash_attn.flash_attn_varlen_func = lambda **kwargs: kwargs["q"]
sys.modules["flash_attn"] = flash_attn
flash_attn_interface = types.ModuleType("flash_attn_interface")
flash_attn_interface.flash_attn_varlen_func = lambda **kwargs: (kwargs["q"], None)
sys.modules["flash_attn_interface"] = flash_attn_interface
"""
code = (
"import sys\n"
"import comfy.options\n"
"comfy.options.enable_args_parsing()\n"
f"sys.argv = {argv!r}\n"
f"{textwrap.dedent(fake_module_code)}\n"
"import comfy.ldm.modules.attention as attention\n"
"print(attention.optimized_var_attention.__name__)\n"
)
return subprocess.run(
[sys.executable, "-c", code],
cwd=".",
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
)
def test_var_attention_rebind_sage_launch_flag():
result = _run_attention_import("--use-sage-attention")
assert result.returncode == 0, result.stderr
assert result.stdout.strip() == "var_attention_sage"
def test_var_attention_rebind_flash_launch_flag_uses_pytorch_varlen_in_cpu_mode():
result = _run_attention_import("--use-flash-attention")
assert result.returncode == 0, result.stderr
assert result.stdout.strip() == "var_attention_pytorch"
def test_var_attention_rebind_sage_launch_flag_without_varlen_uses_pytorch():
result = _run_attention_import(
"--use-sage-attention",
fake_module_code="""
import types
sageattention = types.ModuleType("sageattention")
sageattention.sageattn = lambda *a, **k: a[0]
sys.modules["sageattention"] = sageattention
""",
)
assert result.returncode == 0, result.stderr
assert result.stdout.strip() == "var_attention_pytorch"
def test_var_attention_rebind_flash_launch_flag_without_varlen_uses_pytorch():
result = _run_attention_import(
"--use-flash-attention",
fake_module_code="""
import types
flash_attn = types.ModuleType("flash_attn")
flash_attn.flash_attn_func = lambda q, k, v, **kwargs: q
sys.modules["flash_attn"] = flash_attn
""",
)
assert result.returncode == 0, result.stderr
assert result.stdout.strip() == "var_attention_pytorch"
def test_var_attention_rebind_pytorch_launch_flag():
result = _run_attention_import("--use-pytorch-cross-attention")
assert result.returncode == 0, result.stderr
assert result.stdout.strip() == "var_attention_pytorch"
def test_var_attention_rebind_split_launch_flag():
result = _run_attention_import("--use-split-cross-attention")
assert result.returncode == 0, result.stderr
assert result.stdout.strip() == "var_attention_split"
def test_var_attention_rebind_default_launch_flags():
result = _run_attention_import("")
assert result.returncode == 0, result.stderr
assert result.stdout.strip() == "var_attention_sub_quad"
def test_var_attention_sage_uses_cu_seqlens_contract(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def fake_sageattn_varlen(q, k, v, cu_q, cu_k, max_q, max_k, is_causal, sm_scale):
captured.update(cu_q=cu_q, cu_k=cu_k, max_q=max_q, max_k=max_k, is_causal=is_causal)
return torch.zeros_like(q)
monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", True)
monkeypatch.setattr(attention, "sageattn_varlen", fake_sageattn_varlen, raising=False)
out = attention.var_attention_sage(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_q"], cu)
assert torch.equal(captured["cu_k"], cu)
assert captured["max_q"] == 3
assert captured["max_k"] == 3
assert captured["is_causal"] is False
def test_var_attention_sage_runtime_error_preserves_fallback_dtype(monkeypatch):
q, k, v, heads, cu = _inputs()
q = q.float()
k = k.half()
v = v.half()
captured = {}
def failing_sageattn_varlen(*args, **kwargs):
raise RuntimeError("unsupported")
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
captured.update(dtype=q.dtype, k_dtype=k.dtype, v_dtype=v.dtype, skip_reshape=skip_reshape)
return torch.zeros_like(q)
monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", True)
monkeypatch.setattr(attention, "sageattn_varlen", failing_sageattn_varlen, raising=False)
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
out = attention.var_attention_sage(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert out.dtype == torch.float32
assert captured["dtype"] == torch.float32
assert captured["k_dtype"] == torch.float32
assert captured["v_dtype"] == torch.float32
assert captured["skip_reshape"] is True
def test_var_attention_sage3_uses_cu_seqlens_contract(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def fake_sageattn3_blackwell(q, k, v, is_causal=False):
captured.update(shape=tuple(q.shape), is_causal=is_causal)
return torch.zeros_like(q)
monkeypatch.setattr(attention, "SAGE_ATTENTION3_IS_AVAILABLE", True)
monkeypatch.setattr(attention, "sageattn3_blackwell", fake_sageattn3_blackwell, raising=False)
out = attention.var_attention_sage3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert captured["shape"] == (2, heads, 3, 4)
assert captured["is_causal"] is False
def test_var_attention_sage3_runtime_error_falls_back(monkeypatch):
q, k, v, heads, cu = _inputs()
q = q.float()
k = k.half()
v = v.half()
captured = {}
def failing_sageattn3_blackwell(*args, **kwargs):
raise RuntimeError("unsupported")
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
captured.update(cu_q=cu_seqlens_q, dtype=q.dtype, k_dtype=k.dtype, v_dtype=v.dtype, skip_reshape=skip_reshape)
return torch.zeros_like(q)
monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", False)
monkeypatch.setattr(attention, "SAGE_ATTENTION3_IS_AVAILABLE", True)
monkeypatch.setattr(attention, "sageattn3_blackwell", failing_sageattn3_blackwell, raising=False)
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
out = attention.var_attention_sage3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_q"], cu)
assert captured["dtype"] == torch.float32
assert captured["k_dtype"] == torch.float32
assert captured["v_dtype"] == torch.float32
assert captured["skip_reshape"] is True
def test_var_attention_flash_uses_cu_seqlens_contract(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def fake_flash_attn_varlen_func(**kwargs):
captured.update(kwargs)
return torch.zeros_like(kwargs["q"])
monkeypatch.setattr(attention, "FLASH_ATTENTION_VARLEN_IS_AVAILABLE", True)
monkeypatch.setattr(attention, "flash_attn_varlen_func", fake_flash_attn_varlen_func, raising=False)
out = attention.var_attention_flash(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_seqlens_q"], cu)
assert torch.equal(captured["cu_seqlens_k"], cu)
assert captured["max_seqlen_q"] == 3
assert captured["max_seqlen_k"] == 3
def test_var_attention_flash_runtime_error_falls_back(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def failing_flash_attn_varlen_func(**kwargs):
raise NotImplementedError("cpu")
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
captured.update(cu_q=cu_seqlens_q, skip_reshape=skip_reshape)
return torch.zeros_like(q)
monkeypatch.setattr(attention, "FLASH_ATTENTION_VARLEN_IS_AVAILABLE", True)
monkeypatch.setattr(attention, "flash_attn_varlen_func", failing_flash_attn_varlen_func, raising=False)
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
out = attention.var_attention_flash(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_q"], cu)
assert captured["skip_reshape"] is True
def test_var_attention_flash3_uses_cu_seqlens_contract(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def fake_flash_attn3_varlen_func(**kwargs):
captured.update(kwargs)
return torch.zeros_like(kwargs["q"]), None
monkeypatch.setattr(attention, "flash_attn3_varlen_func", fake_flash_attn3_varlen_func, raising=False)
monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True)
out = attention.var_attention_flash3(
q,
k,
v,
heads,
cu,
cu,
skip_reshape=True,
skip_output_reshape=True,
dropout_p=0.25,
window_size=(16, 16),
)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_seqlens_q"], cu)
assert torch.equal(captured["cu_seqlens_k"], cu)
assert captured["max_seqlen_q"] == 3
assert captured["max_seqlen_k"] == 3
assert captured["seqused_q"] is None
assert captured["seqused_k"] is None
assert "dropout_p" not in captured
assert "window_size" not in captured
def test_var_attention_flash3_accepts_tensor_return(monkeypatch):
q, k, v, heads, cu = _inputs()
def fake_flash_attn3_varlen_func(**kwargs):
return torch.zeros_like(kwargs["q"])
monkeypatch.setattr(attention, "flash_attn3_varlen_func", fake_flash_attn3_varlen_func, raising=False)
monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True)
out = attention.var_attention_flash3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
def test_var_attention_flash3_runtime_error_falls_back(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def failing_flash_attn3_varlen_func(**kwargs):
raise RuntimeError("unsupported")
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
captured.update(cu_q=cu_seqlens_q, skip_reshape=skip_reshape)
return torch.zeros_like(q)
monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True)
monkeypatch.setattr(attention, "flash_attn3_varlen_func", failing_flash_attn3_varlen_func, raising=False)
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
out = attention.var_attention_flash3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_q"], cu)
assert captured["skip_reshape"] is True
def test_var_attention_sub_quad_uses_cu_seqlens_contract(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
captured.update(cu_q=cu_seqlens_q, cu_k=cu_seqlens_k, skip_reshape=skip_reshape)
return torch.zeros_like(q)
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
out = attention.var_attention_sub_quad(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_q"], cu)
assert torch.equal(captured["cu_k"], cu)
assert captured["skip_reshape"] is True
def test_var_attention_split_uses_cu_seqlens_contract(monkeypatch):
q, k, v, heads, cu = _inputs()
captured = {}
def fake_var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
captured.update(cu_q=cu_seqlens_q, cu_k=cu_seqlens_k, skip_reshape=skip_reshape)
return torch.zeros_like(q)
def fail_var_attention_pytorch(*args, **kwargs):
raise AssertionError("split backend must not use nested-tensor pytorch var attention")
monkeypatch.setattr(attention, "var_attention_pytorch", fail_var_attention_pytorch)
monkeypatch.setattr(attention, "var_attention_pytorch_split", fake_var_attention_pytorch_split)
out = attention.var_attention_split(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert torch.equal(captured["cu_q"], cu)
assert torch.equal(captured["cu_k"], cu)
assert captured["skip_reshape"] is True
def test_var_attention_pytorch_split_normalizes_split_indices_to_cpu(monkeypatch):
q, k, v, heads, cu = _inputs()
captured_devices = []
real_tensor_split = torch.tensor_split
def capture_tensor_split(input, indices_or_sections, dim=0):
if isinstance(indices_or_sections, torch.Tensor):
captured_devices.append(indices_or_sections.device.type)
return real_tensor_split(input, indices_or_sections, dim=dim)
monkeypatch.setattr(torch, "tensor_split", capture_tensor_split)
out = attention.var_attention_pytorch_split(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
assert tuple(out.shape) == tuple(q.shape)
assert captured_devices == ["cpu", "cpu", "cpu"]
def test_missing_sage_package_guard_message_preserved():
code = textwrap.dedent(
"""
import builtins
import sys
import comfy.options
comfy.options.enable_args_parsing()
real_import = builtins.__import__
def blocked_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "sageattention":
raise ImportError("No module named sageattention", name="sageattention")
return real_import(name, globals, locals, fromlist, level)
builtins.__import__ = blocked_import
sys.argv = ["pytest-subprocess", "--cpu", "--disable-xformers", "--use-sage-attention"]
import comfy.ldm.modules.attention
"""
)
result = subprocess.run(
[sys.executable, "-c", code],
cwd=".",
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
)
assert result.returncode != 0
assert "To use the `--use-sage-attention` feature" in result.stderr
assert "sageattention" in result.stderr

View File

@ -0,0 +1,167 @@
"""Regression tests for the SeedVR2-named guard inside
``comfy.ldm.modules.attention.var_attention_pytorch``.
Contract:
* If ``torch.nested.nested_tensor_from_jagged`` is unavailable on the
installed PyTorch build, ``var_attention_pytorch`` must raise
``RuntimeError`` whose message contains both ``SeedVR2`` and
``nested_tensor_from_jagged`` so the operator can identify the
failing attention path. A bare ``AttributeError`` from the
``torch.nested`` lookup is non-conformant. The guard must also
cover the case where the ``torch.nested`` namespace itself is
absent (e.g. forks/builds that strip the module) accessing
``torch.nested`` directly would otherwise raise the same opaque
``AttributeError`` the guard is meant to translate.
* If the API is present, the present-API path must produce the
canonical SeedVR2-inference output shape ``(total_tokens,
heads * head_dim)``.
* If the caller passes malformed offsets (off-end / non-monotonic /
size-mismatched), torch's own per-call ``RuntimeError`` propagates
unchanged: the SeedVR2-context guard fires only on the missing-API
path, never on torch's per-call shape errors.
Each cell additionally pins the production guard at the AST level via
``inspect.getsource(var_attention_pytorch)`` so every AC fails
diagnostically on an unguarded base.
"""
from comfy.cli_args import args
import torch
if not torch.cuda.is_available():
args.cpu = True
import ast # noqa: E402
import inspect # noqa: E402
import logging # noqa: E402
import textwrap # noqa: E402
import warnings # noqa: E402
import pytest # noqa: E402
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
def _inputs():
"""Canonical 2-D ``(q, k, v, heads, cu_seqlens_q, cu_seqlens_k,
total_tokens, embed_dim)`` matching the live shape from GPT-3:
two segments of 3 tokens each, ``embed_dim = heads * head_dim =
2 * 8 = 16``.
"""
heads, head_dim, total_tokens = 2, 8, 6
embed_dim = heads * head_dim
q = torch.randn(total_tokens, embed_dim)
k = torch.randn(total_tokens, embed_dim)
v = torch.randn(total_tokens, embed_dim)
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
return q, k, v, heads, cu, cu, total_tokens, embed_dim
def _assert_guard_source_pin():
"""Walk the AST of ``var_attention_pytorch`` and assert that the
first ``raise RuntimeError(...)`` statement appears strictly
before any attribute access named ``nested_tensor_from_jagged``.
Substring-based source pinning (``src.index('raise RuntimeError(')
< src.index('nested_tensor_from_jagged')``) is fragile: it false-
positives on docstring or comment text containing the literal,
and false-negatives on a refactor that splits ``raise
RuntimeError(`` across lines or replaces it with a helper
raising ``RuntimeError`` from another scope. AST-walking the
function body collapses both failure modes onto the only
invariant we actually require the guard statement dominates
the attribute access by line number.
"""
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
tree = ast.parse(src)
raise_lines = []
nested_lines = []
for node in ast.walk(tree):
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
func = node.exc.func
if isinstance(func, ast.Name) and func.id == "RuntimeError":
raise_lines.append(node.lineno)
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
nested_lines.append(node.lineno)
assert raise_lines, (
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
)
assert nested_lines, (
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
f"attribute access; cannot pin guard ordering.\n"
f"--- source ---\n{src}"
)
first_raise = min(raise_lines)
first_nested = min(nested_lines)
assert first_raise < first_nested, (
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
f"at line {first_nested}; the guard must precede the lookup.\n"
f"--- source ---\n{src}"
)
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _inputs()
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch, "nested", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _inputs()
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
def test_present_api_returns_expected_shape():
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _inputs()
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
old_torch_fx_level = torch_fx_logger.level
torch_fx_logger.setLevel(logging.ERROR)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The PyTorch API of nested tensors is in prototype stage.*",
category=UserWarning,
)
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
finally:
torch_fx_logger.setLevel(old_torch_fx_level)
assert tuple(out.shape) == (total_tokens, embed_dim), (
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
)
_assert_guard_source_pin()
def test_malformed_offsets_propagates_torch_runtime_error():
q, k, v, heads, _, _, _, _ = _inputs()
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
with pytest.raises(RuntimeError) as exc_info:
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
msg = str(exc_info.value)
assert "split_with_sizes" in msg, (
f"expected torch's `split_with_sizes` error to propagate; got: {msg!r}"
)
assert "SeedVR2" not in msg, (
f"SeedVR2-context substring must not be substituted onto torch's "
f"per-call shape error; got: {msg!r}"
)
_assert_guard_source_pin()