From 9eb6c7fe9e9be91603a9ea052a7df60eb1198ccd Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 25 May 2026 22:12:12 -0500 Subject: [PATCH] Add SeedVR2 core coverage --- tests-unit/comfy_test/model_detection_test.py | 58 ++ tests-unit/comfy_test/seedvr_model_test.py | 192 +++++++ .../comfy_test/seedvr_vae_forward_test.py | 124 +++++ .../seedvr_vae_wrapper_forward_test.py | 63 +++ .../test_diffusers_metadata_guard.py | 105 ++++ tests-unit/comfy_test/test_seedvr2_dtype.py | 503 ++++++++++++++++++ .../test_seedvr_7b_final_block_text_path.py | 218 ++++++++ .../test_seedvr_forward_no_device_cast.py | 54 ++ .../comfy_test/test_seedvr_groupnorm_limit.py | 179 +++++++ .../comfy_test/test_seedvr_latent_format.py | 40 ++ .../comfy_test/test_seedvr_rope_delegation.py | 176 ++++++ .../comfy_test/test_seedvr_rope_rewrite.py | 335 ++++++++++++ .../test_seedvr_vae_attention_fence.py | 37 ++ .../test_seedvr_var_attention_backends.py | 476 +++++++++++++++++ ...est_var_attention_pytorch_seedvr2_guard.py | 167 ++++++ 15 files changed, 2727 insertions(+) create mode 100644 tests-unit/comfy_test/seedvr_model_test.py create mode 100644 tests-unit/comfy_test/seedvr_vae_forward_test.py create mode 100644 tests-unit/comfy_test/seedvr_vae_wrapper_forward_test.py create mode 100644 tests-unit/comfy_test/test_diffusers_metadata_guard.py create mode 100644 tests-unit/comfy_test/test_seedvr2_dtype.py create mode 100644 tests-unit/comfy_test/test_seedvr_7b_final_block_text_path.py create mode 100644 tests-unit/comfy_test/test_seedvr_forward_no_device_cast.py create mode 100644 tests-unit/comfy_test/test_seedvr_groupnorm_limit.py create mode 100644 tests-unit/comfy_test/test_seedvr_latent_format.py create mode 100644 tests-unit/comfy_test/test_seedvr_rope_delegation.py create mode 100644 tests-unit/comfy_test/test_seedvr_rope_rewrite.py create mode 100644 tests-unit/comfy_test/test_seedvr_vae_attention_fence.py create mode 100644 tests-unit/comfy_test/test_seedvr_var_attention_backends.py create mode 100644 tests-unit/comfy_test/test_var_attention_pytorch_seedvr2_guard.py diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 4e9350602..cc64a2ce1 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd(): return sd +def _make_seedvr2_7b_separate_mm_sd(): + return { + "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), + } + + +def _make_seedvr2_7b_shared_mm_sd(): + return { + "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + +def _make_seedvr2_3b_shared_mm_sd(): + return { + "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -125,6 +143,46 @@ class TestModelDetection: assert model_config is not None assert type(model_config).__name__ == "FluxSchnell" + def test_seedvr2_7b_separate_mm_detection_config(self): + sd = _make_seedvr2_7b_separate_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 36 + assert unet_config["mlp_type"] == "normal" + assert unet_config["qk_rope"] is True + assert unet_config["rope_type"] == "rope3d" + assert unet_config["rope_dim"] == 64 + + def test_seedvr2_7b_shared_mm_detection_config(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 10 + assert unet_config["mlp_type"] == "swiglu" + assert unet_config["qk_rope"] is True + + def test_seedvr2_3b_shared_mm_detection_config(self): + sd = _make_seedvr2_3b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 2560 + assert unet_config["heads"] == 20 + assert unet_config["num_layers"] == 32 + assert unet_config["mlp_type"] == "swiglu" + assert unet_config["qk_rope"] is None + def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/seedvr_model_test.py b/tests-unit/comfy_test/seedvr_model_test.py new file mode 100644 index 000000000..bc25967ab --- /dev/null +++ b/tests-unit/comfy_test/seedvr_model_test.py @@ -0,0 +1,192 @@ +"""Regression tests for SeedVR2 conditioning split hardening. + +Two bare ``except:`` clauses in ``NaDiT.forward`` previously swallowed +every failure mode on (1) the input-side text-conditioning split and +(2) the output-side positive/negative split, silently substituting +wrong fallbacks: the ``positive_conditioning`` buffer (which prior to +explicit zero-init held **uninitialized** memory — NaN, residual heap +contents, never guaranteed-zero) for the input, and the un-split +tensor for the output. Real prompt-shape, dtype, OOM, and downstream +tensor failures were re-routed to "no prompt supplied" with arbitrary +buffer contents standing in for actual prompt embeddings, or to a +wrong-order output, with no diagnostic. + +The fix: + + 1. Input-side: explicit absence predicate (``context is None`` or + ``context.numel() == 0``) → fall back to ``positive_conditioning`` + buffer. Any other failure (wrong rank, odd batch, dtype, OOM) + propagates the original torch exception. + 2. Output-side: no try/except at all. ``out.chunk(2)`` of the + network output is a contract: an unsplittable result is a bug, + not a recoverable condition. + +The two blocks were extracted into named private methods on +``NaDiT`` (``_resolve_text_conditioning`` and ``_swap_pos_neg_halves``) +so the regression evidence drives the actual production code paths +without standing up a full transformer. The methods are called from +``forward`` exactly where the original try/except blocks lived. +""" + +from comfy.cli_args import args +import torch + +if not torch.cuda.is_available(): + args.cpu = True + +import ast # noqa: E402 +import inspect # noqa: E402 +import textwrap # noqa: E402 + +import pytest # noqa: E402 + +from comfy.ldm.seedvr.model import NaDiT # noqa: E402 + + +def _make_standin(positive_conditioning): + class _StandIn(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "positive_conditioning", positive_conditioning + ) + + _resolve_text_conditioning = NaDiT._resolve_text_conditioning + _swap_pos_neg_halves = NaDiT._swap_pos_neg_halves + + return _StandIn() + + +def test_no_bare_except_in_forward_path(): + """Source-level pin: neither ``NaDiT.forward`` nor its split helpers + may carry the bare ``except:`` clauses that swallowed real torch + failures on the SeedVR2 conditioning paths. AST-walked rather than + substring-matched so that ``except:`` appearing in a docstring or + comment does not false-positive, and so that ``except Exception:`` + (a typed handler, fine to have) does not false-negative. + """ + sources = [ + inspect.getsource(NaDiT.forward), + inspect.getsource(NaDiT._resolve_text_conditioning), + inspect.getsource(NaDiT._swap_pos_neg_halves), + ] + for src in sources: + tree = ast.parse(textwrap.dedent(src)) + for node in ast.walk(tree): + if isinstance(node, ast.ExceptHandler): + assert node.type is not None, ( + "Bare 'except:' (ast.ExceptHandler with type=None) " + f"must not appear on the SeedVR2 forward path:\n{src}" + ) + + +def test_valid_context_splits_pos_neg(): + """AC: valid (neg, pos)-stacked context (shape ``(2, L, C)``) + produces a flattened ``[pos, neg]`` text tensor — first ``L`` rows + are positive, next ``L`` rows are negative — matching the original + semantics of the ``flatten([pos_cond, neg_cond])`` call. + """ + pos_buffer = torch.zeros((58, 5120)) + standin = _make_standin(pos_buffer) + seq_len, channels = 7, 5120 + neg = torch.full((1, seq_len, channels), -1.0) + pos = torch.full((1, seq_len, channels), 1.0) + context = torch.cat([neg, pos], dim=0) + txt, txt_shape = standin._resolve_text_conditioning(context) + assert txt.shape == (2 * seq_len, channels) + assert (txt[:seq_len] == 1.0).all(), "first half must be positive cond" + assert (txt[seq_len:] == -1.0).all(), "second half must be negative cond" + assert txt_shape.shape == (2, 1) + assert txt_shape[0].item() == seq_len + assert txt_shape[1].item() == seq_len + + +def test_missing_context_falls_back_to_positive_buffer(): + """AC: ``context is None`` falls back to the registered + ``positive_conditioning`` buffer and runs to completion — no + silent zero substitution, no raised exception. + """ + pos_buffer = torch.full((58, 5120), 7.0) + standin = _make_standin(pos_buffer) + txt, txt_shape = standin._resolve_text_conditioning(None) + assert txt.shape == (58, 5120) + assert (txt == 7.0).all(), ( + "fallback path must use the positive_conditioning buffer " + "verbatim, not a zero tensor" + ) + assert txt_shape.shape == (1, 1) + assert txt_shape[0, 0].item() == 58 + + +def test_empty_context_falls_back_to_positive_buffer(): + """AC: ``context.numel() == 0`` falls back to the registered + ``positive_conditioning`` buffer and runs to completion. + """ + pos_buffer = torch.full((58, 5120), 13.0) + standin = _make_standin(pos_buffer) + empty = torch.empty((0, 5120)) + assert empty.numel() == 0 + txt, txt_shape = standin._resolve_text_conditioning(empty) + assert txt.shape == (58, 5120) + assert (txt == 13.0).all() + assert txt_shape.shape == (1, 1) + assert txt_shape[0, 0].item() == 58 + + +def test_wrong_rank_context_raises_original_torch_exception(): + """AC: a 1-D context tensor cannot be split into ``[pos, neg]`` + via the ``chunk + squeeze + flatten`` chain; the original torch + exception must propagate rather than silently falling back. + """ + pos_buffer = torch.zeros((58, 5120)) + standin = _make_standin(pos_buffer) + bad = torch.zeros(10) + with pytest.raises((RuntimeError, IndexError, ValueError)): + standin._resolve_text_conditioning(bad) + + +def test_odd_batch_context_raises_original_exception(): + """AC: a context whose batch dim cannot be split into two equal + chunks (here batch=1 so ``chunk(2, dim=0)`` returns a single + tensor) must propagate the original exception — no silent fallback. + """ + pos_buffer = torch.zeros((58, 5120)) + standin = _make_standin(pos_buffer) + bad = torch.zeros((1, 7, 5120)) + with pytest.raises((RuntimeError, ValueError)): + standin._resolve_text_conditioning(bad) + + +def test_output_side_misshaped_tensor_raises(): + """AC: the post-network output split must raise on an unsplittable + tensor (no silent return of the un-split tensor in the wrong + order/shape). Here a batch=1 tensor cannot be ``chunk(2, dim=0)`` + into two halves; ``pos, neg = out.chunk(2, dim=0)`` raises on + unpacking — matching the production helper's explicit-dim contract + (``_swap_pos_neg_halves`` calls ``chunk(2, dim=0)`` and + ``torch.cat(..., dim=0)``). + """ + pos_buffer = torch.zeros((58, 5120)) + standin = _make_standin(pos_buffer) + bad_out = torch.zeros((1, 4, 8, 8)) + with pytest.raises((RuntimeError, ValueError)): + standin._swap_pos_neg_halves(bad_out) + + +def test_output_side_swaps_pos_neg_halves(): + """AC complement: ``_swap_pos_neg_halves`` reorders the post-network + output so the first half (positive) and second half (negative) trade + places. For a 2-batch tensor with distinguishable halves, the + returned tensor must be the swap — first half becomes negative, + second half becomes positive — matching the original + ``torch.cat([neg, pos])`` semantics from the pre-fix forward path. + """ + pos_buffer = torch.zeros((58, 5120)) + standin = _make_standin(pos_buffer) + pos_half = torch.full((1, 4, 8, 8), 1.0) + neg_half = torch.full((1, 4, 8, 8), -1.0) + out = torch.cat([pos_half, neg_half], dim=0) + swapped = standin._swap_pos_neg_halves(out) + assert swapped.shape == out.shape + assert (swapped[0] == -1.0).all(), "first half of swapped output must be the original negative half" + assert (swapped[1] == 1.0).all(), "second half of swapped output must be the original positive half" diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py new file mode 100644 index 000000000..76fed86ed --- /dev/null +++ b/tests-unit/comfy_test/seedvr_vae_forward_test.py @@ -0,0 +1,124 @@ +"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must +honor the actual tensor/tuple return contract of ``encode()`` and +``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` +or ``.sample`` attributes on those returns. + +The pre-fix body raised ``AttributeError: 'Tensor' object has no +attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and +``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` +for ``mode == "decode"`` (the class only defines ``decode_`` with a +trailing underscore). The post-fix body unwraps the optional one-element +tuple shape that ``return_dict=False`` produces and returns the tensor +directly. + +Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses +the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and +overrides ``encode``/``decode_`` with known tensors so the contract can +be probed without loading any real VAE weights. +""" + +import inspect +import re + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 + + +_LATENT_SHAPE = (1, 16, 2, 2, 2) +_DECODED_SHAPE = (1, 3, 5, 16, 16) +_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) +_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) + + +class _StubVAE(VideoAutoencoderKL): + def __init__(self): + nn.Module.__init__(self) + self._encode_out = torch.zeros(*_LATENT_SHAPE) + self._decode_out = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return self._encode_out + + def decode_(self, z, return_dict=True): + return self._decode_out + + +def test_forward_encode_returns_tensor(): + vae = _StubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="encode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_LATENT_SHAPE) + + +def test_forward_decode_returns_tensor(): + vae = _StubVAE() + z = torch.zeros(*_INPUT_DECODE_SHAPE) + result = vae.forward(z, mode="decode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) + + +def test_forward_all_returns_tensor(): + vae = _StubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="all") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) + + +def test_forward_source_has_no_diffusers_attr_access(): + src = inspect.getsource(VideoAutoencoderKL.forward) + assert ".latent_dist" not in src + assert ".sample" not in src + assert re.search(r"self\.decode\(", src) is None + + +class _TupleReturningStubVAE(VideoAutoencoderKL): + """Stub variant whose ``encode``/``decode_`` return the + ``(tensor,)`` one-element tuple shape ``return_dict=False`` produces + in the parent class. Exercises the unwrap branch of + ``VideoAutoencoderKL.forward``. + """ + + def __init__(self): + nn.Module.__init__(self) + self._encode_tensor = torch.zeros(*_LATENT_SHAPE) + self._decode_tensor = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return (self._encode_tensor,) + + def decode_(self, z, return_dict=True): + return (self._decode_tensor,) + + +def test_forward_encode_unwraps_one_tuple(): + vae = _TupleReturningStubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="encode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_LATENT_SHAPE) + + +def test_forward_decode_unwraps_one_tuple(): + vae = _TupleReturningStubVAE() + z = torch.zeros(*_INPUT_DECODE_SHAPE) + result = vae.forward(z, mode="decode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) + + +def test_forward_all_unwraps_one_tuple_at_each_step(): + vae = _TupleReturningStubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="all") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) diff --git a/tests-unit/comfy_test/seedvr_vae_wrapper_forward_test.py b/tests-unit/comfy_test/seedvr_vae_wrapper_forward_test.py new file mode 100644 index 000000000..7a4c32131 --- /dev/null +++ b/tests-unit/comfy_test/seedvr_vae_wrapper_forward_test.py @@ -0,0 +1,63 @@ +import inspect + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 + +VideoAutoencoderKLWrapper = vae_mod.VideoAutoencoderKLWrapper + + +_INPUT_SHAPE = (1, 3, 5, 16, 16) +_POSTERIOR_SHAPE = (1, 16, 1, 2, 2) +_DECODE_OUT_SHAPE = (1, 3, 5, 16, 16) + + +def _build_wrapper_standin() -> VideoAutoencoderKLWrapper: + wrapper = VideoAutoencoderKLWrapper.__new__(VideoAutoencoderKLWrapper) + nn.Module.__init__(wrapper) + return wrapper + + +def test_wrapper_forward_returns_tensor_triple(monkeypatch): + wrapper = _build_wrapper_standin() + wrapper.original_image_video = torch.zeros(*_INPUT_SHAPE) + wrapper.img_dims = (16, 16) + wrapper.freeze_encoder = True + + posterior = torch.full(_POSTERIOR_SHAPE, 7.0) + decode_out = torch.full(_DECODE_OUT_SHAPE, 13.0) + + def stub_encode(self, x, orig_dims=None): + return posterior.squeeze(2), posterior + + def stub_decode(self, z): + return decode_out + + monkeypatch.setattr(VideoAutoencoderKLWrapper, "encode", stub_encode) + monkeypatch.setattr(VideoAutoencoderKLWrapper, "decode", stub_decode) + + x = torch.zeros(*_INPUT_SHAPE) + result = wrapper.forward(x) + + assert isinstance(result, tuple) + assert len(result) == 3 + x_out, z, p = result + assert type(x_out) is torch.Tensor + assert type(z) is torch.Tensor + assert type(p) is torch.Tensor + assert x_out.shape == decode_out.shape + assert z.shape == posterior.squeeze(2).shape + assert torch.equal(x_out, decode_out) + assert torch.equal(z, posterior.squeeze(2)) + assert p is posterior + + +def test_wrapper_forward_source_has_no_sample_access(): + src = inspect.getsource(VideoAutoencoderKLWrapper.forward) + assert ".sample" not in src diff --git a/tests-unit/comfy_test/test_diffusers_metadata_guard.py b/tests-unit/comfy_test/test_diffusers_metadata_guard.py new file mode 100644 index 000000000..597ef781f --- /dev/null +++ b/tests-unit/comfy_test/test_diffusers_metadata_guard.py @@ -0,0 +1,105 @@ +"""Regression tests for the diffusers-format guard inside ``comfy.sd.VAE.__init__``. + +The guard previously indexed ``metadata["keep_diffusers_format"]`` directly, +raising ``KeyError`` when ``metadata`` was non-``None`` but lacked that key. The +fixed guard uses ``metadata.get("keep_diffusers_format") != "true"``: a missing +key flows through to ``convert_vae_state_dict``; the explicit ``"true"`` value +bypasses it. + +Five cells exercise every reachable shape of the guard input — missing key, +explicit ``"true"``, ``None``, explicit non-``"true"``, empty dict — and halt +the constructor at the first post-guard call (``model_management.is_amd``). +``_make_standin`` borrows ``__init__`` onto a bare class, mirroring +``seedvr_model_test.py::_make_standin`` (#109). ``_exercise_guard`` single- +sources the patched-constructor harness so the cells stay synchronised. +""" + +from comfy.cli_args import args +import torch + +if not torch.cuda.is_available(): + args.cpu = True + +import contextlib # noqa: E402 +import unittest.mock # noqa: E402 + +import comfy.sd # noqa: E402 + + +_DIFFUSERS_TRIGGER_KEY = "decoder.up_blocks.0.resnets.0.norm1.weight" + + +class _PostGuardReached(Exception): + """Sentinel raised by the patched ``is_amd`` to halt ``__init__`` at the first post-guard statement.""" + + +def _make_standin(): + class _StandIn: + __init__ = comfy.sd.VAE.__init__ + + return _StandIn + + +def _exercise_guard(metadata): + """Drive ``VAE.__init__`` with the diffusers trigger key and the supplied + ``metadata``; halt at ``is_amd``. Returns ``(mock_convert, mock_is_amd)`` + for branch (call_count) + reach (called) assertions per cell. + """ + StandIn = _make_standin() + sd = {_DIFFUSERS_TRIGGER_KEY: torch.zeros(1)} + + with unittest.mock.patch.object( + comfy.sd.diffusers_convert, + "convert_vae_state_dict", + autospec=True, + side_effect=lambda state_dict: state_dict, + ) as mock_convert, unittest.mock.patch.object( + comfy.sd.model_management, + "is_amd", + autospec=True, + side_effect=_PostGuardReached("post-guard reached"), + ) as mock_is_amd: + with contextlib.suppress(_PostGuardReached): + StandIn(sd=sd, metadata=metadata) + + return mock_convert, mock_is_amd + + +def test_diffusers_guard_invokes_convert_when_metadata_missing_key(): + """AC1: metadata is non-None but lacks ``keep_diffusers_format`` → convert is invoked.""" + mock_convert, mock_is_amd = _exercise_guard({"unrelated_key": "value"}) + + assert mock_convert.call_count == 1 + assert mock_is_amd.called + + +def test_diffusers_guard_skips_convert_when_metadata_pins_keep_true(): + """AC2: metadata pins ``keep_diffusers_format == "true"`` → convert is skipped.""" + mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "true"}) + + assert mock_convert.call_count == 0 + assert mock_is_amd.called + + +def test_diffusers_guard_invokes_convert_when_metadata_is_none(): + """AC3: metadata is ``None`` → first disjunct fires, convert is invoked.""" + mock_convert, mock_is_amd = _exercise_guard(None) + + assert mock_convert.call_count == 1 + assert mock_is_amd.called + + +def test_diffusers_guard_invokes_convert_when_metadata_pins_keep_false(): + """AC4: metadata pins a non-``"true"`` value → second disjunct fires, convert is invoked.""" + mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "false"}) + + assert mock_convert.call_count == 1 + assert mock_is_amd.called + + +def test_diffusers_guard_invokes_convert_when_metadata_is_empty_dict(): + """AC5: metadata is ``{}`` (the ``convert_old_quants`` None→{} normalization shape) → convert is invoked.""" + mock_convert, mock_is_amd = _exercise_guard({}) + + assert mock_convert.call_count == 1 + assert mock_is_amd.called diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py new file mode 100644 index 000000000..3ca0d0dd6 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_dtype.py @@ -0,0 +1,503 @@ +import inspect +import logging +import warnings +from pathlib import Path +from types import SimpleNamespace + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.modules.attention as attention +import comfy.sd +import comfy.supported_models +import comfy.ldm.seedvr.model as seedvr_model + + +def test_set_model_config_inference_dtype_preserves_legacy_signature(): + calls = [] + + class LegacyConfig: + def set_inference_dtype(self, dtype, manual_cast_dtype): + calls.append((dtype, manual_cast_dtype)) + + comfy.sd._set_model_config_inference_dtype(LegacyConfig(), torch.float16, None, object()) + + assert calls == [(torch.float16, None)] + + +def test_set_model_config_inference_dtype_passes_device_when_supported(): + calls = [] + device = object() + + class DeviceAwareConfig: + def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): + calls.append((dtype, manual_cast_dtype, device)) + + comfy.sd._set_model_config_inference_dtype(DeviceAwareConfig(), torch.float16, None, device) + + assert calls == [(torch.float16, None, device)] + + +def test_set_model_config_inference_dtype_passes_device_to_kwargs_override(): + calls = [] + device = object() + + class KwargsConfig: + def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs): + calls.append((dtype, manual_cast_dtype, kwargs)) + + comfy.sd._set_model_config_inference_dtype(KwargsConfig(), torch.float16, None, device) + + assert calls == [(torch.float16, None, {"device": device})] + + +def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch): + bf16_device = object() + fp16_device = object() + + monkeypatch.setattr( + comfy.supported_models.comfy.model_management, + "should_use_bf16", + lambda device=None: device is bf16_device, + ) + + bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device) + assert bf16_config.manual_cast_dtype is torch.bfloat16 + + fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device) + assert fp16_config.manual_cast_dtype is None + + +def test_apply_rope1_partial_preserves_full_rotation_input_dtype(monkeypatch): + def fake_apply_rope1(t, freqs_cis): + return t.float() + 1.0 + + monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1) + + t = torch.arange(8, dtype=torch.float16).reshape(1, 2, 4) + original = t.clone() + freqs_cis = torch.zeros(1, 2, 2, 2) + + out = seedvr_model._apply_rope1_partial(t, freqs_cis) + + assert out.dtype is torch.float16 + torch.testing.assert_close(out, (original.float() + 1.0).to(torch.float16)) + + +def test_apply_rope1_partial_preserves_partial_rotation_input_dtype(monkeypatch): + def fake_apply_rope1(t, freqs_cis): + return t.float() + 1.0 + + monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1) + + t = torch.arange(12, dtype=torch.float16).reshape(1, 2, 6) + original = t.clone() + freqs_cis = torch.zeros(1, 2, 2, 2) + + out = seedvr_model._apply_rope1_partial(t, freqs_cis) + + assert out.dtype is torch.float16 + torch.testing.assert_close( + out[..., :4], + (original[..., :4].float() + 1.0).to(torch.float16), + ) + torch.testing.assert_close(out[..., 4:], original[..., 4:]) + + +def test_apply_rope1_partial_chunks_sequence_dimension(monkeypatch): + calls = [] + + def fake_apply_rope1(t, freqs_cis): + calls.append(t.shape[-2]) + return t.float() + 1.0 + + monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1) + monkeypatch.setattr(seedvr_model, "_ROPE1_PARTIAL_CHUNK_TOKENS", 2) + + t = torch.arange(30, dtype=torch.float16).reshape(1, 5, 6) + original = t.clone() + freqs_cis = torch.zeros(5, 2, 2, 2) + + out = seedvr_model._apply_rope1_partial(t, freqs_cis) + + assert calls == [2, 2, 1] + torch.testing.assert_close(out[..., :4], (original[..., :4].float() + 1.0).to(torch.float16)) + torch.testing.assert_close(out[..., 4:], original[..., 4:]) + + +def test_apply_rope1_partial_clones_training_tensor(monkeypatch): + def fake_apply_rope1(t, freqs_cis): + return t + 1.0 + + monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1) + + base = torch.arange(12, dtype=torch.float32, requires_grad=True) + t = base.reshape(1, 2, 6) + original = t.clone() + freqs_cis = torch.zeros(2, 2, 2, 2) + + out = seedvr_model._apply_rope1_partial(t, freqs_cis) + out.sum().backward() + + assert out is not t + torch.testing.assert_close(t, original) + torch.testing.assert_close(out[..., :4], original[..., :4] + 1.0) + torch.testing.assert_close(out[..., 4:], original[..., 4:]) + assert base.grad is not None + + +def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): + context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) + + torch.testing.assert_close(txt, context.squeeze(0)) + torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device)) + + +def test_seedvr2_text_conditioning_accepts_batched_cfg1_single_branch(): + context = torch.arange(12, dtype=torch.float32).reshape(2, 3, 2) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) + + torch.testing.assert_close(txt, context.flatten(0, -2)) + torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device)) + + +def test_seedvr2_text_conditioning_accepts_multi_entry_cfg1_single_branch(): + context = torch.arange(12, dtype=torch.float32).reshape(2, 3, 2) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0, 0]) + + torch.testing.assert_close(txt, context.flatten(0, -2)) + torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device)) + + +def test_seedvr2_text_conditioning_preserves_two_branch_swap_contract(): + neg = torch.full((1, 3, 2), -1.0) + pos = torch.full((1, 3, 2), 1.0) + context = torch.cat([neg, pos], dim=0) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context) + + torch.testing.assert_close(txt[:3], pos.squeeze(0)) + torch.testing.assert_close(txt[3:], neg.squeeze(0)) + torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device)) + + +def test_seedvr2_text_conditioning_preserves_batched_two_branch_swap_contract(): + neg = torch.full((2, 3, 2), -1.0) + pos = torch.full((2, 3, 2), 1.0) + context = torch.cat([neg, pos], dim=0) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [1, 0]) + + torch.testing.assert_close(txt[:6], pos.flatten(0, -2)) + torch.testing.assert_close(txt[6:], neg.flatten(0, -2)) + torch.testing.assert_close(txt_shape, torch.tensor([[3], [3], [3], [3]], device=context.device)) + + +def test_seedvr2_cfg1_single_branch_output_is_not_swapped(): + out = torch.arange(6, dtype=torch.float32).reshape(1, 6) + + swapped = seedvr_model.NaDiT._swap_pos_neg_halves(object(), out, [0]) + + torch.testing.assert_close(swapped, out) + + +def test_seedvr2_multi_entry_cfg1_output_is_not_swapped(): + out = torch.arange(12, dtype=torch.float32).reshape(2, 6) + + swapped = seedvr_model.NaDiT._swap_pos_neg_halves(object(), out, [0, 0]) + + torch.testing.assert_close(swapped, out) + + +def test_seedvr2_conditioning_keeps_comfy_cfg1_optimization_enabled(): + source = (Path(__file__).resolve().parents[2] / "comfy_extras" / "nodes_seedvr.py").read_text(encoding="utf-8") + + assert "disable_model_cfg1_optimization()" not in source + + +def test_seedvr2_split_var_attention_matches_nested_var_attention(): + torch.manual_seed(1) + q = torch.randn(5, 2, 4) + k = torch.randn(7, 2, 4) + v = torch.randn(7, 2, 4) + cu_q = torch.tensor([0, 2, 5], dtype=torch.int32) + cu_k = torch.tensor([0, 3, 7], dtype=torch.int32) + + torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace") + old_torch_fx_level = torch_fx_logger.level + torch_fx_logger.setLevel(logging.ERROR) + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="The PyTorch API of nested tensors is in prototype stage.*", + category=UserWarning, + ) + nested = attention.var_attention_pytorch( + q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + skip_reshape=True, skip_output_reshape=True, + ) + finally: + torch_fx_logger.setLevel(old_torch_fx_level) + split = attention.var_attention_pytorch_split( + q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + skip_reshape=True, skip_output_reshape=True, + ) + + torch.testing.assert_close(split, nested, rtol=1e-5, atol=1e-5) + + +def test_seedvr2_split_var_attention_preserves_flat_output_shape(): + torch.manual_seed(2) + q = torch.randn(5, 8) + k = torch.randn(7, 8) + v = torch.randn(7, 8) + cu_q = torch.tensor([0, 1, 5], dtype=torch.int32) + cu_k = torch.tensor([0, 2, 7], dtype=torch.int32) + + nested = attention.var_attention_pytorch( + q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + ) + split = attention.var_attention_pytorch_split( + q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + ) + + assert split.shape == q.shape + torch.testing.assert_close(split, nested, rtol=1e-5, atol=1e-5) + + +def test_seedvr2_split_var_attention_rejects_mismatched_sequence_count(): + q = torch.randn(5, 2, 4) + k = torch.randn(7, 2, 4) + v = torch.randn(7, 2, 4) + cu_q = torch.tensor([0, 2, 5], dtype=torch.int32) + cu_k = torch.tensor([0, 3, 5, 7], dtype=torch.int32) + + try: + attention.var_attention_pytorch_split( + q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + skip_reshape=True, skip_output_reshape=True, + ) + except ValueError as exc: + assert "same sequence count" in str(exc) + else: + raise AssertionError("mismatched cu_seqlens sequence counts must fail") + + +def test_seedvr2_split_var_attention_rejects_malformed_offsets(): + q = torch.randn(5, 2, 4) + k = torch.randn(7, 2, 4) + v = torch.randn(7, 2, 4) + cu_k = torch.tensor([0, 3, 7], dtype=torch.int32) + + malformed_cases = ( + (torch.tensor([1, 2, 5], dtype=torch.int32), "start at 0"), + (torch.tensor([0, 2, 2, 5], dtype=torch.int32), "strictly increasing"), + (torch.tensor([0.0, 2.0, 5.0], dtype=torch.float32), "integer dtype"), + ) + + for cu_q, message in malformed_cases: + try: + attention.var_attention_pytorch_split( + q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + skip_reshape=True, skip_output_reshape=True, + ) + except ValueError as exc: + assert message in str(exc) + else: + raise AssertionError("malformed cu_seqlens must fail") + + +def test_seedvr2_7b_window_attention_handles_mm_rope_source(): + source = inspect.getsource(seedvr_model.NaSwinAttention.forward) + + assert "if self.rope.mm" in source + assert "txt_q_repeat" in source + + +def test_seedvr2_7b_window_attention_routes_to_split_var_attention(): + source = inspect.getsource(seedvr_model.NaSwinAttention.forward) + + assert "_seedvr2_7b_window_attention_split" in source + assert "if self.version_7b" in source + + +def test_seedvr2_7b_window_attention_split_matches_concat_path(): + torch.manual_seed(3) + vid_len_win = torch.tensor([1, 2, 3], dtype=torch.int64) + txt_len = torch.tensor([2, 3], dtype=torch.int64) + window_count = torch.tensor([2, 1], dtype=torch.int64) + heads = 2 + dim = 4 + + vid_total = int(vid_len_win.sum().item()) + txt_total = int(txt_len.sum().item()) + vid_q = torch.randn(vid_total, heads, dim) + vid_k = torch.randn(vid_total, heads, dim) + vid_v = torch.randn(vid_total, heads, dim) + txt_q = torch.randn(txt_total, heads, dim) + txt_k = torch.randn(txt_total, heads, dim) + txt_v = torch.randn(txt_total, heads, dim) + + concat_win, unconcat_win = seedvr_model.repeat_concat_idx(vid_len_win, txt_len, window_count) + all_len_win = vid_len_win + txt_len.repeat_interleave(window_count) + cu_seqlens = torch.nn.functional.pad(all_len_win.cumsum(0), (1, 0)).int() + concat_out = attention.var_attention_pytorch_split( + concat_win(vid_q, txt_q), + concat_win(vid_k, txt_k), + concat_win(vid_v, txt_v), + heads=heads, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + skip_reshape=True, + skip_output_reshape=True, + ) + expected_vid, expected_txt = unconcat_win(concat_out) + + split_vid, split_txt = seedvr_model._seedvr2_7b_window_attention_split( + vid_q, txt_q, vid_k, txt_k, vid_v, txt_v, + vid_len_win, txt_len, window_count, + ) + + torch.testing.assert_close(split_vid, expected_vid, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(split_txt, expected_txt, rtol=1e-5, atol=1e-5) + + +def test_seedvr2_7b_window_attention_split_preserves_autograd(): + torch.manual_seed(4) + vid_len_win = torch.tensor([1, 2, 3], dtype=torch.int64) + txt_len = torch.tensor([2, 3], dtype=torch.int64) + window_count = torch.tensor([2, 1], dtype=torch.int64) + heads = 2 + dim = 4 + + vid_total = int(vid_len_win.sum().item()) + txt_total = int(txt_len.sum().item()) + vid_q = torch.randn(vid_total, heads, dim, requires_grad=True) + vid_k = torch.randn(vid_total, heads, dim, requires_grad=True) + vid_v = torch.randn(vid_total, heads, dim, requires_grad=True) + txt_q = torch.randn(txt_total, heads, dim, requires_grad=True) + txt_k = torch.randn(txt_total, heads, dim, requires_grad=True) + txt_v = torch.randn(txt_total, heads, dim, requires_grad=True) + + split_vid, split_txt = seedvr_model._seedvr2_7b_window_attention_split( + vid_q, txt_q, vid_k, txt_k, vid_v, txt_v, + vid_len_win, txt_len, window_count, + ) + (split_vid.sum() + split_txt.sum()).backward() + + for tensor in (vid_q, vid_k, vid_v, txt_q, txt_k, txt_v): + assert tensor.grad is not None + + +def test_seedvr2_7b_mlp_chunks_video_tokens(monkeypatch): + class TrackingModule(torch.nn.Module): + def __init__(self, scale): + super().__init__() + self.scale = scale + self.calls = [] + + def forward(self, x): + self.calls.append(x.shape[0]) + return x * self.scale + + monkeypatch.setattr(seedvr_model, "SEEDVR2_7B_MLP_CHUNK", 2) + + vid_module = TrackingModule(2.0) + txt_module = TrackingModule(3.0) + block = SimpleNamespace( + mlp=SimpleNamespace( + shared_weights=False, + vid_only=False, + vid=vid_module, + txt=txt_module, + ) + ) + vid = torch.arange(24, dtype=torch.float32).reshape(6, 4) + txt = torch.arange(12, dtype=torch.float32).reshape(3, 4) + + out_vid, out_txt = seedvr_model.NaMMSRTransformerBlock._seedvr2_7b_mlp(block, vid, txt) + + assert vid_module.calls == [2, 2, 2] + assert txt_module.calls == [3] + torch.testing.assert_close(out_vid, vid * 2.0) + torch.testing.assert_close(out_txt, txt * 3.0) + + +def test_seedvr2_7b_mlp_preserves_video_autograd(monkeypatch): + class TrackingModule(torch.nn.Module): + def forward(self, x): + return x * 2.0 + + monkeypatch.setattr(seedvr_model, "SEEDVR2_7B_MLP_CHUNK", 2) + + block = SimpleNamespace( + mlp=SimpleNamespace( + shared_weights=False, + vid_only=True, + vid=TrackingModule(), + ) + ) + vid_base = torch.arange(24, dtype=torch.float32, requires_grad=True) + vid = vid_base.reshape(6, 4) + txt = torch.arange(12, dtype=torch.float32).reshape(3, 4) + + out_vid, _ = seedvr_model.NaMMSRTransformerBlock._seedvr2_7b_mlp(block, vid, txt) + out_vid.sum().backward() + + assert vid_base.grad is not None + + +def test_seedvr2_7b_block_routes_mlp_to_chunk_helper(): + source = inspect.getsource(seedvr_model.NaMMSRTransformerBlock.forward) + + assert "if self.version" in source + assert "_seedvr2_7b_mlp" in source + + +def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): + estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) + old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 + + assert estimate == 101 * 960 * 1280 * 160 + assert estimate > 15 * 1024 ** 3 + assert estimate > old_estimate * 100 + + +def test_seedvr2_vae_decode_memory_estimate_is_per_sample(): + single = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) + batch = comfy.sd._seedvr2_vae_decode_memory_used((2, 16, 26, 120, 160)) + + assert batch == single + + +def test_seedvr2_vae_decode_memory_accepts_channel_last_tiled_latents(): + channel_first = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) + channel_last = comfy.sd._seedvr2_vae_decode_memory_used((1, 26, 120, 160, 16)) + + assert channel_last == channel_first + + +def test_seedvr2_vae_decode_memory_rounds_malformed_collapsed_channels_up(): + malformed = comfy.sd._seedvr2_vae_decode_memory_used((1, 17, 120, 160)) + expected = comfy.sd._seedvr2_vae_decode_output_pixels(2, 120, 160) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL + + assert malformed == expected + + +def test_seedvr2_vae_decode_memory_uses_conservative_ambiguous_5d_layout(): + ambiguous = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 120, 160, 16)) + channel_first = comfy.sd._seedvr2_vae_decode_output_pixels(120, 160, 16) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL + channel_last = comfy.sd._seedvr2_vae_decode_output_pixels(16, 120, 160) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL + + assert ambiguous == max(channel_first, channel_last) diff --git a/tests-unit/comfy_test/test_seedvr_7b_final_block_text_path.py b/tests-unit/comfy_test/test_seedvr_7b_final_block_text_path.py new file mode 100644 index 000000000..5d5847f8f --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_7b_final_block_text_path.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import torch +from torch import nn + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 + + +class _StubModule(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + +def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]: + flags = [] + + class _Block(_StubModule): + def __init__(self, *args, **kwargs): + flags.append(kwargs["is_last_layer"]) + super().__init__() + + monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule) + monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule) + monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule) + monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block) + + seedvr_model.NaDiT( + norm_eps=1e-5, + qk_rope=None, + num_layers=4, + mlp_type="normal", + vid_dim=vid_dim, + txt_in_dim=txt_in_dim, + heads=24, + mm_layers=3, + ) + + return flags + + +def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): + assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ + False, + False, + False, + False, + ] + + +def test_seedvr2_3b_keeps_final_block_vid_only_path(monkeypatch): + assert _capture_last_layer_flags(monkeypatch, vid_dim=2560, txt_in_dim=2560) == [ + False, + False, + False, + True, + ] + + +def _capture_block_attention_rope_type(monkeypatch, qk_rope): + rope_types = [] + + class _Attention(_StubModule): + def __init__(self, *args, **kwargs): + rope_types.append(kwargs["rope_type"]) + super().__init__() + + monkeypatch.setattr(seedvr_model, "MMModule", _StubModule) + monkeypatch.setattr(seedvr_model, "NaSwinAttention", _Attention) + + seedvr_model.NaMMSRTransformerBlock( + vid_dim=4, + txt_dim=4, + emb_dim=4, + heads=1, + head_dim=4, + expand_ratio=1, + norm=_StubModule, + norm_eps=1e-5, + ada=_StubModule, + qk_bias=False, + qk_rope=qk_rope, + qk_norm=_StubModule, + mlp_type="normal", + shared_weights=False, + rope_type="mmrope3d", + rope_dim=4, + is_last_layer=False, + device="cpu", + dtype=torch.float32, + operations=seedvr_model.comfy.ops.disable_weight_init, + ) + + return rope_types + + +def test_seedvr2_3b_qk_rope_none_preserves_checkpoint_rope_buffers(monkeypatch): + assert _capture_block_attention_rope_type(monkeypatch, qk_rope=None) == ["mmrope3d"] + + +def test_seedvr2_7b_qk_rope_true_preserves_attention_rope(monkeypatch): + assert _capture_block_attention_rope_type(monkeypatch, qk_rope=True) == ["mmrope3d"] + + +def test_seedvr2_7b_rope3d_matches_checkpoint_buffer_shape(): + rope = seedvr_model.get_na_rope("rope3d", dim=64) + + assert isinstance(rope, seedvr_model.NaRotaryEmbedding3d) + assert tuple(rope.rope.freqs.shape) == (10,) + + +def test_seedvr2_7b_rope3d_preserves_qk_shape(): + rope = seedvr_model.get_na_rope("rope3d", dim=64) + q = torch.randn(4, 2, 128) + k = torch.randn(4, 2, 128) + shape = torch.tensor([[1, 2, 2]], dtype=torch.long) + + q_out, k_out = rope(q, k, shape, seedvr_model.Cache(disable=True)) + + assert q_out.shape == q.shape + assert k_out.shape == k.shape + + +def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): + rope = seedvr_model.get_na_rope("rope3d", dim=64) + generator = torch.Generator(device="cpu").manual_seed(0) + q = torch.randn(4, 2, 128, generator=generator) + k = torch.randn(4, 2, 128, generator=generator) + shape = torch.tensor([[1, 2, 2]], dtype=torch.long) + freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) + + expected_q = seedvr_model.apply_rotary_emb( + freqs, + q.permute(1, 0, 2).float(), + ).to(q.dtype).permute(1, 0, 2) + expected_k = seedvr_model.apply_rotary_emb( + freqs, + k.permute(1, 0, 2).float(), + ).to(k.dtype).permute(1, 0, 2) + + actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True)) + + torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0) + torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) + + +def test_seedvr2_mmrope_handles_large_spatial_grid_without_truncation(): + rope = seedvr_model.NaMMRotaryEmbedding3d(dim=12) + vid_shape = torch.tensor([[1, 129, 130]], dtype=torch.long) + txt_shape = torch.tensor([[2]], dtype=torch.long) + vid_tokens = int(vid_shape.prod().item()) + txt_tokens = int(txt_shape.prod().item()) + vid_q = torch.zeros(vid_tokens, 1, 12) + vid_k = torch.zeros_like(vid_q) + txt_q = torch.zeros(txt_tokens, 1, 12) + txt_k = torch.zeros_like(txt_q) + + out = rope(vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, seedvr_model.Cache(disable=True)) + + assert [tuple(t.shape) for t in out] == [ + tuple(vid_q.shape), + tuple(vid_k.shape), + tuple(txt_q.shape), + tuple(txt_k.shape), + ] + + +def test_adasingle_init_preserves_supported_dtype(): + ada = seedvr_model.AdaSingle( + dim=4, + emb_dim=24, + layers=["test"], + modes=["in", "out"], + device="cpu", + dtype=torch.bfloat16, + ) + + assert ada.test_shift.dtype is torch.bfloat16 + assert ada.test_scale.dtype is torch.bfloat16 + assert ada.test_gate.dtype is torch.bfloat16 + + +def test_adasingle_init_uses_default_dtype_for_fp8(): + if not hasattr(torch, "float8_e4m3fn"): + return + + ada = seedvr_model.AdaSingle( + dim=4, + emb_dim=24, + layers=["test"], + modes=["in", "out"], + device="cpu", + dtype=torch.float8_e4m3fn, + ) + + assert ada.test_shift.dtype is torch.float32 + assert ada.test_scale.dtype is torch.float32 + assert ada.test_gate.dtype is torch.float32 + + +def test_adasingle_init_and_forward_share_fp8_dtype_set(): + expected = { + getattr(torch, name) + for name in ( + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, name) + } + + assert set(seedvr_model._torch_float8_types()) == expected diff --git a/tests-unit/comfy_test/test_seedvr_forward_no_device_cast.py b/tests-unit/comfy_test/test_seedvr_forward_no_device_cast.py new file mode 100644 index 000000000..802588ebd --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_forward_no_device_cast.py @@ -0,0 +1,54 @@ +from comfy.cli_args import args +import torch + +if not torch.cuda.is_available(): + args.cpu = True + +import ast # noqa: E402 +import inspect # noqa: E402 + +from torch import nn # noqa: E402 + +import comfy # noqa: E402 +import comfy.ldm.seedvr.model # noqa: E402 +import comfy.model_management # noqa: E402 +from comfy.ldm.seedvr.model import MMModule # noqa: E402 + + +def test_no_get_torch_device_in_forward_methods(): + tree = ast.parse(inspect.getsource(comfy.ldm.seedvr.model)) + assert [ + (n.lineno, i.lineno) + for n in ast.walk(tree) + if isinstance(n, ast.FunctionDef) and n.name == "forward" + for i in ast.walk(n) + if isinstance(i, ast.Call) + and isinstance(i.func, ast.Attribute) + and i.func.attr == "get_torch_device" + ] == [] + + +def test_mmmodule_forward_succeeds_without_get_torch_device_lookup(monkeypatch): + call_count = [0] + + def boom(): + call_count[0] += 1 + raise RuntimeError("MMModule.forward called get_torch_device()") + + monkeypatch.setattr(comfy.model_management, "get_torch_device", boom) + + class _IdentityCallable(nn.Module): + def forward(self, x, *args, **kwargs): + return x + + mm = MMModule(_IdentityCallable, shared_weights=False, vid_only=False) + + vid_in = torch.zeros(2, 4) + txt_in = torch.ones(2, 4) + vid_out, txt_out = mm.forward(vid_in, txt_in) + + assert call_count[0] == 0 + assert torch.equal(vid_out, vid_in) + assert torch.equal(txt_out, txt_in) + assert vid_out.device == vid_in.device + assert txt_out.device == txt_in.device diff --git a/tests-unit/comfy_test/test_seedvr_groupnorm_limit.py b/tests-unit/comfy_test/test_seedvr_groupnorm_limit.py new file mode 100644 index 000000000..e610bbbc4 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_groupnorm_limit.py @@ -0,0 +1,179 @@ +"""Regression: ``comfy.ldm.seedvr.vae.causal_norm_wrapper`` 5D GroupNorm +gate at ``vae.py:509`` must compare ``memory_occupy`` against the configured +``get_norm_limit()`` accessor, not against a hardcoded ``float('inf')``. + +The original code path was ``... > float('inf')`` which is unreachable at any +finite ``memory_occupy`` value, so SeedVR2's ``norm_max_mem`` setting (wired +through ``set_norm_limit``) had no effect. + +This module locks in two complementary cases against any future regression, +parametrized over both ``ops.GroupNorm`` subclasses (``disable_weight_init`` and +``manual_cast``) since the production gate ``isinstance(norm_layer, ops.GroupNorm)`` +matches both. + +* ``test_seedvr_groupnorm_default_limit_uses_full_groupnorm_path`` — with + the limit at its default ``inf``, the full GroupNorm forward must run and + the chunked branch must NOT run, regardless of input tensor size. +* ``test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path`` — with a + deliberately low limit (``1e-9 GiB``), the chunked branch must run and + the full GroupNorm forward must NOT run. + +Each case discriminates the two branches with two independent observers: + +1. ``nn.Module.register_forward_hook`` on the GroupNorm — fires only on the + full-path branch ``norm_layer(x)``; the chunked branch bypasses the + module ``__call__`` and goes through ``F.group_norm`` directly. +2. ``unittest.mock.patch.object(vae.F, 'group_norm', ...)`` spy with + ``side_effect`` delegating to the real ``torch.nn.functional.group_norm`` + — captures every direct ``F.group_norm`` call's ``num_groups`` argument. + Calls with ``num_groups < gn.num_groups`` come from the chunked branch + (``num_groups_per_chunk = gn.num_groups // num_chunks``). + +The spy uses ``*args, **kwargs`` passthrough so future ``F.group_norm`` kwargs +do not break the test. + +CPU-only by construction: the tests use a small float32 tensor and never +allocate a real model or GPU memory. +""" + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ops as comfy_ops # noqa: E402 +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +from comfy.ldm.seedvr.vae import ( # noqa: E402 + causal_norm_wrapper, + set_norm_limit, +) + + +_NUM_CHANNELS = 8 +_NUM_GROUPS = 4 +_TENSOR_SHAPE = (1, 8, 2, 4, 4) + +# Both ``ops.GroupNorm`` subclasses appear in production paths depending on +# the active backend. The dispatch gate at ``vae.py:509`` reads +# ``isinstance(norm_layer, ops.GroupNorm)`` and matches both via MRO. +_GROUPNORM_SUBCLASSES = [ + pytest.param( + comfy_ops.disable_weight_init.GroupNorm, + id="disable_weight_init", + ), + pytest.param( + comfy_ops.manual_cast.GroupNorm, + id="manual_cast", + ), +] + + +@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) +def test_seedvr_groupnorm_default_limit_uses_full_groupnorm_path(groupnorm_cls): + real_group_norm = vae_mod.F.group_norm + set_norm_limit(None) + try: + gn = groupnorm_cls( + num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS + ) + gn.eval() + + forward_hook_calls = [] + + def _hook(module, inputs, output): + forward_hook_calls.append(tuple(inputs[0].shape)) + + spy_calls = [] + + def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): + spy_calls.append({ + "num_groups": int(num_groups_arg), + "input_shape": tuple(int(s) for s in input_tensor.shape), + }) + return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) + + handle = gn.register_forward_hook(_hook) + try: + with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): + out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) + finally: + handle.remove() + + full_calls = len(forward_hook_calls) + chunked_calls = sum( + 1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS + ) + + assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE, ( + f"causal_norm_wrapper output shape {tuple(out_tensor.shape)} does not " + f"match input shape {_TENSOR_SHAPE}" + ) + assert full_calls == 1, ( + f"default-limit (inf) GroupNorm gate must take the full-forward path " + f"(register_forward_hook fires exactly once); got full_calls={full_calls}" + ) + assert chunked_calls == 0, ( + f"default-limit (inf) GroupNorm gate must NOT take the chunked path " + f"(no F.group_norm call with num_groups<{_NUM_GROUPS}); got " + f"chunked_calls={chunked_calls}" + ) + finally: + set_norm_limit(None) + + +@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) +def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): + real_group_norm = vae_mod.F.group_norm + set_norm_limit(1e-9) + try: + gn = groupnorm_cls( + num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS + ) + gn.eval() + + forward_hook_calls = [] + + def _hook(module, inputs, output): + forward_hook_calls.append(tuple(inputs[0].shape)) + + spy_calls = [] + + def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): + spy_calls.append({ + "num_groups": int(num_groups_arg), + "input_shape": tuple(int(s) for s in input_tensor.shape), + }) + return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) + + handle = gn.register_forward_hook(_hook) + try: + with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): + out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) + finally: + handle.remove() + + full_calls = len(forward_hook_calls) + chunked_calls = sum( + 1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS + ) + + assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE, ( + f"causal_norm_wrapper output shape {tuple(out_tensor.shape)} does not " + f"match input shape {_TENSOR_SHAPE}" + ) + assert full_calls == 0, ( + f"low-limit GroupNorm gate must NOT take the full-forward path " + f"(register_forward_hook should not fire); got full_calls={full_calls}" + ) + assert chunked_calls > 0, ( + f"low-limit GroupNorm gate must take the chunked path " + f"(at least one F.group_norm call with num_groups<{_NUM_GROUPS}); got " + f"chunked_calls={chunked_calls}" + ) + finally: + set_norm_limit(None) diff --git a/tests-unit/comfy_test/test_seedvr_latent_format.py b/tests-unit/comfy_test/test_seedvr_latent_format.py new file mode 100644 index 000000000..998993c1d --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_latent_format.py @@ -0,0 +1,40 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.latent_formats +import comfy.sample + + +class _Model: + def __init__(self, latent_format): + self._latent_format = latent_format + + def get_model_object(self, name): + assert name == "latent_format" + return self._latent_format + + +def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): + latent_format = comfy.latent_formats.SeedVR2() + latent_image = torch.zeros(1, 1, 4, 5) + + fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) + + assert latent_format.latent_channels == 16 + assert latent_format.latent_dimensions == 2 + assert fixed.shape == (1, 16, 4, 5) + + +def test_seedvr2_empty_collapsed_latent_preserves_temporal_channel_multiples(): + latent_format = comfy.latent_formats.SeedVR2() + latent_image = torch.zeros(1, 48, 4, 5) + + fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) + + assert latent_format.preserve_empty_channel_multiples is True + assert fixed.shape == latent_image.shape + assert fixed.data_ptr() == latent_image.data_ptr() diff --git a/tests-unit/comfy_test/test_seedvr_rope_delegation.py b/tests-unit/comfy_test/test_seedvr_rope_delegation.py new file mode 100644 index 000000000..99d44f069 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_rope_delegation.py @@ -0,0 +1,176 @@ +"""Regression test: ``comfy.ldm.seedvr.model.apply_rotary_emb`` must delegate +to ``comfy.ldm.flux.math.apply_rope1`` and produce exact-equality output +across the wrapper's slicing, scaling, and concatenation logic. Drift between +the wrapper and the delegate would silently corrupt SeedVR2's RoPE; this test +fails loudly on any future drift. + +Each parametrized case does both: + +1. Patches ``comfy.ldm.seedvr.model.apply_rope1`` with a ``wraps``-style spy + and asserts ``spy.call_count >= 1`` so a future change that inlines the + math and stops calling ``apply_rope1`` fails the test. +2. Compares the wrapper's output against a hand-rolled reproduction using + ``torch.testing.assert_close(rtol=0, atol=0)`` -- exact tensor equality, + not bit-equality (``+0.0`` vs ``-0.0`` and NaN payloads can still match); + the assertion catches any future kernel-precision drift in the + ``apply_rope1`` dispatch. + +The test uses a local ``torch.Generator`` so global RNG state is not mutated. +Parametrization covers non-default ``start_index`` and ``scale`` and a case +where ``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's +``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised. +Imports are taken at module level. Heavy-import stubbing of +``comfy.model_management`` was attempted but is insufficient on this live +import chain (``comfy.ldm.seedvr.model`` pulls +``comfy.ldm.modules.diffusionmodules.model -> comfy.ops -> +comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor -> +torch._dynamo``), so this test intentionally runs against the real modules +to fail loudly if that import path or runtime state drifts. Other tests in +this repo (e.g. ``tests-unit/comfy_extras_test/image_stitch_test.py``) do +stub via ``patch.dict(sys.modules, ...)`` for narrower targets; the choice +here is local to this regression and not a repo-wide convention. +""" + +from unittest.mock import patch + +import pytest +import torch + +# CPU-only CI fix: ``comfy.ldm.seedvr.model`` transitively imports +# ``comfy.model_management``, whose import-time ``get_torch_device()`` call +# probes ``torch.cuda.current_device()`` unless ``comfy.cli_args.args.cpu`` is +# set. On a CPU-only build that probe can raise during test collection before +# the ``cuda`` case has had a chance to be skipped. Match the pattern used by +# ``tests-unit/comfy_quant/test_mixed_precision.py``: flip ``args.cpu`` before +# importing any ``comfy.ldm.*`` symbol. +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +from comfy.ldm.flux.math import apply_rope1 # noqa: E402 +from comfy.ldm.seedvr.model import apply_rotary_emb # noqa: E402 + + +def _direct_reproduction(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + """Reproduce the body of ``apply_rotary_emb`` for the default case where + ``freqs.ndim == 2`` and ``t.ndim == 3`` (implicit ``freqs_seq_dim=0``). + Mirrors the wrapper's ``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` + step when freqs is longer than ``t`` along ``seq_dim``. Calls the real + ``apply_rope1`` via the test module's import (the test patches the + ``seedvr_model.apply_rope1`` attribute; this call uses the unpatched + ``flux.math`` symbol). + """ + if freqs.ndim == 2 and t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + angles = freqs.to(t_middle.device)[..., ::2] + cos = torch.cos(angles) * scale + sin = torch.sin(angles) * scale + col0 = torch.stack([cos, sin], dim=-1) + col1 = torch.stack([-sin, cos], dim=-1) + freqs_mat = torch.stack([col0, col1], dim=-1) + t_middle_out = apply_rope1(t_middle, freqs_mat) + return torch.cat((t_left, t_middle_out, t_right), dim=-1).type(t.dtype) + + +def _cpu_trig_supported(dtype): + """Return whether ``torch.cos`` (and by symmetry ``torch.sin``) is + implemented for the given dtype on CPU on the current runtime. Some + PyTorch CPU wheels don't implement trig ops for ``float16`` / ``bfloat16`` + and raise at runtime; the parametrized cases for those dtypes are skipped + when that's the case so CI remains stable across PyTorch builds. + """ + try: + torch.cos(torch.zeros(1, dtype=dtype)) + except (RuntimeError, TypeError): + return False + return True + + +_CPU_FP16_TRIG_OK = _cpu_trig_supported(torch.float16) +_CPU_BF16_TRIG_OK = _cpu_trig_supported(torch.bfloat16) + + +# (device, dtype, t_shape, freqs_shape, start_index, scale) +_CASES = [ + pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-float32-base"), + pytest.param( + "cpu", torch.float16, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-float16-base", + marks=pytest.mark.skipif( + not _CPU_FP16_TRIG_OK, + reason="torch.cos/torch.sin unsupported for float16 tensors on CPU", + ), + ), + pytest.param( + "cpu", torch.bfloat16, (1, 8, 16), (8, 16), 0, 1.0, + id="cpu-bfloat16-base", + marks=pytest.mark.skipif( + not _CPU_BF16_TRIG_OK, + reason="torch.cos/torch.sin unsupported for bfloat16 tensors on CPU", + ), + ), + pytest.param("cpu", torch.float32, (2, 16, 32), (16, 32), 0, 1.0, + id="cpu-float32-larger"), + pytest.param("cpu", torch.float32, (1, 8, 24), (8, 16), 4, 1.0, + id="cpu-float32-non-empty-left-and-right-slices"), + pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 0.5, + id="cpu-float32-non-default-scale"), + pytest.param("cpu", torch.float32, (1, 8, 16), (12, 16), 0, 1.0, + id="cpu-float32-freqs-longer-than-seq"), + pytest.param( + "cuda", torch.float16, (1, 8, 16), (8, 16), 0, 1.0, + id="cuda-float16-base", + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda"), + ), +] + + +@pytest.mark.parametrize("device,dtype,t_shape,freqs_shape,start_index,scale", _CASES) +def test_apply_rotary_emb_delegates_to_apply_rope1( + device, dtype, t_shape, freqs_shape, start_index, scale +): + generator = torch.Generator(device=device).manual_seed(0) + t = torch.randn(*t_shape, dtype=dtype, device=device, generator=generator) + freqs = torch.randn(*freqs_shape, dtype=dtype, device=device, generator=generator) + + # Patch the apply_rope1 symbol as imported into seedvr.model with a wraps + # spy: a future change that inlines the math and stops calling the + # imported apply_rope1 makes spy.call_count == 0 and fails the test. + with patch.object( + seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1 + ) as spy: + wrapper_out = apply_rotary_emb( + freqs, t, start_index=start_index, scale=scale + ) + + assert spy.call_count >= 1, ( + "apply_rotary_emb did not call comfy.ldm.seedvr.model.apply_rope1; " + "the delegation invariant is broken" + ) + + direct_out = _direct_reproduction( + freqs, t, start_index=start_index, scale=scale + ) + + msg = ( + f"apply_rotary_emb output does not match direct apply_rope1 " + f"reproduction (device={device}, dtype={dtype}, t_shape={t_shape}, " + f"freqs_shape={freqs_shape}, start_index={start_index}, scale={scale})" + ) + torch.testing.assert_close( + wrapper_out, + direct_out, + rtol=0, + atol=0, + msg=msg, + ) diff --git a/tests-unit/comfy_test/test_seedvr_rope_rewrite.py b/tests-unit/comfy_test/test_seedvr_rope_rewrite.py new file mode 100644 index 000000000..5b06eed7d --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_rope_rewrite.py @@ -0,0 +1,335 @@ +"""Regression tests for the SeedVR2 native RoPE rewrite that replaces the +``apply_rotary_emb`` wrapper inside ``NaMMRotaryEmbedding3d.forward`` with +direct calls to ``comfy.ldm.flux.math.apply_rope1`` — matching the pattern +used by the other 7 ComfyUI native-DiT models (flux, hidream, kandinsky5, +lumina, qwen_image, wan, sam3). + +The wrapper builds a 2x2 ``freqs_mat`` and ends in ``torch.cat((t_left, +t_middle_out, t_right), dim=-1)``; that cat OOMs on the largest cell of the +SeedVR2 native_3b non-tiled corpus (VideoLQ_000 1280x960x100 on RTX 5090 +32GB). Canonical and numz pass the same cell because both call +``rotary_embedding_torch.apply_rotary_emb`` directly. The fix moves the +NaMMRotaryEmbedding3d path onto ``apply_rope1`` directly with freqs in +flux-canonical shape ``[..., d/2, 2, 2]`` (cos/-sin/sin/cos baked in). + +This test file pins four invariants the rewrite must satisfy: + +1. ``NaMMRotaryEmbedding3d.forward`` calls ``apply_rope1`` 4 times per + forward (vid_q, vid_k, txt_q, txt_k) and 0 times into the + ``apply_rotary_emb`` wrapper. +2. ``NaMMRotaryEmbedding3d.get_freqs`` returns freqs in flux-canonical shape + ``[..., d/2, 2, 2]`` with the cos/-sin/sin/cos pattern from + ``comfy/ldm/flux/math.py:rope`` (line 27). +3. The forward output is tensor-equal at fp32 against an oracle computed + from the unchanged ``apply_rotary_emb`` wrapper fed with the legacy + freqs layout — proving the rewrite is algorithmically lossless. +4. AST: no ``apply_rotary_emb`` call sites remain inside + ``NaMMRotaryEmbedding3d.forward``. + +The wrapper itself stays in the file (still used by +``RotaryEmbedding3d.forward`` lines 434-435 and the staticmethod +registration on lucidrains' ``RotaryEmbedding`` line 323). Out of scope +here. + +Pre-import CPU-only guard mirrors ``test_seedvr_rope_delegation.py`` — +``comfy.ldm.seedvr.model`` transitively imports ``comfy.model_management`` +which probes ``torch.cuda.current_device()`` at import time unless +``args.cpu`` is set first. +""" + +from __future__ import annotations + +import ast +import inspect +from pathlib import Path +from unittest.mock import patch + +import torch + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +from comfy.ldm.seedvr.model import ( # noqa: E402 + Cache, + NaMMRotaryEmbedding3d, +) + + +# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains +# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8. +# heads = 4. These are all small enough to run on CPU in milliseconds. +_DIM = 192 +_HEADS = 4 +_VID_T, _VID_H, _VID_W = 2, 4, 4 +_TXT_L = 8 +_L_VID = _VID_T * _VID_H * _VID_W +_SEED = 0 + + +def _make_inputs(dtype=torch.float32, device="cpu"): + """Construct the 6 forward inputs + cache. Deterministic via local + Generator so global RNG state is not mutated. + """ + g = torch.Generator(device=device).manual_seed(_SEED) + vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device) + txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device) + cache = Cache(disable=True) + return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache + + +def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape): + """Reproduce the pre-rewrite ``get_freqs`` body verbatim against + ``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method, + unchanged by the rewrite). Used by Test 3 to compute the oracle from + the wrapper path post-rewrite, when ``rope.get_freqs`` itself returns + the new flux-canonical shape. + """ + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + with torch.amp.autocast(device_type="cuda", enabled=False): + vid_freqs_full = rope.get_axial_freqs( + min(max_temporal + 16, 1024), + min(max_height + 4, 128), + min(max_width + 4, 128), + ).float() + txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024)) + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1)) + txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) + + +def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape, + txt_q, txt_k, txt_shape): + """Compute expected forward output via the unchanged + ``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the + oracle. The wrapper itself is out of scope for the rewrite (Shape B). + """ + vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape) + vid_freqs = vid_freqs.to(vid_q.device) + txt_freqs = txt_freqs.to(txt_q.device) + + from einops import rearrange + + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) + vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) + vid_q_out = rearrange(vid_q_out, "h L d -> L h d") + vid_k_out = rearrange(vid_k_out, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) + txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) + txt_q_out = rearrange(txt_q_out, "h L d -> L h d") + txt_k_out = rearrange(txt_k_out, "h L d -> L h d") + return vid_q_out, vid_k_out, txt_q_out, txt_k_out + + +# Test 1 — drives AC-4 (call-graph): forward must reach apply_rope1 directly, +# never via the apply_rotary_emb wrapper. + +def test_namm_forward_calls_apply_rope1_directly(): + rope = NaMMRotaryEmbedding3d(dim=_DIM) + vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() + + with patch.object( + seedvr_model, "apply_rotary_emb", wraps=seedvr_model.apply_rotary_emb + ) as wrapper_spy, patch.object( + seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1 + ) as rope1_spy: + rope.forward(vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache) + + assert wrapper_spy.call_count == 0, ( + f"NaMMRotaryEmbedding3d.forward must not call apply_rotary_emb " + f"(saw {wrapper_spy.call_count} calls); the rewrite must rewire " + f"the 4 forward sites to apply_rope1 directly" + ) + assert rope1_spy.call_count == 4, ( + f"NaMMRotaryEmbedding3d.forward must call apply_rope1 exactly 4 " + f"times (vid_q, vid_k, txt_q, txt_k); saw {rope1_spy.call_count}" + ) + + +# Test 2 — drives the get_freqs shape change to flux-canonical layout. + +def test_get_freqs_emits_flux_canonical_shape(): + rope = NaMMRotaryEmbedding3d(dim=_DIM) + vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long) + txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long) + + vid_freqs, txt_freqs = rope.get_freqs(vid_shape, txt_shape) + + # Flux's `rope()` (comfy/ldm/flux/math.py:17-29) emits freqs in shape + # [..., d/2, 2, 2] via stack([cos, -sin, sin, cos], dim=-1) + + # rearrange("b n d (i j) -> b n d i j", i=2, j=2). The rewrite must + # match: ndim >= 4, last two dims both == 2. + assert vid_freqs.ndim >= 4, ( + f"vid_freqs.ndim must be >= 4 (flux-canonical layout has trailing " + f"[..., d/2, 2, 2]); got ndim={vid_freqs.ndim}, shape={tuple(vid_freqs.shape)}" + ) + assert vid_freqs.shape[-1] == 2, ( + f"vid_freqs.shape[-1] must be 2 (rotation matrix column); got " + f"shape={tuple(vid_freqs.shape)}" + ) + assert vid_freqs.shape[-2] == 2, ( + f"vid_freqs.shape[-2] must be 2 (rotation matrix row); got " + f"shape={tuple(vid_freqs.shape)}" + ) + assert txt_freqs.ndim >= 4, ( + f"txt_freqs must also be flux-canonical; got ndim={txt_freqs.ndim}, " + f"shape={tuple(txt_freqs.shape)}" + ) + assert txt_freqs.shape[-1] == 2 and txt_freqs.shape[-2] == 2, ( + f"txt_freqs trailing dims must be (2, 2); got shape={tuple(txt_freqs.shape)}" + ) + + # Verify the cos/-sin/sin/cos pattern at index 0: + # freqs_cis[..., 0, 0] = cos + # freqs_cis[..., 0, 1] = -sin + # freqs_cis[..., 1, 0] = sin + # freqs_cis[..., 1, 1] = cos + # so [0,0] == [1,1] (both cos) and [0,1] == -[1,0] (=-sin vs +sin). + cos_a = vid_freqs[..., 0, 0] + cos_b = vid_freqs[..., 1, 1] + neg_sin = vid_freqs[..., 0, 1] + sin = vid_freqs[..., 1, 0] + assert torch.allclose(cos_a, cos_b, rtol=0, atol=0), ( + "vid_freqs[..., 0, 0] must equal vid_freqs[..., 1, 1] (both = cos)" + ) + assert torch.allclose(neg_sin, -sin, rtol=0, atol=0), ( + "vid_freqs[..., 0, 1] must equal -vid_freqs[..., 1, 0] (= -sin vs +sin)" + ) + + +# Test 3 — drives AC-1: forward output is tensor-equal against the wrapper- +# fed oracle. Pre-rewrite: trivially passes (forward IS the wrapper path). +# Post-rewrite: must remain equal. Exact equality (rtol=atol=0) at fp32. + +def test_namm_forward_output_tensor_equal_against_legacy_oracle(): + rope = NaMMRotaryEmbedding3d(dim=_DIM) + vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() + + # Oracle: the unchanged apply_rotary_emb wrapper fed with legacy-shape + # freqs produced by reproducing the pre-rewrite get_freqs body. + expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( + rope, + vid_q.clone(), vid_k.clone(), vid_shape, + txt_q.clone(), txt_k.clone(), txt_shape, + ) + + # Actual: NaMMRotaryEmbedding3d.forward (under test). + actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( + vid_q.clone(), vid_k.clone(), vid_shape, + txt_q.clone(), txt_k.clone(), txt_shape, cache, + ) + + torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, + msg="vid_q output diverges from wrapper oracle") + torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, + msg="vid_k output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, + msg="txt_q output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, + msg="txt_k output diverges from wrapper oracle") + + +# Test 5 — partial-rope coverage. The real SeedVR2-3B model is constructed +# with rope_dim=128, which integer-divides into 3 axes as 128//3 = 42 per- +# axis; total rope freq dims = 42*3 = 126. head_dim is 128, so the trailing +# 2 dims of each q/k must be passed through unrotated (matching the legacy +# wrapper's `t_right = t[..., end_index:]` behavior). The fp32-CPU oracle +# test (Test 3) uses dim=192 where rot_d == head_dim and the partial-rope +# path collapses to a single apply_rope1 call. This test exercises the +# partial path explicitly with dim=128 and asserts the rewired forward +# still tensor-equals the wrapper oracle in that regime. + +def test_namm_forward_partial_rope_passthrough_matches_wrapper_oracle(): + rope = NaMMRotaryEmbedding3d(dim=128) + g = torch.Generator(device="cpu").manual_seed(_SEED) + vid_q = torch.randn(_L_VID, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g) + vid_k = torch.randn(_L_VID, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g) + txt_q = torch.randn(_TXT_L, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g) + txt_k = torch.randn(_TXT_L, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g) + vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long) + txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long) + cache = Cache(disable=True) + + expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( + rope, vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape, + ) + actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( + vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape, cache, + ) + + # Confirm the partial-rope contract: rot_d (= 2 * freqs_cis.shape[-3]) is + # 126 (= 42*3), strictly less than head_dim 128. The trailing 2 head-dims + # are pass-through. + vid_freqs, _ = rope.get_freqs(vid_shape, txt_shape) + rot_d = 2 * vid_freqs.shape[-3] + assert rot_d == 126, f"expected rot_d=126 for dim=128 model; got {rot_d}" + assert rot_d < 128, "partial-rope path must trigger (rot_d < head_dim)" + + torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, + msg="vid_q partial-rope output diverges from wrapper oracle") + torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, + msg="vid_k partial-rope output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, + msg="txt_q partial-rope output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, + msg="txt_k partial-rope output diverges from wrapper oracle") + + +# Test 4 — drives AC-4 statically: AST walk over NaMMRotaryEmbedding3d.forward +# must find zero references to the apply_rotary_emb symbol. + +def test_namm_forward_ast_has_no_apply_rotary_emb_calls(): + source_path = Path(inspect.getsourcefile(NaMMRotaryEmbedding3d)) + tree = ast.parse(source_path.read_text(encoding="utf-8")) + + namm_class = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == "NaMMRotaryEmbedding3d": + namm_class = node + break + assert namm_class is not None, ( + f"could not locate class NaMMRotaryEmbedding3d in {source_path}" + ) + + forward_fn = None + for node in namm_class.body: + if isinstance(node, ast.FunctionDef) and node.name == "forward": + forward_fn = node + break + assert forward_fn is not None, ( + "could not locate NaMMRotaryEmbedding3d.forward" + ) + + offending = [] + for node in ast.walk(forward_fn): + if isinstance(node, ast.Name) and node.id == "apply_rotary_emb": + offending.append((node.lineno, node.col_offset)) + + assert not offending, ( + f"NaMMRotaryEmbedding3d.forward must not reference apply_rotary_emb; " + f"found {len(offending)} reference(s) at line:col positions {offending}. " + f"The rewrite must rewire to apply_rope1 directly." + ) diff --git a/tests-unit/comfy_test/test_seedvr_vae_attention_fence.py b/tests-unit/comfy_test/test_seedvr_vae_attention_fence.py new file mode 100644 index 000000000..e5340116f --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_vae_attention_fence.py @@ -0,0 +1,37 @@ +from unittest.mock import patch + +import torch +from torch import nn + +import comfy.ldm.seedvr.vae as seedvr_vae + + +def test_seedvr_vae_4d_self_attention_uses_vae_attention_with_channel_first_layout(): + calls = {} + + def vae_attention_spy(q, k, v): + calls["q"] = q.detach().clone() + calls["k"] = k.detach().clone() + calls["v"] = v.detach().clone() + return q + + def global_attention_forbidden(*args, **kwargs): + raise AssertionError("SeedVR2 VAE self-attention must not use global optimized_attention") + + with patch.object(seedvr_vae, "vae_attention", return_value=vae_attention_spy): + attention = seedvr_vae.Attention(query_dim=4, heads=1, dim_head=4) + + attention.to_q = nn.Identity() + attention.to_k = nn.Identity() + attention.to_v = nn.Identity() + attention.to_out[0] = nn.Identity() + + hidden_states = torch.arange(24, dtype=torch.float32).reshape(1, 4, 2, 3) + + with patch.object(seedvr_vae, "optimized_attention", global_attention_forbidden): + output = attention(hidden_states) + + assert torch.equal(calls["q"], hidden_states) + assert torch.equal(calls["k"], hidden_states) + assert torch.equal(calls["v"], hidden_states) + assert torch.equal(output, hidden_states) diff --git a/tests-unit/comfy_test/test_seedvr_var_attention_backends.py b/tests-unit/comfy_test/test_seedvr_var_attention_backends.py new file mode 100644 index 000000000..d62167b41 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_var_attention_backends.py @@ -0,0 +1,476 @@ +import subprocess +import sys +import textwrap +import ast +import inspect + +import torch + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.modules.attention as attention # noqa: E402 + + +_VAR_BACKENDS = ( + "var_attention_sage", + "var_attention_sage3", + "var_attention_flash", + "var_attention_flash3", + "var_attention_sub_quad", + "var_attention_split", +) + + +def _inputs(): + heads = 2 + head_dim = 4 + total = 6 + q = torch.randn(total, heads, head_dim) + k = torch.randn(total, heads, head_dim) + v = torch.randn(total, heads, head_dim) + cu = torch.tensor([0, 3, 6], dtype=torch.int32) + return q, k, v, heads, cu + + +def _has_dynamo_disable(decorator): + return ( + isinstance(decorator, ast.Attribute) + and decorator.attr == "disable" + and isinstance(decorator.value, ast.Attribute) + and decorator.value.attr == "_dynamo" + and isinstance(decorator.value.value, ast.Name) + and decorator.value.value.id == "torch" + ) + + +def test_var_attention_backend_functions_are_dynamo_disabled_and_signature_compatible(): + tree = ast.parse(inspect.getsource(attention)) + functions = {node.name: node for node in tree.body if isinstance(node, ast.FunctionDef)} + + for name in _VAR_BACKENDS: + node = functions[name] + positional = [arg.arg for arg in node.args.args[:6]] + keyword_only = {arg.arg for arg in node.args.kwonlyargs} + assert positional == ["q", "k", "v", "heads", "cu_seqlens_q", "cu_seqlens_k"] + assert node.args.vararg is not None + assert node.args.kwarg is not None + assert "skip_reshape" in keyword_only + assert "skip_output_reshape" in keyword_only + assert any(_has_dynamo_disable(decorator) for decorator in node.decorator_list) + + +def test_var_attention_registry_contains_always_available_entries(): + assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch + assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad + assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split + + +def _run_attention_import(flag, fake_modules=True, fake_module_code=None): + argv = ["pytest-subprocess", "--cpu", "--disable-xformers"] + if flag: + argv.append(flag) + if fake_module_code is None: + fake_module_code = "" + if fake_modules and not fake_module_code: + fake_module_code = """ +import types + +sageattention = types.ModuleType("sageattention") +sageattention.sageattn = lambda *a, **k: a[0] +sageattention.sageattn_varlen = lambda *a, **k: a[0] +sys.modules["sageattention"] = sageattention + +sageattn3 = types.ModuleType("sageattn3") +sageattn3.sageattn3_blackwell = lambda *a, **k: a[0] +sys.modules["sageattn3"] = sageattn3 + +flash_attn = types.ModuleType("flash_attn") +flash_attn.flash_attn_func = lambda q, k, v, **kwargs: q +flash_attn.flash_attn_varlen_func = lambda **kwargs: kwargs["q"] +sys.modules["flash_attn"] = flash_attn + +flash_attn_interface = types.ModuleType("flash_attn_interface") +flash_attn_interface.flash_attn_varlen_func = lambda **kwargs: (kwargs["q"], None) +sys.modules["flash_attn_interface"] = flash_attn_interface +""" + code = ( + "import sys\n" + "import comfy.options\n" + "comfy.options.enable_args_parsing()\n" + f"sys.argv = {argv!r}\n" + f"{textwrap.dedent(fake_module_code)}\n" + "import comfy.ldm.modules.attention as attention\n" + "print(attention.optimized_var_attention.__name__)\n" + ) + return subprocess.run( + [sys.executable, "-c", code], + cwd=".", + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + +def test_var_attention_rebind_sage_launch_flag(): + result = _run_attention_import("--use-sage-attention") + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "var_attention_sage" + + +def test_var_attention_rebind_flash_launch_flag_uses_pytorch_varlen_in_cpu_mode(): + result = _run_attention_import("--use-flash-attention") + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "var_attention_pytorch" + + +def test_var_attention_rebind_sage_launch_flag_without_varlen_uses_pytorch(): + result = _run_attention_import( + "--use-sage-attention", + fake_module_code=""" +import types + +sageattention = types.ModuleType("sageattention") +sageattention.sageattn = lambda *a, **k: a[0] +sys.modules["sageattention"] = sageattention +""", + ) + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "var_attention_pytorch" + + +def test_var_attention_rebind_flash_launch_flag_without_varlen_uses_pytorch(): + result = _run_attention_import( + "--use-flash-attention", + fake_module_code=""" +import types + +flash_attn = types.ModuleType("flash_attn") +flash_attn.flash_attn_func = lambda q, k, v, **kwargs: q +sys.modules["flash_attn"] = flash_attn +""", + ) + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "var_attention_pytorch" + + +def test_var_attention_rebind_pytorch_launch_flag(): + result = _run_attention_import("--use-pytorch-cross-attention") + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "var_attention_pytorch" + + +def test_var_attention_rebind_split_launch_flag(): + result = _run_attention_import("--use-split-cross-attention") + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "var_attention_split" + + +def test_var_attention_rebind_default_launch_flags(): + result = _run_attention_import("") + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "var_attention_sub_quad" + + +def test_var_attention_sage_uses_cu_seqlens_contract(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def fake_sageattn_varlen(q, k, v, cu_q, cu_k, max_q, max_k, is_causal, sm_scale): + captured.update(cu_q=cu_q, cu_k=cu_k, max_q=max_q, max_k=max_k, is_causal=is_causal) + return torch.zeros_like(q) + + monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", True) + monkeypatch.setattr(attention, "sageattn_varlen", fake_sageattn_varlen, raising=False) + + out = attention.var_attention_sage(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_q"], cu) + assert torch.equal(captured["cu_k"], cu) + assert captured["max_q"] == 3 + assert captured["max_k"] == 3 + assert captured["is_causal"] is False + + +def test_var_attention_sage_runtime_error_preserves_fallback_dtype(monkeypatch): + q, k, v, heads, cu = _inputs() + q = q.float() + k = k.half() + v = v.half() + captured = {} + + def failing_sageattn_varlen(*args, **kwargs): + raise RuntimeError("unsupported") + + def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + captured.update(dtype=q.dtype, k_dtype=k.dtype, v_dtype=v.dtype, skip_reshape=skip_reshape) + return torch.zeros_like(q) + + monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", True) + monkeypatch.setattr(attention, "sageattn_varlen", failing_sageattn_varlen, raising=False) + monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch) + + out = attention.var_attention_sage(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert out.dtype == torch.float32 + assert captured["dtype"] == torch.float32 + assert captured["k_dtype"] == torch.float32 + assert captured["v_dtype"] == torch.float32 + assert captured["skip_reshape"] is True + + +def test_var_attention_sage3_uses_cu_seqlens_contract(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def fake_sageattn3_blackwell(q, k, v, is_causal=False): + captured.update(shape=tuple(q.shape), is_causal=is_causal) + return torch.zeros_like(q) + + monkeypatch.setattr(attention, "SAGE_ATTENTION3_IS_AVAILABLE", True) + monkeypatch.setattr(attention, "sageattn3_blackwell", fake_sageattn3_blackwell, raising=False) + + out = attention.var_attention_sage3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert captured["shape"] == (2, heads, 3, 4) + assert captured["is_causal"] is False + + +def test_var_attention_sage3_runtime_error_falls_back(monkeypatch): + q, k, v, heads, cu = _inputs() + q = q.float() + k = k.half() + v = v.half() + captured = {} + + def failing_sageattn3_blackwell(*args, **kwargs): + raise RuntimeError("unsupported") + + def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + captured.update(cu_q=cu_seqlens_q, dtype=q.dtype, k_dtype=k.dtype, v_dtype=v.dtype, skip_reshape=skip_reshape) + return torch.zeros_like(q) + + monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", False) + monkeypatch.setattr(attention, "SAGE_ATTENTION3_IS_AVAILABLE", True) + monkeypatch.setattr(attention, "sageattn3_blackwell", failing_sageattn3_blackwell, raising=False) + monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch) + + out = attention.var_attention_sage3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_q"], cu) + assert captured["dtype"] == torch.float32 + assert captured["k_dtype"] == torch.float32 + assert captured["v_dtype"] == torch.float32 + assert captured["skip_reshape"] is True + + +def test_var_attention_flash_uses_cu_seqlens_contract(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def fake_flash_attn_varlen_func(**kwargs): + captured.update(kwargs) + return torch.zeros_like(kwargs["q"]) + + monkeypatch.setattr(attention, "FLASH_ATTENTION_VARLEN_IS_AVAILABLE", True) + monkeypatch.setattr(attention, "flash_attn_varlen_func", fake_flash_attn_varlen_func, raising=False) + + out = attention.var_attention_flash(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_seqlens_q"], cu) + assert torch.equal(captured["cu_seqlens_k"], cu) + assert captured["max_seqlen_q"] == 3 + assert captured["max_seqlen_k"] == 3 + + +def test_var_attention_flash_runtime_error_falls_back(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def failing_flash_attn_varlen_func(**kwargs): + raise NotImplementedError("cpu") + + def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + captured.update(cu_q=cu_seqlens_q, skip_reshape=skip_reshape) + return torch.zeros_like(q) + + monkeypatch.setattr(attention, "FLASH_ATTENTION_VARLEN_IS_AVAILABLE", True) + monkeypatch.setattr(attention, "flash_attn_varlen_func", failing_flash_attn_varlen_func, raising=False) + monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch) + + out = attention.var_attention_flash(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_q"], cu) + assert captured["skip_reshape"] is True + + +def test_var_attention_flash3_uses_cu_seqlens_contract(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def fake_flash_attn3_varlen_func(**kwargs): + captured.update(kwargs) + return torch.zeros_like(kwargs["q"]), None + + monkeypatch.setattr(attention, "flash_attn3_varlen_func", fake_flash_attn3_varlen_func, raising=False) + monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True) + + out = attention.var_attention_flash3( + q, + k, + v, + heads, + cu, + cu, + skip_reshape=True, + skip_output_reshape=True, + dropout_p=0.25, + window_size=(16, 16), + ) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_seqlens_q"], cu) + assert torch.equal(captured["cu_seqlens_k"], cu) + assert captured["max_seqlen_q"] == 3 + assert captured["max_seqlen_k"] == 3 + assert captured["seqused_q"] is None + assert captured["seqused_k"] is None + assert "dropout_p" not in captured + assert "window_size" not in captured + + +def test_var_attention_flash3_accepts_tensor_return(monkeypatch): + q, k, v, heads, cu = _inputs() + + def fake_flash_attn3_varlen_func(**kwargs): + return torch.zeros_like(kwargs["q"]) + + monkeypatch.setattr(attention, "flash_attn3_varlen_func", fake_flash_attn3_varlen_func, raising=False) + monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True) + + out = attention.var_attention_flash3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + + +def test_var_attention_flash3_runtime_error_falls_back(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def failing_flash_attn3_varlen_func(**kwargs): + raise RuntimeError("unsupported") + + def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + captured.update(cu_q=cu_seqlens_q, skip_reshape=skip_reshape) + return torch.zeros_like(q) + + monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True) + monkeypatch.setattr(attention, "flash_attn3_varlen_func", failing_flash_attn3_varlen_func, raising=False) + monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch) + + out = attention.var_attention_flash3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_q"], cu) + assert captured["skip_reshape"] is True + + +def test_var_attention_sub_quad_uses_cu_seqlens_contract(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + captured.update(cu_q=cu_seqlens_q, cu_k=cu_seqlens_k, skip_reshape=skip_reshape) + return torch.zeros_like(q) + + monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch) + + out = attention.var_attention_sub_quad(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_q"], cu) + assert torch.equal(captured["cu_k"], cu) + assert captured["skip_reshape"] is True + + +def test_var_attention_split_uses_cu_seqlens_contract(monkeypatch): + q, k, v, heads, cu = _inputs() + captured = {} + + def fake_var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False): + captured.update(cu_q=cu_seqlens_q, cu_k=cu_seqlens_k, skip_reshape=skip_reshape) + return torch.zeros_like(q) + + def fail_var_attention_pytorch(*args, **kwargs): + raise AssertionError("split backend must not use nested-tensor pytorch var attention") + + monkeypatch.setattr(attention, "var_attention_pytorch", fail_var_attention_pytorch) + monkeypatch.setattr(attention, "var_attention_pytorch_split", fake_var_attention_pytorch_split) + + out = attention.var_attention_split(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert torch.equal(captured["cu_q"], cu) + assert torch.equal(captured["cu_k"], cu) + assert captured["skip_reshape"] is True + + +def test_var_attention_pytorch_split_normalizes_split_indices_to_cpu(monkeypatch): + q, k, v, heads, cu = _inputs() + captured_devices = [] + real_tensor_split = torch.tensor_split + + def capture_tensor_split(input, indices_or_sections, dim=0): + if isinstance(indices_or_sections, torch.Tensor): + captured_devices.append(indices_or_sections.device.type) + return real_tensor_split(input, indices_or_sections, dim=dim) + + monkeypatch.setattr(torch, "tensor_split", capture_tensor_split) + + out = attention.var_attention_pytorch_split(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True) + + assert tuple(out.shape) == tuple(q.shape) + assert captured_devices == ["cpu", "cpu", "cpu"] + + +def test_missing_sage_package_guard_message_preserved(): + code = textwrap.dedent( + """ + import builtins + import sys + import comfy.options + + comfy.options.enable_args_parsing() + + real_import = builtins.__import__ + + def blocked_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "sageattention": + raise ImportError("No module named sageattention", name="sageattention") + return real_import(name, globals, locals, fromlist, level) + + builtins.__import__ = blocked_import + sys.argv = ["pytest-subprocess", "--cpu", "--disable-xformers", "--use-sage-attention"] + import comfy.ldm.modules.attention + """ + ) + result = subprocess.run( + [sys.executable, "-c", code], + cwd=".", + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + assert result.returncode != 0 + assert "To use the `--use-sage-attention` feature" in result.stderr + assert "sageattention" in result.stderr diff --git a/tests-unit/comfy_test/test_var_attention_pytorch_seedvr2_guard.py b/tests-unit/comfy_test/test_var_attention_pytorch_seedvr2_guard.py new file mode 100644 index 000000000..f0ffe28ec --- /dev/null +++ b/tests-unit/comfy_test/test_var_attention_pytorch_seedvr2_guard.py @@ -0,0 +1,167 @@ +"""Regression tests for the SeedVR2-named guard inside +``comfy.ldm.modules.attention.var_attention_pytorch``. + +Contract: + + * If ``torch.nested.nested_tensor_from_jagged`` is unavailable on the + installed PyTorch build, ``var_attention_pytorch`` must raise + ``RuntimeError`` whose message contains both ``SeedVR2`` and + ``nested_tensor_from_jagged`` so the operator can identify the + failing attention path. A bare ``AttributeError`` from the + ``torch.nested`` lookup is non-conformant. The guard must also + cover the case where the ``torch.nested`` namespace itself is + absent (e.g. forks/builds that strip the module) — accessing + ``torch.nested`` directly would otherwise raise the same opaque + ``AttributeError`` the guard is meant to translate. + * If the API is present, the present-API path must produce the + canonical SeedVR2-inference output shape ``(total_tokens, + heads * head_dim)``. + * If the caller passes malformed offsets (off-end / non-monotonic / + size-mismatched), torch's own per-call ``RuntimeError`` propagates + unchanged: the SeedVR2-context guard fires only on the missing-API + path, never on torch's per-call shape errors. + +Each cell additionally pins the production guard at the AST level via +``inspect.getsource(var_attention_pytorch)`` so every AC fails +diagnostically on an unguarded base. +""" + +from comfy.cli_args import args +import torch + +if not torch.cuda.is_available(): + args.cpu = True + +import ast # noqa: E402 +import inspect # noqa: E402 +import logging # noqa: E402 +import textwrap # noqa: E402 +import warnings # noqa: E402 + +import pytest # noqa: E402 + +from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402 + + +def _inputs(): + """Canonical 2-D ``(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, + total_tokens, embed_dim)`` matching the live shape from GPT-3: + two segments of 3 tokens each, ``embed_dim = heads * head_dim = + 2 * 8 = 16``. + """ + heads, head_dim, total_tokens = 2, 8, 6 + embed_dim = heads * head_dim + q = torch.randn(total_tokens, embed_dim) + k = torch.randn(total_tokens, embed_dim) + v = torch.randn(total_tokens, embed_dim) + cu = torch.tensor([0, 3, 6], dtype=torch.int32) + return q, k, v, heads, cu, cu, total_tokens, embed_dim + + +def _assert_guard_source_pin(): + """Walk the AST of ``var_attention_pytorch`` and assert that the + first ``raise RuntimeError(...)`` statement appears strictly + before any attribute access named ``nested_tensor_from_jagged``. + + Substring-based source pinning (``src.index('raise RuntimeError(') + < src.index('nested_tensor_from_jagged')``) is fragile: it false- + positives on docstring or comment text containing the literal, + and false-negatives on a refactor that splits ``raise + RuntimeError(`` across lines or replaces it with a helper + raising ``RuntimeError`` from another scope. AST-walking the + function body collapses both failure modes onto the only + invariant we actually require — the guard statement dominates + the attribute access by line number. + """ + src = textwrap.dedent(inspect.getsource(var_attention_pytorch)) + tree = ast.parse(src) + raise_lines = [] + nested_lines = [] + for node in ast.walk(tree): + if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call): + func = node.exc.func + if isinstance(func, ast.Name) and func.id == "RuntimeError": + raise_lines.append(node.lineno) + if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged": + nested_lines.append(node.lineno) + assert raise_lines, ( + "var_attention_pytorch has no `raise RuntimeError(...)` AST node; " + f"the SeedVR2-named guard is missing.\n--- source ---\n{src}" + ) + assert nested_lines, ( + "var_attention_pytorch source has no `nested_tensor_from_jagged` " + f"attribute access; cannot pin guard ordering.\n" + f"--- source ---\n{src}" + ) + first_raise = min(raise_lines) + first_nested = min(nested_lines) + assert first_raise < first_nested, ( + f"`raise RuntimeError(...)` first appears at line {first_raise}, " + f"but `torch.nested.nested_tensor_from_jagged` is referenced first " + f"at line {first_nested}; the guard must precede the lookup.\n" + f"--- source ---\n{src}" + ) + + +def test_missing_api_raises_seedvr2_runtime_error(monkeypatch): + monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False) + q, k, v, heads, cu_q, cu_k, _, _ = _inputs() + + with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"): + var_attention_pytorch(q, k, v, heads, cu_q, cu_k) + + _assert_guard_source_pin() + + +def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch): + monkeypatch.delattr(torch, "nested", raising=False) + q, k, v, heads, cu_q, cu_k, _, _ = _inputs() + + with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"): + var_attention_pytorch(q, k, v, heads, cu_q, cu_k) + + _assert_guard_source_pin() + + +def test_present_api_returns_expected_shape(): + q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _inputs() + + torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace") + old_torch_fx_level = torch_fx_logger.level + torch_fx_logger.setLevel(logging.ERROR) + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="The PyTorch API of nested tensors is in prototype stage.*", + category=UserWarning, + ) + out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k) + finally: + torch_fx_logger.setLevel(old_torch_fx_level) + + assert tuple(out.shape) == (total_tokens, embed_dim), ( + f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}" + ) + + _assert_guard_source_pin() + + +def test_malformed_offsets_propagates_torch_runtime_error(): + q, k, v, heads, _, _, _, _ = _inputs() + cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32) + cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32) + + with pytest.raises(RuntimeError) as exc_info: + var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok) + + msg = str(exc_info.value) + assert "split_with_sizes" in msg, ( + f"expected torch's `split_with_sizes` error to propagate; got: {msg!r}" + ) + assert "SeedVR2" not in msg, ( + f"SeedVR2-context substring must not be substituted onto torch's " + f"per-call shape error; got: {msg!r}" + ) + + _assert_guard_source_pin()