mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Add SeedVR2 core coverage
This commit is contained in:
parent
6e5186ddac
commit
9eb6c7fe9e
@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd():
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_7b_separate_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_7b_shared_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_3b_shared_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestModelDetection:
|
class TestModelDetection:
|
||||||
"""Verify that first-match model detection selects the correct model
|
"""Verify that first-match model detection selects the correct model
|
||||||
based on list ordering and unet_config specificity."""
|
based on list ordering and unet_config specificity."""
|
||||||
@ -125,6 +143,46 @@ class TestModelDetection:
|
|||||||
assert model_config is not None
|
assert model_config is not None
|
||||||
assert type(model_config).__name__ == "FluxSchnell"
|
assert type(model_config).__name__ == "FluxSchnell"
|
||||||
|
|
||||||
|
def test_seedvr2_7b_separate_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_7b_separate_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 3072
|
||||||
|
assert unet_config["heads"] == 24
|
||||||
|
assert unet_config["num_layers"] == 36
|
||||||
|
assert unet_config["mm_layers"] == 36
|
||||||
|
assert unet_config["mlp_type"] == "normal"
|
||||||
|
assert unet_config["qk_rope"] is True
|
||||||
|
assert unet_config["rope_type"] == "rope3d"
|
||||||
|
assert unet_config["rope_dim"] == 64
|
||||||
|
|
||||||
|
def test_seedvr2_7b_shared_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_7b_shared_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 3072
|
||||||
|
assert unet_config["heads"] == 24
|
||||||
|
assert unet_config["num_layers"] == 36
|
||||||
|
assert unet_config["mm_layers"] == 10
|
||||||
|
assert unet_config["mlp_type"] == "swiglu"
|
||||||
|
assert unet_config["qk_rope"] is True
|
||||||
|
|
||||||
|
def test_seedvr2_3b_shared_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_3b_shared_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 2560
|
||||||
|
assert unet_config["heads"] == 20
|
||||||
|
assert unet_config["num_layers"] == 32
|
||||||
|
assert unet_config["mlp_type"] == "swiglu"
|
||||||
|
assert unet_config["qk_rope"] is None
|
||||||
|
|
||||||
def test_unet_config_and_required_keys_combination_is_unique(self):
|
def test_unet_config_and_required_keys_combination_is_unique(self):
|
||||||
"""Each model in the registry must have a unique combination of
|
"""Each model in the registry must have a unique combination of
|
||||||
``unet_config`` and ``required_keys``. If two models share the same
|
``unet_config`` and ``required_keys``. If two models share the same
|
||||||
|
|||||||
192
tests-unit/comfy_test/seedvr_model_test.py
Normal file
192
tests-unit/comfy_test/seedvr_model_test.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""Regression tests for SeedVR2 conditioning split hardening.
|
||||||
|
|
||||||
|
Two bare ``except:`` clauses in ``NaDiT.forward`` previously swallowed
|
||||||
|
every failure mode on (1) the input-side text-conditioning split and
|
||||||
|
(2) the output-side positive/negative split, silently substituting
|
||||||
|
wrong fallbacks: the ``positive_conditioning`` buffer (which prior to
|
||||||
|
explicit zero-init held **uninitialized** memory — NaN, residual heap
|
||||||
|
contents, never guaranteed-zero) for the input, and the un-split
|
||||||
|
tensor for the output. Real prompt-shape, dtype, OOM, and downstream
|
||||||
|
tensor failures were re-routed to "no prompt supplied" with arbitrary
|
||||||
|
buffer contents standing in for actual prompt embeddings, or to a
|
||||||
|
wrong-order output, with no diagnostic.
|
||||||
|
|
||||||
|
The fix:
|
||||||
|
|
||||||
|
1. Input-side: explicit absence predicate (``context is None`` or
|
||||||
|
``context.numel() == 0``) → fall back to ``positive_conditioning``
|
||||||
|
buffer. Any other failure (wrong rank, odd batch, dtype, OOM)
|
||||||
|
propagates the original torch exception.
|
||||||
|
2. Output-side: no try/except at all. ``out.chunk(2)`` of the
|
||||||
|
network output is a contract: an unsplittable result is a bug,
|
||||||
|
not a recoverable condition.
|
||||||
|
|
||||||
|
The two blocks were extracted into named private methods on
|
||||||
|
``NaDiT`` (``_resolve_text_conditioning`` and ``_swap_pos_neg_halves``)
|
||||||
|
so the regression evidence drives the actual production code paths
|
||||||
|
without standing up a full transformer. The methods are called from
|
||||||
|
``forward`` exactly where the original try/except blocks lived.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import ast # noqa: E402
|
||||||
|
import inspect # noqa: E402
|
||||||
|
import textwrap # noqa: E402
|
||||||
|
|
||||||
|
import pytest # noqa: E402
|
||||||
|
|
||||||
|
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _make_standin(positive_conditioning):
|
||||||
|
class _StandIn(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer(
|
||||||
|
"positive_conditioning", positive_conditioning
|
||||||
|
)
|
||||||
|
|
||||||
|
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
|
||||||
|
_swap_pos_neg_halves = NaDiT._swap_pos_neg_halves
|
||||||
|
|
||||||
|
return _StandIn()
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_bare_except_in_forward_path():
|
||||||
|
"""Source-level pin: neither ``NaDiT.forward`` nor its split helpers
|
||||||
|
may carry the bare ``except:`` clauses that swallowed real torch
|
||||||
|
failures on the SeedVR2 conditioning paths. AST-walked rather than
|
||||||
|
substring-matched so that ``except:`` appearing in a docstring or
|
||||||
|
comment does not false-positive, and so that ``except Exception:``
|
||||||
|
(a typed handler, fine to have) does not false-negative.
|
||||||
|
"""
|
||||||
|
sources = [
|
||||||
|
inspect.getsource(NaDiT.forward),
|
||||||
|
inspect.getsource(NaDiT._resolve_text_conditioning),
|
||||||
|
inspect.getsource(NaDiT._swap_pos_neg_halves),
|
||||||
|
]
|
||||||
|
for src in sources:
|
||||||
|
tree = ast.parse(textwrap.dedent(src))
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.ExceptHandler):
|
||||||
|
assert node.type is not None, (
|
||||||
|
"Bare 'except:' (ast.ExceptHandler with type=None) "
|
||||||
|
f"must not appear on the SeedVR2 forward path:\n{src}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_context_splits_pos_neg():
|
||||||
|
"""AC: valid (neg, pos)-stacked context (shape ``(2, L, C)``)
|
||||||
|
produces a flattened ``[pos, neg]`` text tensor — first ``L`` rows
|
||||||
|
are positive, next ``L`` rows are negative — matching the original
|
||||||
|
semantics of the ``flatten([pos_cond, neg_cond])`` call.
|
||||||
|
"""
|
||||||
|
pos_buffer = torch.zeros((58, 5120))
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
seq_len, channels = 7, 5120
|
||||||
|
neg = torch.full((1, seq_len, channels), -1.0)
|
||||||
|
pos = torch.full((1, seq_len, channels), 1.0)
|
||||||
|
context = torch.cat([neg, pos], dim=0)
|
||||||
|
txt, txt_shape = standin._resolve_text_conditioning(context)
|
||||||
|
assert txt.shape == (2 * seq_len, channels)
|
||||||
|
assert (txt[:seq_len] == 1.0).all(), "first half must be positive cond"
|
||||||
|
assert (txt[seq_len:] == -1.0).all(), "second half must be negative cond"
|
||||||
|
assert txt_shape.shape == (2, 1)
|
||||||
|
assert txt_shape[0].item() == seq_len
|
||||||
|
assert txt_shape[1].item() == seq_len
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_context_falls_back_to_positive_buffer():
|
||||||
|
"""AC: ``context is None`` falls back to the registered
|
||||||
|
``positive_conditioning`` buffer and runs to completion — no
|
||||||
|
silent zero substitution, no raised exception.
|
||||||
|
"""
|
||||||
|
pos_buffer = torch.full((58, 5120), 7.0)
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
txt, txt_shape = standin._resolve_text_conditioning(None)
|
||||||
|
assert txt.shape == (58, 5120)
|
||||||
|
assert (txt == 7.0).all(), (
|
||||||
|
"fallback path must use the positive_conditioning buffer "
|
||||||
|
"verbatim, not a zero tensor"
|
||||||
|
)
|
||||||
|
assert txt_shape.shape == (1, 1)
|
||||||
|
assert txt_shape[0, 0].item() == 58
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_context_falls_back_to_positive_buffer():
|
||||||
|
"""AC: ``context.numel() == 0`` falls back to the registered
|
||||||
|
``positive_conditioning`` buffer and runs to completion.
|
||||||
|
"""
|
||||||
|
pos_buffer = torch.full((58, 5120), 13.0)
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
empty = torch.empty((0, 5120))
|
||||||
|
assert empty.numel() == 0
|
||||||
|
txt, txt_shape = standin._resolve_text_conditioning(empty)
|
||||||
|
assert txt.shape == (58, 5120)
|
||||||
|
assert (txt == 13.0).all()
|
||||||
|
assert txt_shape.shape == (1, 1)
|
||||||
|
assert txt_shape[0, 0].item() == 58
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_rank_context_raises_original_torch_exception():
|
||||||
|
"""AC: a 1-D context tensor cannot be split into ``[pos, neg]``
|
||||||
|
via the ``chunk + squeeze + flatten`` chain; the original torch
|
||||||
|
exception must propagate rather than silently falling back.
|
||||||
|
"""
|
||||||
|
pos_buffer = torch.zeros((58, 5120))
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
bad = torch.zeros(10)
|
||||||
|
with pytest.raises((RuntimeError, IndexError, ValueError)):
|
||||||
|
standin._resolve_text_conditioning(bad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_odd_batch_context_raises_original_exception():
|
||||||
|
"""AC: a context whose batch dim cannot be split into two equal
|
||||||
|
chunks (here batch=1 so ``chunk(2, dim=0)`` returns a single
|
||||||
|
tensor) must propagate the original exception — no silent fallback.
|
||||||
|
"""
|
||||||
|
pos_buffer = torch.zeros((58, 5120))
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
bad = torch.zeros((1, 7, 5120))
|
||||||
|
with pytest.raises((RuntimeError, ValueError)):
|
||||||
|
standin._resolve_text_conditioning(bad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_side_misshaped_tensor_raises():
|
||||||
|
"""AC: the post-network output split must raise on an unsplittable
|
||||||
|
tensor (no silent return of the un-split tensor in the wrong
|
||||||
|
order/shape). Here a batch=1 tensor cannot be ``chunk(2, dim=0)``
|
||||||
|
into two halves; ``pos, neg = out.chunk(2, dim=0)`` raises on
|
||||||
|
unpacking — matching the production helper's explicit-dim contract
|
||||||
|
(``_swap_pos_neg_halves`` calls ``chunk(2, dim=0)`` and
|
||||||
|
``torch.cat(..., dim=0)``).
|
||||||
|
"""
|
||||||
|
pos_buffer = torch.zeros((58, 5120))
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
bad_out = torch.zeros((1, 4, 8, 8))
|
||||||
|
with pytest.raises((RuntimeError, ValueError)):
|
||||||
|
standin._swap_pos_neg_halves(bad_out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_output_side_swaps_pos_neg_halves():
|
||||||
|
"""AC complement: ``_swap_pos_neg_halves`` reorders the post-network
|
||||||
|
output so the first half (positive) and second half (negative) trade
|
||||||
|
places. For a 2-batch tensor with distinguishable halves, the
|
||||||
|
returned tensor must be the swap — first half becomes negative,
|
||||||
|
second half becomes positive — matching the original
|
||||||
|
``torch.cat([neg, pos])`` semantics from the pre-fix forward path.
|
||||||
|
"""
|
||||||
|
pos_buffer = torch.zeros((58, 5120))
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
pos_half = torch.full((1, 4, 8, 8), 1.0)
|
||||||
|
neg_half = torch.full((1, 4, 8, 8), -1.0)
|
||||||
|
out = torch.cat([pos_half, neg_half], dim=0)
|
||||||
|
swapped = standin._swap_pos_neg_halves(out)
|
||||||
|
assert swapped.shape == out.shape
|
||||||
|
assert (swapped[0] == -1.0).all(), "first half of swapped output must be the original negative half"
|
||||||
|
assert (swapped[1] == 1.0).all(), "second half of swapped output must be the original positive half"
|
||||||
124
tests-unit/comfy_test/seedvr_vae_forward_test.py
Normal file
124
tests-unit/comfy_test/seedvr_vae_forward_test.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
|
||||||
|
honor the actual tensor/tuple return contract of ``encode()`` and
|
||||||
|
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
|
||||||
|
or ``.sample`` attributes on those returns.
|
||||||
|
|
||||||
|
The pre-fix body raised ``AttributeError: 'Tensor' object has no
|
||||||
|
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
|
||||||
|
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
|
||||||
|
for ``mode == "decode"`` (the class only defines ``decode_`` with a
|
||||||
|
trailing underscore). The post-fix body unwraps the optional one-element
|
||||||
|
tuple shape that ``return_dict=False`` produces and returns the tensor
|
||||||
|
directly.
|
||||||
|
|
||||||
|
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
|
||||||
|
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
|
||||||
|
overrides ``encode``/``decode_`` with known tensors so the contract can
|
||||||
|
be probed without loading any real VAE weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
_LATENT_SHAPE = (1, 16, 2, 2, 2)
|
||||||
|
_DECODED_SHAPE = (1, 3, 5, 16, 16)
|
||||||
|
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
|
||||||
|
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubVAE(VideoAutoencoderKL):
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self._encode_out = torch.zeros(*_LATENT_SHAPE)
|
||||||
|
self._decode_out = torch.zeros(*_DECODED_SHAPE)
|
||||||
|
|
||||||
|
def encode(self, x, return_dict=True):
|
||||||
|
return self._encode_out
|
||||||
|
|
||||||
|
def decode_(self, z, return_dict=True):
|
||||||
|
return self._decode_out
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_encode_returns_tensor():
|
||||||
|
vae = _StubVAE()
|
||||||
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||||
|
result = vae.forward(x, mode="encode")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_LATENT_SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_decode_returns_tensor():
|
||||||
|
vae = _StubVAE()
|
||||||
|
z = torch.zeros(*_INPUT_DECODE_SHAPE)
|
||||||
|
result = vae.forward(z, mode="decode")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_all_returns_tensor():
|
||||||
|
vae = _StubVAE()
|
||||||
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||||
|
result = vae.forward(x, mode="all")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_source_has_no_diffusers_attr_access():
|
||||||
|
src = inspect.getsource(VideoAutoencoderKL.forward)
|
||||||
|
assert ".latent_dist" not in src
|
||||||
|
assert ".sample" not in src
|
||||||
|
assert re.search(r"self\.decode\(", src) is None
|
||||||
|
|
||||||
|
|
||||||
|
class _TupleReturningStubVAE(VideoAutoencoderKL):
|
||||||
|
"""Stub variant whose ``encode``/``decode_`` return the
|
||||||
|
``(tensor,)`` one-element tuple shape ``return_dict=False`` produces
|
||||||
|
in the parent class. Exercises the unwrap branch of
|
||||||
|
``VideoAutoencoderKL.forward``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
|
||||||
|
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
|
||||||
|
|
||||||
|
def encode(self, x, return_dict=True):
|
||||||
|
return (self._encode_tensor,)
|
||||||
|
|
||||||
|
def decode_(self, z, return_dict=True):
|
||||||
|
return (self._decode_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_encode_unwraps_one_tuple():
|
||||||
|
vae = _TupleReturningStubVAE()
|
||||||
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||||
|
result = vae.forward(x, mode="encode")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_LATENT_SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_decode_unwraps_one_tuple():
|
||||||
|
vae = _TupleReturningStubVAE()
|
||||||
|
z = torch.zeros(*_INPUT_DECODE_SHAPE)
|
||||||
|
result = vae.forward(z, mode="decode")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_all_unwraps_one_tuple_at_each_step():
|
||||||
|
vae = _TupleReturningStubVAE()
|
||||||
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||||
|
result = vae.forward(x, mode="all")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||||
63
tests-unit/comfy_test/seedvr_vae_wrapper_forward_test.py
Normal file
63
tests-unit/comfy_test/seedvr_vae_wrapper_forward_test.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||||
|
|
||||||
|
VideoAutoencoderKLWrapper = vae_mod.VideoAutoencoderKLWrapper
|
||||||
|
|
||||||
|
|
||||||
|
_INPUT_SHAPE = (1, 3, 5, 16, 16)
|
||||||
|
_POSTERIOR_SHAPE = (1, 16, 1, 2, 2)
|
||||||
|
_DECODE_OUT_SHAPE = (1, 3, 5, 16, 16)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_wrapper_standin() -> VideoAutoencoderKLWrapper:
|
||||||
|
wrapper = VideoAutoencoderKLWrapper.__new__(VideoAutoencoderKLWrapper)
|
||||||
|
nn.Module.__init__(wrapper)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapper_forward_returns_tensor_triple(monkeypatch):
|
||||||
|
wrapper = _build_wrapper_standin()
|
||||||
|
wrapper.original_image_video = torch.zeros(*_INPUT_SHAPE)
|
||||||
|
wrapper.img_dims = (16, 16)
|
||||||
|
wrapper.freeze_encoder = True
|
||||||
|
|
||||||
|
posterior = torch.full(_POSTERIOR_SHAPE, 7.0)
|
||||||
|
decode_out = torch.full(_DECODE_OUT_SHAPE, 13.0)
|
||||||
|
|
||||||
|
def stub_encode(self, x, orig_dims=None):
|
||||||
|
return posterior.squeeze(2), posterior
|
||||||
|
|
||||||
|
def stub_decode(self, z):
|
||||||
|
return decode_out
|
||||||
|
|
||||||
|
monkeypatch.setattr(VideoAutoencoderKLWrapper, "encode", stub_encode)
|
||||||
|
monkeypatch.setattr(VideoAutoencoderKLWrapper, "decode", stub_decode)
|
||||||
|
|
||||||
|
x = torch.zeros(*_INPUT_SHAPE)
|
||||||
|
result = wrapper.forward(x)
|
||||||
|
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 3
|
||||||
|
x_out, z, p = result
|
||||||
|
assert type(x_out) is torch.Tensor
|
||||||
|
assert type(z) is torch.Tensor
|
||||||
|
assert type(p) is torch.Tensor
|
||||||
|
assert x_out.shape == decode_out.shape
|
||||||
|
assert z.shape == posterior.squeeze(2).shape
|
||||||
|
assert torch.equal(x_out, decode_out)
|
||||||
|
assert torch.equal(z, posterior.squeeze(2))
|
||||||
|
assert p is posterior
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapper_forward_source_has_no_sample_access():
|
||||||
|
src = inspect.getsource(VideoAutoencoderKLWrapper.forward)
|
||||||
|
assert ".sample" not in src
|
||||||
105
tests-unit/comfy_test/test_diffusers_metadata_guard.py
Normal file
105
tests-unit/comfy_test/test_diffusers_metadata_guard.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
"""Regression tests for the diffusers-format guard inside ``comfy.sd.VAE.__init__``.
|
||||||
|
|
||||||
|
The guard previously indexed ``metadata["keep_diffusers_format"]`` directly,
|
||||||
|
raising ``KeyError`` when ``metadata`` was non-``None`` but lacked that key. The
|
||||||
|
fixed guard uses ``metadata.get("keep_diffusers_format") != "true"``: a missing
|
||||||
|
key flows through to ``convert_vae_state_dict``; the explicit ``"true"`` value
|
||||||
|
bypasses it.
|
||||||
|
|
||||||
|
Five cells exercise every reachable shape of the guard input — missing key,
|
||||||
|
explicit ``"true"``, ``None``, explicit non-``"true"``, empty dict — and halt
|
||||||
|
the constructor at the first post-guard call (``model_management.is_amd``).
|
||||||
|
``_make_standin`` borrows ``__init__`` onto a bare class, mirroring
|
||||||
|
``seedvr_model_test.py::_make_standin`` (#109). ``_exercise_guard`` single-
|
||||||
|
sources the patched-constructor harness so the cells stay synchronised.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import contextlib # noqa: E402
|
||||||
|
import unittest.mock # noqa: E402
|
||||||
|
|
||||||
|
import comfy.sd # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
_DIFFUSERS_TRIGGER_KEY = "decoder.up_blocks.0.resnets.0.norm1.weight"
|
||||||
|
|
||||||
|
|
||||||
|
class _PostGuardReached(Exception):
|
||||||
|
"""Sentinel raised by the patched ``is_amd`` to halt ``__init__`` at the first post-guard statement."""
|
||||||
|
|
||||||
|
|
||||||
|
def _make_standin():
|
||||||
|
class _StandIn:
|
||||||
|
__init__ = comfy.sd.VAE.__init__
|
||||||
|
|
||||||
|
return _StandIn
|
||||||
|
|
||||||
|
|
||||||
|
def _exercise_guard(metadata):
|
||||||
|
"""Drive ``VAE.__init__`` with the diffusers trigger key and the supplied
|
||||||
|
``metadata``; halt at ``is_amd``. Returns ``(mock_convert, mock_is_amd)``
|
||||||
|
for branch (call_count) + reach (called) assertions per cell.
|
||||||
|
"""
|
||||||
|
StandIn = _make_standin()
|
||||||
|
sd = {_DIFFUSERS_TRIGGER_KEY: torch.zeros(1)}
|
||||||
|
|
||||||
|
with unittest.mock.patch.object(
|
||||||
|
comfy.sd.diffusers_convert,
|
||||||
|
"convert_vae_state_dict",
|
||||||
|
autospec=True,
|
||||||
|
side_effect=lambda state_dict: state_dict,
|
||||||
|
) as mock_convert, unittest.mock.patch.object(
|
||||||
|
comfy.sd.model_management,
|
||||||
|
"is_amd",
|
||||||
|
autospec=True,
|
||||||
|
side_effect=_PostGuardReached("post-guard reached"),
|
||||||
|
) as mock_is_amd:
|
||||||
|
with contextlib.suppress(_PostGuardReached):
|
||||||
|
StandIn(sd=sd, metadata=metadata)
|
||||||
|
|
||||||
|
return mock_convert, mock_is_amd
|
||||||
|
|
||||||
|
|
||||||
|
def test_diffusers_guard_invokes_convert_when_metadata_missing_key():
|
||||||
|
"""AC1: metadata is non-None but lacks ``keep_diffusers_format`` → convert is invoked."""
|
||||||
|
mock_convert, mock_is_amd = _exercise_guard({"unrelated_key": "value"})
|
||||||
|
|
||||||
|
assert mock_convert.call_count == 1
|
||||||
|
assert mock_is_amd.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_diffusers_guard_skips_convert_when_metadata_pins_keep_true():
|
||||||
|
"""AC2: metadata pins ``keep_diffusers_format == "true"`` → convert is skipped."""
|
||||||
|
mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "true"})
|
||||||
|
|
||||||
|
assert mock_convert.call_count == 0
|
||||||
|
assert mock_is_amd.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_diffusers_guard_invokes_convert_when_metadata_is_none():
|
||||||
|
"""AC3: metadata is ``None`` → first disjunct fires, convert is invoked."""
|
||||||
|
mock_convert, mock_is_amd = _exercise_guard(None)
|
||||||
|
|
||||||
|
assert mock_convert.call_count == 1
|
||||||
|
assert mock_is_amd.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_diffusers_guard_invokes_convert_when_metadata_pins_keep_false():
|
||||||
|
"""AC4: metadata pins a non-``"true"`` value → second disjunct fires, convert is invoked."""
|
||||||
|
mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "false"})
|
||||||
|
|
||||||
|
assert mock_convert.call_count == 1
|
||||||
|
assert mock_is_amd.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_diffusers_guard_invokes_convert_when_metadata_is_empty_dict():
|
||||||
|
"""AC5: metadata is ``{}`` (the ``convert_old_quants`` None→{} normalization shape) → convert is invoked."""
|
||||||
|
mock_convert, mock_is_amd = _exercise_guard({})
|
||||||
|
|
||||||
|
assert mock_convert.call_count == 1
|
||||||
|
assert mock_is_amd.called
|
||||||
503
tests-unit/comfy_test/test_seedvr2_dtype.py
Normal file
503
tests-unit/comfy_test/test_seedvr2_dtype.py
Normal file
@ -0,0 +1,503 @@
|
|||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.modules.attention as attention
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.supported_models
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_model_config_inference_dtype_preserves_legacy_signature():
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
class LegacyConfig:
|
||||||
|
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||||
|
calls.append((dtype, manual_cast_dtype))
|
||||||
|
|
||||||
|
comfy.sd._set_model_config_inference_dtype(LegacyConfig(), torch.float16, None, object())
|
||||||
|
|
||||||
|
assert calls == [(torch.float16, None)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_model_config_inference_dtype_passes_device_when_supported():
|
||||||
|
calls = []
|
||||||
|
device = object()
|
||||||
|
|
||||||
|
class DeviceAwareConfig:
|
||||||
|
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||||
|
calls.append((dtype, manual_cast_dtype, device))
|
||||||
|
|
||||||
|
comfy.sd._set_model_config_inference_dtype(DeviceAwareConfig(), torch.float16, None, device)
|
||||||
|
|
||||||
|
assert calls == [(torch.float16, None, device)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_model_config_inference_dtype_passes_device_to_kwargs_override():
|
||||||
|
calls = []
|
||||||
|
device = object()
|
||||||
|
|
||||||
|
class KwargsConfig:
|
||||||
|
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
|
||||||
|
calls.append((dtype, manual_cast_dtype, kwargs))
|
||||||
|
|
||||||
|
comfy.sd._set_model_config_inference_dtype(KwargsConfig(), torch.float16, None, device)
|
||||||
|
|
||||||
|
assert calls == [(torch.float16, None, {"device": device})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
|
||||||
|
bf16_device = object()
|
||||||
|
fp16_device = object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
comfy.supported_models.comfy.model_management,
|
||||||
|
"should_use_bf16",
|
||||||
|
lambda device=None: device is bf16_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||||
|
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
|
||||||
|
assert bf16_config.manual_cast_dtype is torch.bfloat16
|
||||||
|
|
||||||
|
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||||
|
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
|
||||||
|
assert fp16_config.manual_cast_dtype is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rope1_partial_preserves_full_rotation_input_dtype(monkeypatch):
|
||||||
|
def fake_apply_rope1(t, freqs_cis):
|
||||||
|
return t.float() + 1.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
|
||||||
|
|
||||||
|
t = torch.arange(8, dtype=torch.float16).reshape(1, 2, 4)
|
||||||
|
original = t.clone()
|
||||||
|
freqs_cis = torch.zeros(1, 2, 2, 2)
|
||||||
|
|
||||||
|
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
|
||||||
|
|
||||||
|
assert out.dtype is torch.float16
|
||||||
|
torch.testing.assert_close(out, (original.float() + 1.0).to(torch.float16))
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rope1_partial_preserves_partial_rotation_input_dtype(monkeypatch):
|
||||||
|
def fake_apply_rope1(t, freqs_cis):
|
||||||
|
return t.float() + 1.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
|
||||||
|
|
||||||
|
t = torch.arange(12, dtype=torch.float16).reshape(1, 2, 6)
|
||||||
|
original = t.clone()
|
||||||
|
freqs_cis = torch.zeros(1, 2, 2, 2)
|
||||||
|
|
||||||
|
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
|
||||||
|
|
||||||
|
assert out.dtype is torch.float16
|
||||||
|
torch.testing.assert_close(
|
||||||
|
out[..., :4],
|
||||||
|
(original[..., :4].float() + 1.0).to(torch.float16),
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(out[..., 4:], original[..., 4:])
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rope1_partial_chunks_sequence_dimension(monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_apply_rope1(t, freqs_cis):
|
||||||
|
calls.append(t.shape[-2])
|
||||||
|
return t.float() + 1.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
|
||||||
|
monkeypatch.setattr(seedvr_model, "_ROPE1_PARTIAL_CHUNK_TOKENS", 2)
|
||||||
|
|
||||||
|
t = torch.arange(30, dtype=torch.float16).reshape(1, 5, 6)
|
||||||
|
original = t.clone()
|
||||||
|
freqs_cis = torch.zeros(5, 2, 2, 2)
|
||||||
|
|
||||||
|
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
|
||||||
|
|
||||||
|
assert calls == [2, 2, 1]
|
||||||
|
torch.testing.assert_close(out[..., :4], (original[..., :4].float() + 1.0).to(torch.float16))
|
||||||
|
torch.testing.assert_close(out[..., 4:], original[..., 4:])
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rope1_partial_clones_training_tensor(monkeypatch):
|
||||||
|
def fake_apply_rope1(t, freqs_cis):
|
||||||
|
return t + 1.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "apply_rope1", fake_apply_rope1)
|
||||||
|
|
||||||
|
base = torch.arange(12, dtype=torch.float32, requires_grad=True)
|
||||||
|
t = base.reshape(1, 2, 6)
|
||||||
|
original = t.clone()
|
||||||
|
freqs_cis = torch.zeros(2, 2, 2, 2)
|
||||||
|
|
||||||
|
out = seedvr_model._apply_rope1_partial(t, freqs_cis)
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
assert out is not t
|
||||||
|
torch.testing.assert_close(t, original)
|
||||||
|
torch.testing.assert_close(out[..., :4], original[..., :4] + 1.0)
|
||||||
|
torch.testing.assert_close(out[..., 4:], original[..., 4:])
|
||||||
|
assert base.grad is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
|
||||||
|
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
|
||||||
|
|
||||||
|
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(txt, context.squeeze(0))
|
||||||
|
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_text_conditioning_accepts_batched_cfg1_single_branch():
|
||||||
|
context = torch.arange(12, dtype=torch.float32).reshape(2, 3, 2)
|
||||||
|
|
||||||
|
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(txt, context.flatten(0, -2))
|
||||||
|
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_text_conditioning_accepts_multi_entry_cfg1_single_branch():
|
||||||
|
context = torch.arange(12, dtype=torch.float32).reshape(2, 3, 2)
|
||||||
|
|
||||||
|
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0, 0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(txt, context.flatten(0, -2))
|
||||||
|
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_text_conditioning_preserves_two_branch_swap_contract():
|
||||||
|
neg = torch.full((1, 3, 2), -1.0)
|
||||||
|
pos = torch.full((1, 3, 2), 1.0)
|
||||||
|
context = torch.cat([neg, pos], dim=0)
|
||||||
|
|
||||||
|
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context)
|
||||||
|
|
||||||
|
torch.testing.assert_close(txt[:3], pos.squeeze(0))
|
||||||
|
torch.testing.assert_close(txt[3:], neg.squeeze(0))
|
||||||
|
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3]], device=context.device))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_text_conditioning_preserves_batched_two_branch_swap_contract():
|
||||||
|
neg = torch.full((2, 3, 2), -1.0)
|
||||||
|
pos = torch.full((2, 3, 2), 1.0)
|
||||||
|
context = torch.cat([neg, pos], dim=0)
|
||||||
|
|
||||||
|
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [1, 0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(txt[:6], pos.flatten(0, -2))
|
||||||
|
torch.testing.assert_close(txt[6:], neg.flatten(0, -2))
|
||||||
|
torch.testing.assert_close(txt_shape, torch.tensor([[3], [3], [3], [3]], device=context.device))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_cfg1_single_branch_output_is_not_swapped():
|
||||||
|
out = torch.arange(6, dtype=torch.float32).reshape(1, 6)
|
||||||
|
|
||||||
|
swapped = seedvr_model.NaDiT._swap_pos_neg_halves(object(), out, [0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(swapped, out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_multi_entry_cfg1_output_is_not_swapped():
|
||||||
|
out = torch.arange(12, dtype=torch.float32).reshape(2, 6)
|
||||||
|
|
||||||
|
swapped = seedvr_model.NaDiT._swap_pos_neg_halves(object(), out, [0, 0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(swapped, out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_keeps_comfy_cfg1_optimization_enabled():
|
||||||
|
source = (Path(__file__).resolve().parents[2] / "comfy_extras" / "nodes_seedvr.py").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "disable_model_cfg1_optimization()" not in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_split_var_attention_matches_nested_var_attention():
|
||||||
|
torch.manual_seed(1)
|
||||||
|
q = torch.randn(5, 2, 4)
|
||||||
|
k = torch.randn(7, 2, 4)
|
||||||
|
v = torch.randn(7, 2, 4)
|
||||||
|
cu_q = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
cu_k = torch.tensor([0, 3, 7], dtype=torch.int32)
|
||||||
|
|
||||||
|
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
|
||||||
|
old_torch_fx_level = torch_fx_logger.level
|
||||||
|
torch_fx_logger.setLevel(logging.ERROR)
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="The PyTorch API of nested tensors is in prototype stage.*",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
nested = attention.var_attention_pytorch(
|
||||||
|
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
|
||||||
|
skip_reshape=True, skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
torch_fx_logger.setLevel(old_torch_fx_level)
|
||||||
|
split = attention.var_attention_pytorch_split(
|
||||||
|
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
|
||||||
|
skip_reshape=True, skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(split, nested, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_split_var_attention_preserves_flat_output_shape():
|
||||||
|
torch.manual_seed(2)
|
||||||
|
q = torch.randn(5, 8)
|
||||||
|
k = torch.randn(7, 8)
|
||||||
|
v = torch.randn(7, 8)
|
||||||
|
cu_q = torch.tensor([0, 1, 5], dtype=torch.int32)
|
||||||
|
cu_k = torch.tensor([0, 2, 7], dtype=torch.int32)
|
||||||
|
|
||||||
|
nested = attention.var_attention_pytorch(
|
||||||
|
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
|
||||||
|
)
|
||||||
|
split = attention.var_attention_pytorch_split(
|
||||||
|
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert split.shape == q.shape
|
||||||
|
torch.testing.assert_close(split, nested, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_split_var_attention_rejects_mismatched_sequence_count():
|
||||||
|
q = torch.randn(5, 2, 4)
|
||||||
|
k = torch.randn(7, 2, 4)
|
||||||
|
v = torch.randn(7, 2, 4)
|
||||||
|
cu_q = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
cu_k = torch.tensor([0, 3, 5, 7], dtype=torch.int32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
attention.var_attention_pytorch_split(
|
||||||
|
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
|
||||||
|
skip_reshape=True, skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "same sequence count" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("mismatched cu_seqlens sequence counts must fail")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_split_var_attention_rejects_malformed_offsets():
|
||||||
|
q = torch.randn(5, 2, 4)
|
||||||
|
k = torch.randn(7, 2, 4)
|
||||||
|
v = torch.randn(7, 2, 4)
|
||||||
|
cu_k = torch.tensor([0, 3, 7], dtype=torch.int32)
|
||||||
|
|
||||||
|
malformed_cases = (
|
||||||
|
(torch.tensor([1, 2, 5], dtype=torch.int32), "start at 0"),
|
||||||
|
(torch.tensor([0, 2, 2, 5], dtype=torch.int32), "strictly increasing"),
|
||||||
|
(torch.tensor([0.0, 2.0, 5.0], dtype=torch.float32), "integer dtype"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for cu_q, message in malformed_cases:
|
||||||
|
try:
|
||||||
|
attention.var_attention_pytorch_split(
|
||||||
|
q, k, v, heads=2, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k,
|
||||||
|
skip_reshape=True, skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
assert message in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("malformed cu_seqlens must fail")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_window_attention_handles_mm_rope_source():
|
||||||
|
source = inspect.getsource(seedvr_model.NaSwinAttention.forward)
|
||||||
|
|
||||||
|
assert "if self.rope.mm" in source
|
||||||
|
assert "txt_q_repeat" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_window_attention_routes_to_split_var_attention():
|
||||||
|
source = inspect.getsource(seedvr_model.NaSwinAttention.forward)
|
||||||
|
|
||||||
|
assert "_seedvr2_7b_window_attention_split" in source
|
||||||
|
assert "if self.version_7b" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_window_attention_split_matches_concat_path():
|
||||||
|
torch.manual_seed(3)
|
||||||
|
vid_len_win = torch.tensor([1, 2, 3], dtype=torch.int64)
|
||||||
|
txt_len = torch.tensor([2, 3], dtype=torch.int64)
|
||||||
|
window_count = torch.tensor([2, 1], dtype=torch.int64)
|
||||||
|
heads = 2
|
||||||
|
dim = 4
|
||||||
|
|
||||||
|
vid_total = int(vid_len_win.sum().item())
|
||||||
|
txt_total = int(txt_len.sum().item())
|
||||||
|
vid_q = torch.randn(vid_total, heads, dim)
|
||||||
|
vid_k = torch.randn(vid_total, heads, dim)
|
||||||
|
vid_v = torch.randn(vid_total, heads, dim)
|
||||||
|
txt_q = torch.randn(txt_total, heads, dim)
|
||||||
|
txt_k = torch.randn(txt_total, heads, dim)
|
||||||
|
txt_v = torch.randn(txt_total, heads, dim)
|
||||||
|
|
||||||
|
concat_win, unconcat_win = seedvr_model.repeat_concat_idx(vid_len_win, txt_len, window_count)
|
||||||
|
all_len_win = vid_len_win + txt_len.repeat_interleave(window_count)
|
||||||
|
cu_seqlens = torch.nn.functional.pad(all_len_win.cumsum(0), (1, 0)).int()
|
||||||
|
concat_out = attention.var_attention_pytorch_split(
|
||||||
|
concat_win(vid_q, txt_q),
|
||||||
|
concat_win(vid_k, txt_k),
|
||||||
|
concat_win(vid_v, txt_v),
|
||||||
|
heads=heads,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
expected_vid, expected_txt = unconcat_win(concat_out)
|
||||||
|
|
||||||
|
split_vid, split_txt = seedvr_model._seedvr2_7b_window_attention_split(
|
||||||
|
vid_q, txt_q, vid_k, txt_k, vid_v, txt_v,
|
||||||
|
vid_len_win, txt_len, window_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(split_vid, expected_vid, rtol=1e-5, atol=1e-5)
|
||||||
|
torch.testing.assert_close(split_txt, expected_txt, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_window_attention_split_preserves_autograd():
|
||||||
|
torch.manual_seed(4)
|
||||||
|
vid_len_win = torch.tensor([1, 2, 3], dtype=torch.int64)
|
||||||
|
txt_len = torch.tensor([2, 3], dtype=torch.int64)
|
||||||
|
window_count = torch.tensor([2, 1], dtype=torch.int64)
|
||||||
|
heads = 2
|
||||||
|
dim = 4
|
||||||
|
|
||||||
|
vid_total = int(vid_len_win.sum().item())
|
||||||
|
txt_total = int(txt_len.sum().item())
|
||||||
|
vid_q = torch.randn(vid_total, heads, dim, requires_grad=True)
|
||||||
|
vid_k = torch.randn(vid_total, heads, dim, requires_grad=True)
|
||||||
|
vid_v = torch.randn(vid_total, heads, dim, requires_grad=True)
|
||||||
|
txt_q = torch.randn(txt_total, heads, dim, requires_grad=True)
|
||||||
|
txt_k = torch.randn(txt_total, heads, dim, requires_grad=True)
|
||||||
|
txt_v = torch.randn(txt_total, heads, dim, requires_grad=True)
|
||||||
|
|
||||||
|
split_vid, split_txt = seedvr_model._seedvr2_7b_window_attention_split(
|
||||||
|
vid_q, txt_q, vid_k, txt_k, vid_v, txt_v,
|
||||||
|
vid_len_win, txt_len, window_count,
|
||||||
|
)
|
||||||
|
(split_vid.sum() + split_txt.sum()).backward()
|
||||||
|
|
||||||
|
for tensor in (vid_q, vid_k, vid_v, txt_q, txt_k, txt_v):
|
||||||
|
assert tensor.grad is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_mlp_chunks_video_tokens(monkeypatch):
|
||||||
|
class TrackingModule(torch.nn.Module):
|
||||||
|
def __init__(self, scale):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.calls.append(x.shape[0])
|
||||||
|
return x * self.scale
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "SEEDVR2_7B_MLP_CHUNK", 2)
|
||||||
|
|
||||||
|
vid_module = TrackingModule(2.0)
|
||||||
|
txt_module = TrackingModule(3.0)
|
||||||
|
block = SimpleNamespace(
|
||||||
|
mlp=SimpleNamespace(
|
||||||
|
shared_weights=False,
|
||||||
|
vid_only=False,
|
||||||
|
vid=vid_module,
|
||||||
|
txt=txt_module,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vid = torch.arange(24, dtype=torch.float32).reshape(6, 4)
|
||||||
|
txt = torch.arange(12, dtype=torch.float32).reshape(3, 4)
|
||||||
|
|
||||||
|
out_vid, out_txt = seedvr_model.NaMMSRTransformerBlock._seedvr2_7b_mlp(block, vid, txt)
|
||||||
|
|
||||||
|
assert vid_module.calls == [2, 2, 2]
|
||||||
|
assert txt_module.calls == [3]
|
||||||
|
torch.testing.assert_close(out_vid, vid * 2.0)
|
||||||
|
torch.testing.assert_close(out_txt, txt * 3.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_mlp_preserves_video_autograd(monkeypatch):
|
||||||
|
class TrackingModule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x * 2.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "SEEDVR2_7B_MLP_CHUNK", 2)
|
||||||
|
|
||||||
|
block = SimpleNamespace(
|
||||||
|
mlp=SimpleNamespace(
|
||||||
|
shared_weights=False,
|
||||||
|
vid_only=True,
|
||||||
|
vid=TrackingModule(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vid_base = torch.arange(24, dtype=torch.float32, requires_grad=True)
|
||||||
|
vid = vid_base.reshape(6, 4)
|
||||||
|
txt = torch.arange(12, dtype=torch.float32).reshape(3, 4)
|
||||||
|
|
||||||
|
out_vid, _ = seedvr_model.NaMMSRTransformerBlock._seedvr2_7b_mlp(block, vid, txt)
|
||||||
|
out_vid.sum().backward()
|
||||||
|
|
||||||
|
assert vid_base.grad is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_block_routes_mlp_to_chunk_helper():
|
||||||
|
source = inspect.getsource(seedvr_model.NaMMSRTransformerBlock.forward)
|
||||||
|
|
||||||
|
assert "if self.version" in source
|
||||||
|
assert "_seedvr2_7b_mlp" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
|
||||||
|
estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
|
||||||
|
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
|
||||||
|
|
||||||
|
assert estimate == 101 * 960 * 1280 * 160
|
||||||
|
assert estimate > 15 * 1024 ** 3
|
||||||
|
assert estimate > old_estimate * 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vae_decode_memory_estimate_is_per_sample():
|
||||||
|
single = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
|
||||||
|
batch = comfy.sd._seedvr2_vae_decode_memory_used((2, 16, 26, 120, 160))
|
||||||
|
|
||||||
|
assert batch == single
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vae_decode_memory_accepts_channel_last_tiled_latents():
|
||||||
|
channel_first = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
|
||||||
|
channel_last = comfy.sd._seedvr2_vae_decode_memory_used((1, 26, 120, 160, 16))
|
||||||
|
|
||||||
|
assert channel_last == channel_first
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vae_decode_memory_rounds_malformed_collapsed_channels_up():
|
||||||
|
malformed = comfy.sd._seedvr2_vae_decode_memory_used((1, 17, 120, 160))
|
||||||
|
expected = comfy.sd._seedvr2_vae_decode_output_pixels(2, 120, 160) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
|
||||||
|
|
||||||
|
assert malformed == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vae_decode_memory_uses_conservative_ambiguous_5d_layout():
|
||||||
|
ambiguous = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 120, 160, 16))
|
||||||
|
channel_first = comfy.sd._seedvr2_vae_decode_output_pixels(120, 160, 16) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
|
||||||
|
channel_last = comfy.sd._seedvr2_vae_decode_output_pixels(16, 120, 160) * comfy.sd.SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
|
||||||
|
|
||||||
|
assert ambiguous == max(channel_first, channel_last)
|
||||||
218
tests-unit/comfy_test/test_seedvr_7b_final_block_text_path.py
Normal file
218
tests-unit/comfy_test/test_seedvr_7b_final_block_text_path.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class _StubModule(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
|
||||||
|
flags = []
|
||||||
|
|
||||||
|
class _Block(_StubModule):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
flags.append(kwargs["is_last_layer"])
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
|
||||||
|
|
||||||
|
seedvr_model.NaDiT(
|
||||||
|
norm_eps=1e-5,
|
||||||
|
qk_rope=None,
|
||||||
|
num_layers=4,
|
||||||
|
mlp_type="normal",
|
||||||
|
vid_dim=vid_dim,
|
||||||
|
txt_in_dim=txt_in_dim,
|
||||||
|
heads=24,
|
||||||
|
mm_layers=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
|
||||||
|
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_3b_keeps_final_block_vid_only_path(monkeypatch):
|
||||||
|
assert _capture_last_layer_flags(monkeypatch, vid_dim=2560, txt_in_dim=2560) == [
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _capture_block_attention_rope_type(monkeypatch, qk_rope):
|
||||||
|
rope_types = []
|
||||||
|
|
||||||
|
class _Attention(_StubModule):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
rope_types.append(kwargs["rope_type"])
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "MMModule", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaSwinAttention", _Attention)
|
||||||
|
|
||||||
|
seedvr_model.NaMMSRTransformerBlock(
|
||||||
|
vid_dim=4,
|
||||||
|
txt_dim=4,
|
||||||
|
emb_dim=4,
|
||||||
|
heads=1,
|
||||||
|
head_dim=4,
|
||||||
|
expand_ratio=1,
|
||||||
|
norm=_StubModule,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
ada=_StubModule,
|
||||||
|
qk_bias=False,
|
||||||
|
qk_rope=qk_rope,
|
||||||
|
qk_norm=_StubModule,
|
||||||
|
mlp_type="normal",
|
||||||
|
shared_weights=False,
|
||||||
|
rope_type="mmrope3d",
|
||||||
|
rope_dim=4,
|
||||||
|
is_last_layer=False,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
operations=seedvr_model.comfy.ops.disable_weight_init,
|
||||||
|
)
|
||||||
|
|
||||||
|
return rope_types
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_3b_qk_rope_none_preserves_checkpoint_rope_buffers(monkeypatch):
|
||||||
|
assert _capture_block_attention_rope_type(monkeypatch, qk_rope=None) == ["mmrope3d"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_qk_rope_true_preserves_attention_rope(monkeypatch):
|
||||||
|
assert _capture_block_attention_rope_type(monkeypatch, qk_rope=True) == ["mmrope3d"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_rope3d_matches_checkpoint_buffer_shape():
|
||||||
|
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
||||||
|
|
||||||
|
assert isinstance(rope, seedvr_model.NaRotaryEmbedding3d)
|
||||||
|
assert tuple(rope.rope.freqs.shape) == (10,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_rope3d_preserves_qk_shape():
|
||||||
|
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
||||||
|
q = torch.randn(4, 2, 128)
|
||||||
|
k = torch.randn(4, 2, 128)
|
||||||
|
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
||||||
|
|
||||||
|
q_out, k_out = rope(q, k, shape, seedvr_model.Cache(disable=True))
|
||||||
|
|
||||||
|
assert q_out.shape == q.shape
|
||||||
|
assert k_out.shape == k.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
||||||
|
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||||
|
q = torch.randn(4, 2, 128, generator=generator)
|
||||||
|
k = torch.randn(4, 2, 128, generator=generator)
|
||||||
|
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
||||||
|
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
||||||
|
|
||||||
|
expected_q = seedvr_model.apply_rotary_emb(
|
||||||
|
freqs,
|
||||||
|
q.permute(1, 0, 2).float(),
|
||||||
|
).to(q.dtype).permute(1, 0, 2)
|
||||||
|
expected_k = seedvr_model.apply_rotary_emb(
|
||||||
|
freqs,
|
||||||
|
k.permute(1, 0, 2).float(),
|
||||||
|
).to(k.dtype).permute(1, 0, 2)
|
||||||
|
|
||||||
|
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
|
||||||
|
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_mmrope_handles_large_spatial_grid_without_truncation():
|
||||||
|
rope = seedvr_model.NaMMRotaryEmbedding3d(dim=12)
|
||||||
|
vid_shape = torch.tensor([[1, 129, 130]], dtype=torch.long)
|
||||||
|
txt_shape = torch.tensor([[2]], dtype=torch.long)
|
||||||
|
vid_tokens = int(vid_shape.prod().item())
|
||||||
|
txt_tokens = int(txt_shape.prod().item())
|
||||||
|
vid_q = torch.zeros(vid_tokens, 1, 12)
|
||||||
|
vid_k = torch.zeros_like(vid_q)
|
||||||
|
txt_q = torch.zeros(txt_tokens, 1, 12)
|
||||||
|
txt_k = torch.zeros_like(txt_q)
|
||||||
|
|
||||||
|
out = rope(vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, seedvr_model.Cache(disable=True))
|
||||||
|
|
||||||
|
assert [tuple(t.shape) for t in out] == [
|
||||||
|
tuple(vid_q.shape),
|
||||||
|
tuple(vid_k.shape),
|
||||||
|
tuple(txt_q.shape),
|
||||||
|
tuple(txt_k.shape),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_adasingle_init_preserves_supported_dtype():
|
||||||
|
ada = seedvr_model.AdaSingle(
|
||||||
|
dim=4,
|
||||||
|
emb_dim=24,
|
||||||
|
layers=["test"],
|
||||||
|
modes=["in", "out"],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ada.test_shift.dtype is torch.bfloat16
|
||||||
|
assert ada.test_scale.dtype is torch.bfloat16
|
||||||
|
assert ada.test_gate.dtype is torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
|
def test_adasingle_init_uses_default_dtype_for_fp8():
|
||||||
|
if not hasattr(torch, "float8_e4m3fn"):
|
||||||
|
return
|
||||||
|
|
||||||
|
ada = seedvr_model.AdaSingle(
|
||||||
|
dim=4,
|
||||||
|
emb_dim=24,
|
||||||
|
layers=["test"],
|
||||||
|
modes=["in", "out"],
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ada.test_shift.dtype is torch.float32
|
||||||
|
assert ada.test_scale.dtype is torch.float32
|
||||||
|
assert ada.test_gate.dtype is torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_adasingle_init_and_forward_share_fp8_dtype_set():
|
||||||
|
expected = {
|
||||||
|
getattr(torch, name)
|
||||||
|
for name in (
|
||||||
|
"float8_e4m3fn",
|
||||||
|
"float8_e4m3fnuz",
|
||||||
|
"float8_e5m2",
|
||||||
|
"float8_e5m2fnuz",
|
||||||
|
"float8_e8m0fnu",
|
||||||
|
)
|
||||||
|
if hasattr(torch, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert set(seedvr_model._torch_float8_types()) == expected
|
||||||
54
tests-unit/comfy_test/test_seedvr_forward_no_device_cast.py
Normal file
54
tests-unit/comfy_test/test_seedvr_forward_no_device_cast.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from comfy.cli_args import args
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import ast # noqa: E402
|
||||||
|
import inspect # noqa: E402
|
||||||
|
|
||||||
|
from torch import nn # noqa: E402
|
||||||
|
|
||||||
|
import comfy # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.model # noqa: E402
|
||||||
|
import comfy.model_management # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.model import MMModule # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_get_torch_device_in_forward_methods():
|
||||||
|
tree = ast.parse(inspect.getsource(comfy.ldm.seedvr.model))
|
||||||
|
assert [
|
||||||
|
(n.lineno, i.lineno)
|
||||||
|
for n in ast.walk(tree)
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "forward"
|
||||||
|
for i in ast.walk(n)
|
||||||
|
if isinstance(i, ast.Call)
|
||||||
|
and isinstance(i.func, ast.Attribute)
|
||||||
|
and i.func.attr == "get_torch_device"
|
||||||
|
] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmmodule_forward_succeeds_without_get_torch_device_lookup(monkeypatch):
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def boom():
|
||||||
|
call_count[0] += 1
|
||||||
|
raise RuntimeError("MMModule.forward called get_torch_device()")
|
||||||
|
|
||||||
|
monkeypatch.setattr(comfy.model_management, "get_torch_device", boom)
|
||||||
|
|
||||||
|
class _IdentityCallable(nn.Module):
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
mm = MMModule(_IdentityCallable, shared_weights=False, vid_only=False)
|
||||||
|
|
||||||
|
vid_in = torch.zeros(2, 4)
|
||||||
|
txt_in = torch.ones(2, 4)
|
||||||
|
vid_out, txt_out = mm.forward(vid_in, txt_in)
|
||||||
|
|
||||||
|
assert call_count[0] == 0
|
||||||
|
assert torch.equal(vid_out, vid_in)
|
||||||
|
assert torch.equal(txt_out, txt_in)
|
||||||
|
assert vid_out.device == vid_in.device
|
||||||
|
assert txt_out.device == txt_in.device
|
||||||
179
tests-unit/comfy_test/test_seedvr_groupnorm_limit.py
Normal file
179
tests-unit/comfy_test/test_seedvr_groupnorm_limit.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
"""Regression: ``comfy.ldm.seedvr.vae.causal_norm_wrapper`` 5D GroupNorm
|
||||||
|
gate at ``vae.py:509`` must compare ``memory_occupy`` against the configured
|
||||||
|
``get_norm_limit()`` accessor, not against a hardcoded ``float('inf')``.
|
||||||
|
|
||||||
|
The original code path was ``... > float('inf')`` which is unreachable at any
|
||||||
|
finite ``memory_occupy`` value, so SeedVR2's ``norm_max_mem`` setting (wired
|
||||||
|
through ``set_norm_limit``) had no effect.
|
||||||
|
|
||||||
|
This module locks in two complementary cases against any future regression,
|
||||||
|
parametrized over both ``ops.GroupNorm`` subclasses (``disable_weight_init`` and
|
||||||
|
``manual_cast``) since the production gate ``isinstance(norm_layer, ops.GroupNorm)``
|
||||||
|
matches both.
|
||||||
|
|
||||||
|
* ``test_seedvr_groupnorm_default_limit_uses_full_groupnorm_path`` — with
|
||||||
|
the limit at its default ``inf``, the full GroupNorm forward must run and
|
||||||
|
the chunked branch must NOT run, regardless of input tensor size.
|
||||||
|
* ``test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path`` — with a
|
||||||
|
deliberately low limit (``1e-9 GiB``), the chunked branch must run and
|
||||||
|
the full GroupNorm forward must NOT run.
|
||||||
|
|
||||||
|
Each case discriminates the two branches with two independent observers:
|
||||||
|
|
||||||
|
1. ``nn.Module.register_forward_hook`` on the GroupNorm — fires only on the
|
||||||
|
full-path branch ``norm_layer(x)``; the chunked branch bypasses the
|
||||||
|
module ``__call__`` and goes through ``F.group_norm`` directly.
|
||||||
|
2. ``unittest.mock.patch.object(vae.F, 'group_norm', ...)`` spy with
|
||||||
|
``side_effect`` delegating to the real ``torch.nn.functional.group_norm``
|
||||||
|
— captures every direct ``F.group_norm`` call's ``num_groups`` argument.
|
||||||
|
Calls with ``num_groups < gn.num_groups`` come from the chunked branch
|
||||||
|
(``num_groups_per_chunk = gn.num_groups // num_chunks``).
|
||||||
|
|
||||||
|
The spy uses ``*args, **kwargs`` passthrough so future ``F.group_norm`` kwargs
|
||||||
|
do not break the test.
|
||||||
|
|
||||||
|
CPU-only by construction: the tests use a small float32 tensor and never
|
||||||
|
allocate a real model or GPU memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ops as comfy_ops # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.vae import ( # noqa: E402
|
||||||
|
causal_norm_wrapper,
|
||||||
|
set_norm_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_NUM_CHANNELS = 8
|
||||||
|
_NUM_GROUPS = 4
|
||||||
|
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
|
||||||
|
|
||||||
|
# Both ``ops.GroupNorm`` subclasses appear in production paths depending on
|
||||||
|
# the active backend. The dispatch gate at ``vae.py:509`` reads
|
||||||
|
# ``isinstance(norm_layer, ops.GroupNorm)`` and matches both via MRO.
|
||||||
|
_GROUPNORM_SUBCLASSES = [
|
||||||
|
pytest.param(
|
||||||
|
comfy_ops.disable_weight_init.GroupNorm,
|
||||||
|
id="disable_weight_init",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
comfy_ops.manual_cast.GroupNorm,
|
||||||
|
id="manual_cast",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
|
||||||
|
def test_seedvr_groupnorm_default_limit_uses_full_groupnorm_path(groupnorm_cls):
|
||||||
|
real_group_norm = vae_mod.F.group_norm
|
||||||
|
set_norm_limit(None)
|
||||||
|
try:
|
||||||
|
gn = groupnorm_cls(
|
||||||
|
num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS
|
||||||
|
)
|
||||||
|
gn.eval()
|
||||||
|
|
||||||
|
forward_hook_calls = []
|
||||||
|
|
||||||
|
def _hook(module, inputs, output):
|
||||||
|
forward_hook_calls.append(tuple(inputs[0].shape))
|
||||||
|
|
||||||
|
spy_calls = []
|
||||||
|
|
||||||
|
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
|
||||||
|
spy_calls.append({
|
||||||
|
"num_groups": int(num_groups_arg),
|
||||||
|
"input_shape": tuple(int(s) for s in input_tensor.shape),
|
||||||
|
})
|
||||||
|
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
|
||||||
|
|
||||||
|
handle = gn.register_forward_hook(_hook)
|
||||||
|
try:
|
||||||
|
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
|
||||||
|
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
|
||||||
|
finally:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
full_calls = len(forward_hook_calls)
|
||||||
|
chunked_calls = sum(
|
||||||
|
1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE, (
|
||||||
|
f"causal_norm_wrapper output shape {tuple(out_tensor.shape)} does not "
|
||||||
|
f"match input shape {_TENSOR_SHAPE}"
|
||||||
|
)
|
||||||
|
assert full_calls == 1, (
|
||||||
|
f"default-limit (inf) GroupNorm gate must take the full-forward path "
|
||||||
|
f"(register_forward_hook fires exactly once); got full_calls={full_calls}"
|
||||||
|
)
|
||||||
|
assert chunked_calls == 0, (
|
||||||
|
f"default-limit (inf) GroupNorm gate must NOT take the chunked path "
|
||||||
|
f"(no F.group_norm call with num_groups<{_NUM_GROUPS}); got "
|
||||||
|
f"chunked_calls={chunked_calls}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
set_norm_limit(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
|
||||||
|
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
||||||
|
real_group_norm = vae_mod.F.group_norm
|
||||||
|
set_norm_limit(1e-9)
|
||||||
|
try:
|
||||||
|
gn = groupnorm_cls(
|
||||||
|
num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS
|
||||||
|
)
|
||||||
|
gn.eval()
|
||||||
|
|
||||||
|
forward_hook_calls = []
|
||||||
|
|
||||||
|
def _hook(module, inputs, output):
|
||||||
|
forward_hook_calls.append(tuple(inputs[0].shape))
|
||||||
|
|
||||||
|
spy_calls = []
|
||||||
|
|
||||||
|
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
|
||||||
|
spy_calls.append({
|
||||||
|
"num_groups": int(num_groups_arg),
|
||||||
|
"input_shape": tuple(int(s) for s in input_tensor.shape),
|
||||||
|
})
|
||||||
|
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
|
||||||
|
|
||||||
|
handle = gn.register_forward_hook(_hook)
|
||||||
|
try:
|
||||||
|
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
|
||||||
|
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
|
||||||
|
finally:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
full_calls = len(forward_hook_calls)
|
||||||
|
chunked_calls = sum(
|
||||||
|
1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE, (
|
||||||
|
f"causal_norm_wrapper output shape {tuple(out_tensor.shape)} does not "
|
||||||
|
f"match input shape {_TENSOR_SHAPE}"
|
||||||
|
)
|
||||||
|
assert full_calls == 0, (
|
||||||
|
f"low-limit GroupNorm gate must NOT take the full-forward path "
|
||||||
|
f"(register_forward_hook should not fire); got full_calls={full_calls}"
|
||||||
|
)
|
||||||
|
assert chunked_calls > 0, (
|
||||||
|
f"low-limit GroupNorm gate must take the chunked path "
|
||||||
|
f"(at least one F.group_norm call with num_groups<{_NUM_GROUPS}); got "
|
||||||
|
f"chunked_calls={chunked_calls}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
set_norm_limit(None)
|
||||||
40
tests-unit/comfy_test/test_seedvr_latent_format.py
Normal file
40
tests-unit/comfy_test/test_seedvr_latent_format.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.latent_formats
|
||||||
|
import comfy.sample
|
||||||
|
|
||||||
|
|
||||||
|
class _Model:
|
||||||
|
def __init__(self, latent_format):
|
||||||
|
self._latent_format = latent_format
|
||||||
|
|
||||||
|
def get_model_object(self, name):
|
||||||
|
assert name == "latent_format"
|
||||||
|
return self._latent_format
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
|
||||||
|
latent_format = comfy.latent_formats.SeedVR2()
|
||||||
|
latent_image = torch.zeros(1, 1, 4, 5)
|
||||||
|
|
||||||
|
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
||||||
|
|
||||||
|
assert latent_format.latent_channels == 16
|
||||||
|
assert latent_format.latent_dimensions == 2
|
||||||
|
assert fixed.shape == (1, 16, 4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_empty_collapsed_latent_preserves_temporal_channel_multiples():
|
||||||
|
latent_format = comfy.latent_formats.SeedVR2()
|
||||||
|
latent_image = torch.zeros(1, 48, 4, 5)
|
||||||
|
|
||||||
|
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
||||||
|
|
||||||
|
assert latent_format.preserve_empty_channel_multiples is True
|
||||||
|
assert fixed.shape == latent_image.shape
|
||||||
|
assert fixed.data_ptr() == latent_image.data_ptr()
|
||||||
176
tests-unit/comfy_test/test_seedvr_rope_delegation.py
Normal file
176
tests-unit/comfy_test/test_seedvr_rope_delegation.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
"""Regression test: ``comfy.ldm.seedvr.model.apply_rotary_emb`` must delegate
|
||||||
|
to ``comfy.ldm.flux.math.apply_rope1`` and produce exact-equality output
|
||||||
|
across the wrapper's slicing, scaling, and concatenation logic. Drift between
|
||||||
|
the wrapper and the delegate would silently corrupt SeedVR2's RoPE; this test
|
||||||
|
fails loudly on any future drift.
|
||||||
|
|
||||||
|
Each parametrized case does both:
|
||||||
|
|
||||||
|
1. Patches ``comfy.ldm.seedvr.model.apply_rope1`` with a ``wraps``-style spy
|
||||||
|
and asserts ``spy.call_count >= 1`` so a future change that inlines the
|
||||||
|
math and stops calling ``apply_rope1`` fails the test.
|
||||||
|
2. Compares the wrapper's output against a hand-rolled reproduction using
|
||||||
|
``torch.testing.assert_close(rtol=0, atol=0)`` -- exact tensor equality,
|
||||||
|
not bit-equality (``+0.0`` vs ``-0.0`` and NaN payloads can still match);
|
||||||
|
the assertion catches any future kernel-precision drift in the
|
||||||
|
``apply_rope1`` dispatch.
|
||||||
|
|
||||||
|
The test uses a local ``torch.Generator`` so global RNG state is not mutated.
|
||||||
|
Parametrization covers non-default ``start_index`` and ``scale`` and a case
|
||||||
|
where ``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's
|
||||||
|
``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised.
|
||||||
|
Imports are taken at module level. Heavy-import stubbing of
|
||||||
|
``comfy.model_management`` was attempted but is insufficient on this live
|
||||||
|
import chain (``comfy.ldm.seedvr.model`` pulls
|
||||||
|
``comfy.ldm.modules.diffusionmodules.model -> comfy.ops ->
|
||||||
|
comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor ->
|
||||||
|
torch._dynamo``), so this test intentionally runs against the real modules
|
||||||
|
to fail loudly if that import path or runtime state drifts. Other tests in
|
||||||
|
this repo (e.g. ``tests-unit/comfy_extras_test/image_stitch_test.py``) do
|
||||||
|
stub via ``patch.dict(sys.modules, ...)`` for narrower targets; the choice
|
||||||
|
here is local to this regression and not a repo-wide convention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# CPU-only CI fix: ``comfy.ldm.seedvr.model`` transitively imports
|
||||||
|
# ``comfy.model_management``, whose import-time ``get_torch_device()`` call
|
||||||
|
# probes ``torch.cuda.current_device()`` unless ``comfy.cli_args.args.cpu`` is
|
||||||
|
# set. On a CPU-only build that probe can raise during test collection before
|
||||||
|
# the ``cuda`` case has had a chance to be skipped. Match the pattern used by
|
||||||
|
# ``tests-unit/comfy_quant/test_mixed_precision.py``: flip ``args.cpu`` before
|
||||||
|
# importing any ``comfy.ldm.*`` symbol.
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||||
|
from comfy.ldm.flux.math import apply_rope1 # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.model import apply_rotary_emb # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _direct_reproduction(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
||||||
|
"""Reproduce the body of ``apply_rotary_emb`` for the default case where
|
||||||
|
``freqs.ndim == 2`` and ``t.ndim == 3`` (implicit ``freqs_seq_dim=0``).
|
||||||
|
Mirrors the wrapper's ``slice_at_dim(freqs, slice(-seq_len, None), dim=0)``
|
||||||
|
step when freqs is longer than ``t`` along ``seq_dim``. Calls the real
|
||||||
|
``apply_rope1`` via the test module's import (the test patches the
|
||||||
|
``seedvr_model.apply_rope1`` attribute; this call uses the unpatched
|
||||||
|
``flux.math`` symbol).
|
||||||
|
"""
|
||||||
|
if freqs.ndim == 2 and t.ndim == 3:
|
||||||
|
seq_len = t.shape[seq_dim]
|
||||||
|
freqs = freqs[-seq_len:]
|
||||||
|
|
||||||
|
rot_feats = freqs.shape[-1]
|
||||||
|
end_index = start_index + rot_feats
|
||||||
|
t_left = t[..., :start_index]
|
||||||
|
t_middle = t[..., start_index:end_index]
|
||||||
|
t_right = t[..., end_index:]
|
||||||
|
angles = freqs.to(t_middle.device)[..., ::2]
|
||||||
|
cos = torch.cos(angles) * scale
|
||||||
|
sin = torch.sin(angles) * scale
|
||||||
|
col0 = torch.stack([cos, sin], dim=-1)
|
||||||
|
col1 = torch.stack([-sin, cos], dim=-1)
|
||||||
|
freqs_mat = torch.stack([col0, col1], dim=-1)
|
||||||
|
t_middle_out = apply_rope1(t_middle, freqs_mat)
|
||||||
|
return torch.cat((t_left, t_middle_out, t_right), dim=-1).type(t.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _cpu_trig_supported(dtype):
|
||||||
|
"""Return whether ``torch.cos`` (and by symmetry ``torch.sin``) is
|
||||||
|
implemented for the given dtype on CPU on the current runtime. Some
|
||||||
|
PyTorch CPU wheels don't implement trig ops for ``float16`` / ``bfloat16``
|
||||||
|
and raise at runtime; the parametrized cases for those dtypes are skipped
|
||||||
|
when that's the case so CI remains stable across PyTorch builds.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
torch.cos(torch.zeros(1, dtype=dtype))
|
||||||
|
except (RuntimeError, TypeError):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
_CPU_FP16_TRIG_OK = _cpu_trig_supported(torch.float16)
|
||||||
|
_CPU_BF16_TRIG_OK = _cpu_trig_supported(torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
# (device, dtype, t_shape, freqs_shape, start_index, scale)
|
||||||
|
_CASES = [
|
||||||
|
pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 1.0,
|
||||||
|
id="cpu-float32-base"),
|
||||||
|
pytest.param(
|
||||||
|
"cpu", torch.float16, (1, 8, 16), (8, 16), 0, 1.0,
|
||||||
|
id="cpu-float16-base",
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not _CPU_FP16_TRIG_OK,
|
||||||
|
reason="torch.cos/torch.sin unsupported for float16 tensors on CPU",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"cpu", torch.bfloat16, (1, 8, 16), (8, 16), 0, 1.0,
|
||||||
|
id="cpu-bfloat16-base",
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not _CPU_BF16_TRIG_OK,
|
||||||
|
reason="torch.cos/torch.sin unsupported for bfloat16 tensors on CPU",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pytest.param("cpu", torch.float32, (2, 16, 32), (16, 32), 0, 1.0,
|
||||||
|
id="cpu-float32-larger"),
|
||||||
|
pytest.param("cpu", torch.float32, (1, 8, 24), (8, 16), 4, 1.0,
|
||||||
|
id="cpu-float32-non-empty-left-and-right-slices"),
|
||||||
|
pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 0.5,
|
||||||
|
id="cpu-float32-non-default-scale"),
|
||||||
|
pytest.param("cpu", torch.float32, (1, 8, 16), (12, 16), 0, 1.0,
|
||||||
|
id="cpu-float32-freqs-longer-than-seq"),
|
||||||
|
pytest.param(
|
||||||
|
"cuda", torch.float16, (1, 8, 16), (8, 16), 0, 1.0,
|
||||||
|
id="cuda-float16-base",
|
||||||
|
marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device,dtype,t_shape,freqs_shape,start_index,scale", _CASES)
|
||||||
|
def test_apply_rotary_emb_delegates_to_apply_rope1(
|
||||||
|
device, dtype, t_shape, freqs_shape, start_index, scale
|
||||||
|
):
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
t = torch.randn(*t_shape, dtype=dtype, device=device, generator=generator)
|
||||||
|
freqs = torch.randn(*freqs_shape, dtype=dtype, device=device, generator=generator)
|
||||||
|
|
||||||
|
# Patch the apply_rope1 symbol as imported into seedvr.model with a wraps
|
||||||
|
# spy: a future change that inlines the math and stops calling the
|
||||||
|
# imported apply_rope1 makes spy.call_count == 0 and fails the test.
|
||||||
|
with patch.object(
|
||||||
|
seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1
|
||||||
|
) as spy:
|
||||||
|
wrapper_out = apply_rotary_emb(
|
||||||
|
freqs, t, start_index=start_index, scale=scale
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spy.call_count >= 1, (
|
||||||
|
"apply_rotary_emb did not call comfy.ldm.seedvr.model.apply_rope1; "
|
||||||
|
"the delegation invariant is broken"
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_out = _direct_reproduction(
|
||||||
|
freqs, t, start_index=start_index, scale=scale
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f"apply_rotary_emb output does not match direct apply_rope1 "
|
||||||
|
f"reproduction (device={device}, dtype={dtype}, t_shape={t_shape}, "
|
||||||
|
f"freqs_shape={freqs_shape}, start_index={start_index}, scale={scale})"
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
wrapper_out,
|
||||||
|
direct_out,
|
||||||
|
rtol=0,
|
||||||
|
atol=0,
|
||||||
|
msg=msg,
|
||||||
|
)
|
||||||
335
tests-unit/comfy_test/test_seedvr_rope_rewrite.py
Normal file
335
tests-unit/comfy_test/test_seedvr_rope_rewrite.py
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
"""Regression tests for the SeedVR2 native RoPE rewrite that replaces the
|
||||||
|
``apply_rotary_emb`` wrapper inside ``NaMMRotaryEmbedding3d.forward`` with
|
||||||
|
direct calls to ``comfy.ldm.flux.math.apply_rope1`` — matching the pattern
|
||||||
|
used by the other 7 ComfyUI native-DiT models (flux, hidream, kandinsky5,
|
||||||
|
lumina, qwen_image, wan, sam3).
|
||||||
|
|
||||||
|
The wrapper builds a 2x2 ``freqs_mat`` and ends in ``torch.cat((t_left,
|
||||||
|
t_middle_out, t_right), dim=-1)``; that cat OOMs on the largest cell of the
|
||||||
|
SeedVR2 native_3b non-tiled corpus (VideoLQ_000 1280x960x100 on RTX 5090
|
||||||
|
32GB). Canonical and numz pass the same cell because both call
|
||||||
|
``rotary_embedding_torch.apply_rotary_emb`` directly. The fix moves the
|
||||||
|
NaMMRotaryEmbedding3d path onto ``apply_rope1`` directly with freqs in
|
||||||
|
flux-canonical shape ``[..., d/2, 2, 2]`` (cos/-sin/sin/cos baked in).
|
||||||
|
|
||||||
|
This test file pins four invariants the rewrite must satisfy:
|
||||||
|
|
||||||
|
1. ``NaMMRotaryEmbedding3d.forward`` calls ``apply_rope1`` 4 times per
|
||||||
|
forward (vid_q, vid_k, txt_q, txt_k) and 0 times into the
|
||||||
|
``apply_rotary_emb`` wrapper.
|
||||||
|
2. ``NaMMRotaryEmbedding3d.get_freqs`` returns freqs in flux-canonical shape
|
||||||
|
``[..., d/2, 2, 2]`` with the cos/-sin/sin/cos pattern from
|
||||||
|
``comfy/ldm/flux/math.py:rope`` (line 27).
|
||||||
|
3. The forward output is tensor-equal at fp32 against an oracle computed
|
||||||
|
from the unchanged ``apply_rotary_emb`` wrapper fed with the legacy
|
||||||
|
freqs layout — proving the rewrite is algorithmically lossless.
|
||||||
|
4. AST: no ``apply_rotary_emb`` call sites remain inside
|
||||||
|
``NaMMRotaryEmbedding3d.forward``.
|
||||||
|
|
||||||
|
The wrapper itself stays in the file (still used by
|
||||||
|
``RotaryEmbedding3d.forward`` lines 434-435 and the staticmethod
|
||||||
|
registration on lucidrains' ``RotaryEmbedding`` line 323). Out of scope
|
||||||
|
here.
|
||||||
|
|
||||||
|
Pre-import CPU-only guard mirrors ``test_seedvr_rope_delegation.py`` —
|
||||||
|
``comfy.ldm.seedvr.model`` transitively imports ``comfy.model_management``
|
||||||
|
which probes ``torch.cuda.current_device()`` at import time unless
|
||||||
|
``args.cpu`` is set first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.model import ( # noqa: E402
|
||||||
|
Cache,
|
||||||
|
NaMMRotaryEmbedding3d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
|
||||||
|
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
|
||||||
|
# heads = 4. These are all small enough to run on CPU in milliseconds.
|
||||||
|
_DIM = 192
|
||||||
|
_HEADS = 4
|
||||||
|
_VID_T, _VID_H, _VID_W = 2, 4, 4
|
||||||
|
_TXT_L = 8
|
||||||
|
_L_VID = _VID_T * _VID_H * _VID_W
|
||||||
|
_SEED = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _make_inputs(dtype=torch.float32, device="cpu"):
|
||||||
|
"""Construct the 6 forward inputs + cache. Deterministic via local
|
||||||
|
Generator so global RNG state is not mutated.
|
||||||
|
"""
|
||||||
|
g = torch.Generator(device=device).manual_seed(_SEED)
|
||||||
|
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||||
|
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||||
|
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||||
|
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||||
|
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
|
||||||
|
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
|
||||||
|
cache = Cache(disable=True)
|
||||||
|
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
|
||||||
|
|
||||||
|
|
||||||
|
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
|
||||||
|
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
|
||||||
|
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
|
||||||
|
unchanged by the rewrite). Used by Test 3 to compute the oracle from
|
||||||
|
the wrapper path post-rewrite, when ``rope.get_freqs`` itself returns
|
||||||
|
the new flux-canonical shape.
|
||||||
|
"""
|
||||||
|
max_temporal = 0
|
||||||
|
max_height = 0
|
||||||
|
max_width = 0
|
||||||
|
max_txt_len = 0
|
||||||
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||||
|
max_temporal = max(max_temporal, l + f)
|
||||||
|
max_height = max(max_height, h)
|
||||||
|
max_width = max(max_width, w)
|
||||||
|
max_txt_len = max(max_txt_len, l)
|
||||||
|
with torch.amp.autocast(device_type="cuda", enabled=False):
|
||||||
|
vid_freqs_full = rope.get_axial_freqs(
|
||||||
|
min(max_temporal + 16, 1024),
|
||||||
|
min(max_height + 4, 128),
|
||||||
|
min(max_width + 4, 128),
|
||||||
|
).float()
|
||||||
|
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
|
||||||
|
vid_freq_list, txt_freq_list = [], []
|
||||||
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||||
|
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
|
||||||
|
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
|
||||||
|
vid_freq_list.append(vid_freq)
|
||||||
|
txt_freq_list.append(txt_freq)
|
||||||
|
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
|
||||||
|
txt_q, txt_k, txt_shape):
|
||||||
|
"""Compute expected forward output via the unchanged
|
||||||
|
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
|
||||||
|
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
|
||||||
|
"""
|
||||||
|
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
|
||||||
|
vid_freqs = vid_freqs.to(vid_q.device)
|
||||||
|
txt_freqs = txt_freqs.to(txt_q.device)
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
vid_q = rearrange(vid_q, "L h d -> h L d")
|
||||||
|
vid_k = rearrange(vid_k, "L h d -> h L d")
|
||||||
|
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
|
||||||
|
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
|
||||||
|
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
|
||||||
|
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
|
||||||
|
|
||||||
|
txt_q = rearrange(txt_q, "L h d -> h L d")
|
||||||
|
txt_k = rearrange(txt_k, "L h d -> h L d")
|
||||||
|
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
|
||||||
|
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
|
||||||
|
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
|
||||||
|
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
|
||||||
|
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
|
||||||
|
|
||||||
|
|
||||||
|
# Test 1 — drives AC-4 (call-graph): forward must reach apply_rope1 directly,
|
||||||
|
# never via the apply_rotary_emb wrapper.
|
||||||
|
|
||||||
|
def test_namm_forward_calls_apply_rope1_directly():
|
||||||
|
rope = NaMMRotaryEmbedding3d(dim=_DIM)
|
||||||
|
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
seedvr_model, "apply_rotary_emb", wraps=seedvr_model.apply_rotary_emb
|
||||||
|
) as wrapper_spy, patch.object(
|
||||||
|
seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1
|
||||||
|
) as rope1_spy:
|
||||||
|
rope.forward(vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache)
|
||||||
|
|
||||||
|
assert wrapper_spy.call_count == 0, (
|
||||||
|
f"NaMMRotaryEmbedding3d.forward must not call apply_rotary_emb "
|
||||||
|
f"(saw {wrapper_spy.call_count} calls); the rewrite must rewire "
|
||||||
|
f"the 4 forward sites to apply_rope1 directly"
|
||||||
|
)
|
||||||
|
assert rope1_spy.call_count == 4, (
|
||||||
|
f"NaMMRotaryEmbedding3d.forward must call apply_rope1 exactly 4 "
|
||||||
|
f"times (vid_q, vid_k, txt_q, txt_k); saw {rope1_spy.call_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test 2 — drives the get_freqs shape change to flux-canonical layout.
|
||||||
|
|
||||||
|
def test_get_freqs_emits_flux_canonical_shape():
|
||||||
|
rope = NaMMRotaryEmbedding3d(dim=_DIM)
|
||||||
|
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long)
|
||||||
|
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long)
|
||||||
|
|
||||||
|
vid_freqs, txt_freqs = rope.get_freqs(vid_shape, txt_shape)
|
||||||
|
|
||||||
|
# Flux's `rope()` (comfy/ldm/flux/math.py:17-29) emits freqs in shape
|
||||||
|
# [..., d/2, 2, 2] via stack([cos, -sin, sin, cos], dim=-1) +
|
||||||
|
# rearrange("b n d (i j) -> b n d i j", i=2, j=2). The rewrite must
|
||||||
|
# match: ndim >= 4, last two dims both == 2.
|
||||||
|
assert vid_freqs.ndim >= 4, (
|
||||||
|
f"vid_freqs.ndim must be >= 4 (flux-canonical layout has trailing "
|
||||||
|
f"[..., d/2, 2, 2]); got ndim={vid_freqs.ndim}, shape={tuple(vid_freqs.shape)}"
|
||||||
|
)
|
||||||
|
assert vid_freqs.shape[-1] == 2, (
|
||||||
|
f"vid_freqs.shape[-1] must be 2 (rotation matrix column); got "
|
||||||
|
f"shape={tuple(vid_freqs.shape)}"
|
||||||
|
)
|
||||||
|
assert vid_freqs.shape[-2] == 2, (
|
||||||
|
f"vid_freqs.shape[-2] must be 2 (rotation matrix row); got "
|
||||||
|
f"shape={tuple(vid_freqs.shape)}"
|
||||||
|
)
|
||||||
|
assert txt_freqs.ndim >= 4, (
|
||||||
|
f"txt_freqs must also be flux-canonical; got ndim={txt_freqs.ndim}, "
|
||||||
|
f"shape={tuple(txt_freqs.shape)}"
|
||||||
|
)
|
||||||
|
assert txt_freqs.shape[-1] == 2 and txt_freqs.shape[-2] == 2, (
|
||||||
|
f"txt_freqs trailing dims must be (2, 2); got shape={tuple(txt_freqs.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the cos/-sin/sin/cos pattern at index 0:
|
||||||
|
# freqs_cis[..., 0, 0] = cos
|
||||||
|
# freqs_cis[..., 0, 1] = -sin
|
||||||
|
# freqs_cis[..., 1, 0] = sin
|
||||||
|
# freqs_cis[..., 1, 1] = cos
|
||||||
|
# so [0,0] == [1,1] (both cos) and [0,1] == -[1,0] (=-sin vs +sin).
|
||||||
|
cos_a = vid_freqs[..., 0, 0]
|
||||||
|
cos_b = vid_freqs[..., 1, 1]
|
||||||
|
neg_sin = vid_freqs[..., 0, 1]
|
||||||
|
sin = vid_freqs[..., 1, 0]
|
||||||
|
assert torch.allclose(cos_a, cos_b, rtol=0, atol=0), (
|
||||||
|
"vid_freqs[..., 0, 0] must equal vid_freqs[..., 1, 1] (both = cos)"
|
||||||
|
)
|
||||||
|
assert torch.allclose(neg_sin, -sin, rtol=0, atol=0), (
|
||||||
|
"vid_freqs[..., 0, 1] must equal -vid_freqs[..., 1, 0] (= -sin vs +sin)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test 3 — drives AC-1: forward output is tensor-equal against the wrapper-
|
||||||
|
# fed oracle. Pre-rewrite: trivially passes (forward IS the wrapper path).
|
||||||
|
# Post-rewrite: must remain equal. Exact equality (rtol=atol=0) at fp32.
|
||||||
|
|
||||||
|
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
|
||||||
|
rope = NaMMRotaryEmbedding3d(dim=_DIM)
|
||||||
|
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
|
||||||
|
|
||||||
|
# Oracle: the unchanged apply_rotary_emb wrapper fed with legacy-shape
|
||||||
|
# freqs produced by reproducing the pre-rewrite get_freqs body.
|
||||||
|
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
|
||||||
|
rope,
|
||||||
|
vid_q.clone(), vid_k.clone(), vid_shape,
|
||||||
|
txt_q.clone(), txt_k.clone(), txt_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actual: NaMMRotaryEmbedding3d.forward (under test).
|
||||||
|
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
|
||||||
|
vid_q.clone(), vid_k.clone(), vid_shape,
|
||||||
|
txt_q.clone(), txt_k.clone(), txt_shape, cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
|
||||||
|
msg="vid_q output diverges from wrapper oracle")
|
||||||
|
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
|
||||||
|
msg="vid_k output diverges from wrapper oracle")
|
||||||
|
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
|
||||||
|
msg="txt_q output diverges from wrapper oracle")
|
||||||
|
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
|
||||||
|
msg="txt_k output diverges from wrapper oracle")
|
||||||
|
|
||||||
|
|
||||||
|
# Test 5 — partial-rope coverage. The real SeedVR2-3B model is constructed
|
||||||
|
# with rope_dim=128, which integer-divides into 3 axes as 128//3 = 42 per-
|
||||||
|
# axis; total rope freq dims = 42*3 = 126. head_dim is 128, so the trailing
|
||||||
|
# 2 dims of each q/k must be passed through unrotated (matching the legacy
|
||||||
|
# wrapper's `t_right = t[..., end_index:]` behavior). The fp32-CPU oracle
|
||||||
|
# test (Test 3) uses dim=192 where rot_d == head_dim and the partial-rope
|
||||||
|
# path collapses to a single apply_rope1 call. This test exercises the
|
||||||
|
# partial path explicitly with dim=128 and asserts the rewired forward
|
||||||
|
# still tensor-equals the wrapper oracle in that regime.
|
||||||
|
|
||||||
|
def test_namm_forward_partial_rope_passthrough_matches_wrapper_oracle():
|
||||||
|
rope = NaMMRotaryEmbedding3d(dim=128)
|
||||||
|
g = torch.Generator(device="cpu").manual_seed(_SEED)
|
||||||
|
vid_q = torch.randn(_L_VID, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
|
||||||
|
vid_k = torch.randn(_L_VID, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
|
||||||
|
txt_q = torch.randn(_TXT_L, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
|
||||||
|
txt_k = torch.randn(_TXT_L, _HEADS, 128, dtype=torch.float32, device="cpu", generator=g)
|
||||||
|
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long)
|
||||||
|
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long)
|
||||||
|
cache = Cache(disable=True)
|
||||||
|
|
||||||
|
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
|
||||||
|
rope, vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape,
|
||||||
|
)
|
||||||
|
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
|
||||||
|
vid_q.clone(), vid_k.clone(), vid_shape, txt_q.clone(), txt_k.clone(), txt_shape, cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Confirm the partial-rope contract: rot_d (= 2 * freqs_cis.shape[-3]) is
|
||||||
|
# 126 (= 42*3), strictly less than head_dim 128. The trailing 2 head-dims
|
||||||
|
# are pass-through.
|
||||||
|
vid_freqs, _ = rope.get_freqs(vid_shape, txt_shape)
|
||||||
|
rot_d = 2 * vid_freqs.shape[-3]
|
||||||
|
assert rot_d == 126, f"expected rot_d=126 for dim=128 model; got {rot_d}"
|
||||||
|
assert rot_d < 128, "partial-rope path must trigger (rot_d < head_dim)"
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
|
||||||
|
msg="vid_q partial-rope output diverges from wrapper oracle")
|
||||||
|
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
|
||||||
|
msg="vid_k partial-rope output diverges from wrapper oracle")
|
||||||
|
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
|
||||||
|
msg="txt_q partial-rope output diverges from wrapper oracle")
|
||||||
|
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
|
||||||
|
msg="txt_k partial-rope output diverges from wrapper oracle")
|
||||||
|
|
||||||
|
|
||||||
|
# Test 4 — drives AC-4 statically: AST walk over NaMMRotaryEmbedding3d.forward
|
||||||
|
# must find zero references to the apply_rotary_emb symbol.
|
||||||
|
|
||||||
|
def test_namm_forward_ast_has_no_apply_rotary_emb_calls():
|
||||||
|
source_path = Path(inspect.getsourcefile(NaMMRotaryEmbedding3d))
|
||||||
|
tree = ast.parse(source_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
namm_class = None
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.ClassDef) and node.name == "NaMMRotaryEmbedding3d":
|
||||||
|
namm_class = node
|
||||||
|
break
|
||||||
|
assert namm_class is not None, (
|
||||||
|
f"could not locate class NaMMRotaryEmbedding3d in {source_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
forward_fn = None
|
||||||
|
for node in namm_class.body:
|
||||||
|
if isinstance(node, ast.FunctionDef) and node.name == "forward":
|
||||||
|
forward_fn = node
|
||||||
|
break
|
||||||
|
assert forward_fn is not None, (
|
||||||
|
"could not locate NaMMRotaryEmbedding3d.forward"
|
||||||
|
)
|
||||||
|
|
||||||
|
offending = []
|
||||||
|
for node in ast.walk(forward_fn):
|
||||||
|
if isinstance(node, ast.Name) and node.id == "apply_rotary_emb":
|
||||||
|
offending.append((node.lineno, node.col_offset))
|
||||||
|
|
||||||
|
assert not offending, (
|
||||||
|
f"NaMMRotaryEmbedding3d.forward must not reference apply_rotary_emb; "
|
||||||
|
f"found {len(offending)} reference(s) at line:col positions {offending}. "
|
||||||
|
f"The rewrite must rewire to apply_rope1 directly."
|
||||||
|
)
|
||||||
37
tests-unit/comfy_test/test_seedvr_vae_attention_fence.py
Normal file
37
tests-unit/comfy_test/test_seedvr_vae_attention_fence.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.vae as seedvr_vae
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr_vae_4d_self_attention_uses_vae_attention_with_channel_first_layout():
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
def vae_attention_spy(q, k, v):
|
||||||
|
calls["q"] = q.detach().clone()
|
||||||
|
calls["k"] = k.detach().clone()
|
||||||
|
calls["v"] = v.detach().clone()
|
||||||
|
return q
|
||||||
|
|
||||||
|
def global_attention_forbidden(*args, **kwargs):
|
||||||
|
raise AssertionError("SeedVR2 VAE self-attention must not use global optimized_attention")
|
||||||
|
|
||||||
|
with patch.object(seedvr_vae, "vae_attention", return_value=vae_attention_spy):
|
||||||
|
attention = seedvr_vae.Attention(query_dim=4, heads=1, dim_head=4)
|
||||||
|
|
||||||
|
attention.to_q = nn.Identity()
|
||||||
|
attention.to_k = nn.Identity()
|
||||||
|
attention.to_v = nn.Identity()
|
||||||
|
attention.to_out[0] = nn.Identity()
|
||||||
|
|
||||||
|
hidden_states = torch.arange(24, dtype=torch.float32).reshape(1, 4, 2, 3)
|
||||||
|
|
||||||
|
with patch.object(seedvr_vae, "optimized_attention", global_attention_forbidden):
|
||||||
|
output = attention(hidden_states)
|
||||||
|
|
||||||
|
assert torch.equal(calls["q"], hidden_states)
|
||||||
|
assert torch.equal(calls["k"], hidden_states)
|
||||||
|
assert torch.equal(calls["v"], hidden_states)
|
||||||
|
assert torch.equal(output, hidden_states)
|
||||||
476
tests-unit/comfy_test/test_seedvr_var_attention_backends.py
Normal file
476
tests-unit/comfy_test/test_seedvr_var_attention_backends.py
Normal file
@ -0,0 +1,476 @@
|
|||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.modules.attention as attention # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
_VAR_BACKENDS = (
|
||||||
|
"var_attention_sage",
|
||||||
|
"var_attention_sage3",
|
||||||
|
"var_attention_flash",
|
||||||
|
"var_attention_flash3",
|
||||||
|
"var_attention_sub_quad",
|
||||||
|
"var_attention_split",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _inputs():
|
||||||
|
heads = 2
|
||||||
|
head_dim = 4
|
||||||
|
total = 6
|
||||||
|
q = torch.randn(total, heads, head_dim)
|
||||||
|
k = torch.randn(total, heads, head_dim)
|
||||||
|
v = torch.randn(total, heads, head_dim)
|
||||||
|
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
|
||||||
|
return q, k, v, heads, cu
|
||||||
|
|
||||||
|
|
||||||
|
def _has_dynamo_disable(decorator):
|
||||||
|
return (
|
||||||
|
isinstance(decorator, ast.Attribute)
|
||||||
|
and decorator.attr == "disable"
|
||||||
|
and isinstance(decorator.value, ast.Attribute)
|
||||||
|
and decorator.value.attr == "_dynamo"
|
||||||
|
and isinstance(decorator.value.value, ast.Name)
|
||||||
|
and decorator.value.value.id == "torch"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_backend_functions_are_dynamo_disabled_and_signature_compatible():
|
||||||
|
tree = ast.parse(inspect.getsource(attention))
|
||||||
|
functions = {node.name: node for node in tree.body if isinstance(node, ast.FunctionDef)}
|
||||||
|
|
||||||
|
for name in _VAR_BACKENDS:
|
||||||
|
node = functions[name]
|
||||||
|
positional = [arg.arg for arg in node.args.args[:6]]
|
||||||
|
keyword_only = {arg.arg for arg in node.args.kwonlyargs}
|
||||||
|
assert positional == ["q", "k", "v", "heads", "cu_seqlens_q", "cu_seqlens_k"]
|
||||||
|
assert node.args.vararg is not None
|
||||||
|
assert node.args.kwarg is not None
|
||||||
|
assert "skip_reshape" in keyword_only
|
||||||
|
assert "skip_output_reshape" in keyword_only
|
||||||
|
assert any(_has_dynamo_disable(decorator) for decorator in node.decorator_list)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_registry_contains_always_available_entries():
|
||||||
|
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
|
||||||
|
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
|
||||||
|
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
|
||||||
|
|
||||||
|
|
||||||
|
def _run_attention_import(flag, fake_modules=True, fake_module_code=None):
|
||||||
|
argv = ["pytest-subprocess", "--cpu", "--disable-xformers"]
|
||||||
|
if flag:
|
||||||
|
argv.append(flag)
|
||||||
|
if fake_module_code is None:
|
||||||
|
fake_module_code = ""
|
||||||
|
if fake_modules and not fake_module_code:
|
||||||
|
fake_module_code = """
|
||||||
|
import types
|
||||||
|
|
||||||
|
sageattention = types.ModuleType("sageattention")
|
||||||
|
sageattention.sageattn = lambda *a, **k: a[0]
|
||||||
|
sageattention.sageattn_varlen = lambda *a, **k: a[0]
|
||||||
|
sys.modules["sageattention"] = sageattention
|
||||||
|
|
||||||
|
sageattn3 = types.ModuleType("sageattn3")
|
||||||
|
sageattn3.sageattn3_blackwell = lambda *a, **k: a[0]
|
||||||
|
sys.modules["sageattn3"] = sageattn3
|
||||||
|
|
||||||
|
flash_attn = types.ModuleType("flash_attn")
|
||||||
|
flash_attn.flash_attn_func = lambda q, k, v, **kwargs: q
|
||||||
|
flash_attn.flash_attn_varlen_func = lambda **kwargs: kwargs["q"]
|
||||||
|
sys.modules["flash_attn"] = flash_attn
|
||||||
|
|
||||||
|
flash_attn_interface = types.ModuleType("flash_attn_interface")
|
||||||
|
flash_attn_interface.flash_attn_varlen_func = lambda **kwargs: (kwargs["q"], None)
|
||||||
|
sys.modules["flash_attn_interface"] = flash_attn_interface
|
||||||
|
"""
|
||||||
|
code = (
|
||||||
|
"import sys\n"
|
||||||
|
"import comfy.options\n"
|
||||||
|
"comfy.options.enable_args_parsing()\n"
|
||||||
|
f"sys.argv = {argv!r}\n"
|
||||||
|
f"{textwrap.dedent(fake_module_code)}\n"
|
||||||
|
"import comfy.ldm.modules.attention as attention\n"
|
||||||
|
"print(attention.optimized_var_attention.__name__)\n"
|
||||||
|
)
|
||||||
|
return subprocess.run(
|
||||||
|
[sys.executable, "-c", code],
|
||||||
|
cwd=".",
|
||||||
|
text=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_rebind_sage_launch_flag():
|
||||||
|
result = _run_attention_import("--use-sage-attention")
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert result.stdout.strip() == "var_attention_sage"
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_rebind_flash_launch_flag_uses_pytorch_varlen_in_cpu_mode():
|
||||||
|
result = _run_attention_import("--use-flash-attention")
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert result.stdout.strip() == "var_attention_pytorch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_rebind_sage_launch_flag_without_varlen_uses_pytorch():
|
||||||
|
result = _run_attention_import(
|
||||||
|
"--use-sage-attention",
|
||||||
|
fake_module_code="""
|
||||||
|
import types
|
||||||
|
|
||||||
|
sageattention = types.ModuleType("sageattention")
|
||||||
|
sageattention.sageattn = lambda *a, **k: a[0]
|
||||||
|
sys.modules["sageattention"] = sageattention
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert result.stdout.strip() == "var_attention_pytorch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_rebind_flash_launch_flag_without_varlen_uses_pytorch():
|
||||||
|
result = _run_attention_import(
|
||||||
|
"--use-flash-attention",
|
||||||
|
fake_module_code="""
|
||||||
|
import types
|
||||||
|
|
||||||
|
flash_attn = types.ModuleType("flash_attn")
|
||||||
|
flash_attn.flash_attn_func = lambda q, k, v, **kwargs: q
|
||||||
|
sys.modules["flash_attn"] = flash_attn
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert result.stdout.strip() == "var_attention_pytorch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_rebind_pytorch_launch_flag():
|
||||||
|
result = _run_attention_import("--use-pytorch-cross-attention")
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert result.stdout.strip() == "var_attention_pytorch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_rebind_split_launch_flag():
|
||||||
|
result = _run_attention_import("--use-split-cross-attention")
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert result.stdout.strip() == "var_attention_split"
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_rebind_default_launch_flags():
|
||||||
|
result = _run_attention_import("")
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert result.stdout.strip() == "var_attention_sub_quad"
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_sage_uses_cu_seqlens_contract(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_sageattn_varlen(q, k, v, cu_q, cu_k, max_q, max_k, is_causal, sm_scale):
|
||||||
|
captured.update(cu_q=cu_q, cu_k=cu_k, max_q=max_q, max_k=max_k, is_causal=is_causal)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(attention, "sageattn_varlen", fake_sageattn_varlen, raising=False)
|
||||||
|
|
||||||
|
out = attention.var_attention_sage(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_q"], cu)
|
||||||
|
assert torch.equal(captured["cu_k"], cu)
|
||||||
|
assert captured["max_q"] == 3
|
||||||
|
assert captured["max_k"] == 3
|
||||||
|
assert captured["is_causal"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_sage_runtime_error_preserves_fallback_dtype(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
q = q.float()
|
||||||
|
k = k.half()
|
||||||
|
v = v.half()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def failing_sageattn_varlen(*args, **kwargs):
|
||||||
|
raise RuntimeError("unsupported")
|
||||||
|
|
||||||
|
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
captured.update(dtype=q.dtype, k_dtype=k.dtype, v_dtype=v.dtype, skip_reshape=skip_reshape)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(attention, "sageattn_varlen", failing_sageattn_varlen, raising=False)
|
||||||
|
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
|
||||||
|
|
||||||
|
out = attention.var_attention_sage(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert out.dtype == torch.float32
|
||||||
|
assert captured["dtype"] == torch.float32
|
||||||
|
assert captured["k_dtype"] == torch.float32
|
||||||
|
assert captured["v_dtype"] == torch.float32
|
||||||
|
assert captured["skip_reshape"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_sage3_uses_cu_seqlens_contract(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_sageattn3_blackwell(q, k, v, is_causal=False):
|
||||||
|
captured.update(shape=tuple(q.shape), is_causal=is_causal)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "SAGE_ATTENTION3_IS_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(attention, "sageattn3_blackwell", fake_sageattn3_blackwell, raising=False)
|
||||||
|
|
||||||
|
out = attention.var_attention_sage3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert captured["shape"] == (2, heads, 3, 4)
|
||||||
|
assert captured["is_causal"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_sage3_runtime_error_falls_back(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
q = q.float()
|
||||||
|
k = k.half()
|
||||||
|
v = v.half()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def failing_sageattn3_blackwell(*args, **kwargs):
|
||||||
|
raise RuntimeError("unsupported")
|
||||||
|
|
||||||
|
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
captured.update(cu_q=cu_seqlens_q, dtype=q.dtype, k_dtype=k.dtype, v_dtype=v.dtype, skip_reshape=skip_reshape)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "SAGE_ATTENTION_VARLEN_IS_AVAILABLE", False)
|
||||||
|
monkeypatch.setattr(attention, "SAGE_ATTENTION3_IS_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(attention, "sageattn3_blackwell", failing_sageattn3_blackwell, raising=False)
|
||||||
|
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
|
||||||
|
|
||||||
|
out = attention.var_attention_sage3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_q"], cu)
|
||||||
|
assert captured["dtype"] == torch.float32
|
||||||
|
assert captured["k_dtype"] == torch.float32
|
||||||
|
assert captured["v_dtype"] == torch.float32
|
||||||
|
assert captured["skip_reshape"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_flash_uses_cu_seqlens_contract(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_flash_attn_varlen_func(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return torch.zeros_like(kwargs["q"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "FLASH_ATTENTION_VARLEN_IS_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(attention, "flash_attn_varlen_func", fake_flash_attn_varlen_func, raising=False)
|
||||||
|
|
||||||
|
out = attention.var_attention_flash(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_seqlens_q"], cu)
|
||||||
|
assert torch.equal(captured["cu_seqlens_k"], cu)
|
||||||
|
assert captured["max_seqlen_q"] == 3
|
||||||
|
assert captured["max_seqlen_k"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_flash_runtime_error_falls_back(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def failing_flash_attn_varlen_func(**kwargs):
|
||||||
|
raise NotImplementedError("cpu")
|
||||||
|
|
||||||
|
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
captured.update(cu_q=cu_seqlens_q, skip_reshape=skip_reshape)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "FLASH_ATTENTION_VARLEN_IS_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(attention, "flash_attn_varlen_func", failing_flash_attn_varlen_func, raising=False)
|
||||||
|
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
|
||||||
|
|
||||||
|
out = attention.var_attention_flash(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_q"], cu)
|
||||||
|
assert captured["skip_reshape"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_flash3_uses_cu_seqlens_contract(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_flash_attn3_varlen_func(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return torch.zeros_like(kwargs["q"]), None
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "flash_attn3_varlen_func", fake_flash_attn3_varlen_func, raising=False)
|
||||||
|
monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True)
|
||||||
|
|
||||||
|
out = attention.var_attention_flash3(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu,
|
||||||
|
cu,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=True,
|
||||||
|
dropout_p=0.25,
|
||||||
|
window_size=(16, 16),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_seqlens_q"], cu)
|
||||||
|
assert torch.equal(captured["cu_seqlens_k"], cu)
|
||||||
|
assert captured["max_seqlen_q"] == 3
|
||||||
|
assert captured["max_seqlen_k"] == 3
|
||||||
|
assert captured["seqused_q"] is None
|
||||||
|
assert captured["seqused_k"] is None
|
||||||
|
assert "dropout_p" not in captured
|
||||||
|
assert "window_size" not in captured
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_flash3_accepts_tensor_return(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
|
||||||
|
def fake_flash_attn3_varlen_func(**kwargs):
|
||||||
|
return torch.zeros_like(kwargs["q"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "flash_attn3_varlen_func", fake_flash_attn3_varlen_func, raising=False)
|
||||||
|
monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True)
|
||||||
|
|
||||||
|
out = attention.var_attention_flash3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_flash3_runtime_error_falls_back(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def failing_flash_attn3_varlen_func(**kwargs):
|
||||||
|
raise RuntimeError("unsupported")
|
||||||
|
|
||||||
|
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
captured.update(cu_q=cu_seqlens_q, skip_reshape=skip_reshape)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "FLASH_ATTENTION3_IS_AVAILABLE", True)
|
||||||
|
monkeypatch.setattr(attention, "flash_attn3_varlen_func", failing_flash_attn3_varlen_func, raising=False)
|
||||||
|
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
|
||||||
|
|
||||||
|
out = attention.var_attention_flash3(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_q"], cu)
|
||||||
|
assert captured["skip_reshape"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_sub_quad_uses_cu_seqlens_contract(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
captured.update(cu_q=cu_seqlens_q, cu_k=cu_seqlens_k, skip_reshape=skip_reshape)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "var_attention_pytorch", fake_var_attention_pytorch)
|
||||||
|
|
||||||
|
out = attention.var_attention_sub_quad(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_q"], cu)
|
||||||
|
assert torch.equal(captured["cu_k"], cu)
|
||||||
|
assert captured["skip_reshape"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_split_uses_cu_seqlens_contract(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
|
||||||
|
captured.update(cu_q=cu_seqlens_q, cu_k=cu_seqlens_k, skip_reshape=skip_reshape)
|
||||||
|
return torch.zeros_like(q)
|
||||||
|
|
||||||
|
def fail_var_attention_pytorch(*args, **kwargs):
|
||||||
|
raise AssertionError("split backend must not use nested-tensor pytorch var attention")
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "var_attention_pytorch", fail_var_attention_pytorch)
|
||||||
|
monkeypatch.setattr(attention, "var_attention_pytorch_split", fake_var_attention_pytorch_split)
|
||||||
|
|
||||||
|
out = attention.var_attention_split(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert torch.equal(captured["cu_q"], cu)
|
||||||
|
assert torch.equal(captured["cu_k"], cu)
|
||||||
|
assert captured["skip_reshape"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_pytorch_split_normalizes_split_indices_to_cpu(monkeypatch):
|
||||||
|
q, k, v, heads, cu = _inputs()
|
||||||
|
captured_devices = []
|
||||||
|
real_tensor_split = torch.tensor_split
|
||||||
|
|
||||||
|
def capture_tensor_split(input, indices_or_sections, dim=0):
|
||||||
|
if isinstance(indices_or_sections, torch.Tensor):
|
||||||
|
captured_devices.append(indices_or_sections.device.type)
|
||||||
|
return real_tensor_split(input, indices_or_sections, dim=dim)
|
||||||
|
|
||||||
|
monkeypatch.setattr(torch, "tensor_split", capture_tensor_split)
|
||||||
|
|
||||||
|
out = attention.var_attention_pytorch_split(q, k, v, heads, cu, cu, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == tuple(q.shape)
|
||||||
|
assert captured_devices == ["cpu", "cpu", "cpu"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_sage_package_guard_message_preserved():
|
||||||
|
code = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
import builtins
|
||||||
|
import sys
|
||||||
|
import comfy.options
|
||||||
|
|
||||||
|
comfy.options.enable_args_parsing()
|
||||||
|
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
def blocked_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
if name == "sageattention":
|
||||||
|
raise ImportError("No module named sageattention", name="sageattention")
|
||||||
|
return real_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
builtins.__import__ = blocked_import
|
||||||
|
sys.argv = ["pytest-subprocess", "--cpu", "--disable-xformers", "--use-sage-attention"]
|
||||||
|
import comfy.ldm.modules.attention
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, "-c", code],
|
||||||
|
cwd=".",
|
||||||
|
text=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode != 0
|
||||||
|
assert "To use the `--use-sage-attention` feature" in result.stderr
|
||||||
|
assert "sageattention" in result.stderr
|
||||||
@ -0,0 +1,167 @@
|
|||||||
|
"""Regression tests for the SeedVR2-named guard inside
|
||||||
|
``comfy.ldm.modules.attention.var_attention_pytorch``.
|
||||||
|
|
||||||
|
Contract:
|
||||||
|
|
||||||
|
* If ``torch.nested.nested_tensor_from_jagged`` is unavailable on the
|
||||||
|
installed PyTorch build, ``var_attention_pytorch`` must raise
|
||||||
|
``RuntimeError`` whose message contains both ``SeedVR2`` and
|
||||||
|
``nested_tensor_from_jagged`` so the operator can identify the
|
||||||
|
failing attention path. A bare ``AttributeError`` from the
|
||||||
|
``torch.nested`` lookup is non-conformant. The guard must also
|
||||||
|
cover the case where the ``torch.nested`` namespace itself is
|
||||||
|
absent (e.g. forks/builds that strip the module) — accessing
|
||||||
|
``torch.nested`` directly would otherwise raise the same opaque
|
||||||
|
``AttributeError`` the guard is meant to translate.
|
||||||
|
* If the API is present, the present-API path must produce the
|
||||||
|
canonical SeedVR2-inference output shape ``(total_tokens,
|
||||||
|
heads * head_dim)``.
|
||||||
|
* If the caller passes malformed offsets (off-end / non-monotonic /
|
||||||
|
size-mismatched), torch's own per-call ``RuntimeError`` propagates
|
||||||
|
unchanged: the SeedVR2-context guard fires only on the missing-API
|
||||||
|
path, never on torch's per-call shape errors.
|
||||||
|
|
||||||
|
Each cell additionally pins the production guard at the AST level via
|
||||||
|
``inspect.getsource(var_attention_pytorch)`` so every AC fails
|
||||||
|
diagnostically on an unguarded base.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import ast # noqa: E402
|
||||||
|
import inspect # noqa: E402
|
||||||
|
import logging # noqa: E402
|
||||||
|
import textwrap # noqa: E402
|
||||||
|
import warnings # noqa: E402
|
||||||
|
|
||||||
|
import pytest # noqa: E402
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _inputs():
|
||||||
|
"""Canonical 2-D ``(q, k, v, heads, cu_seqlens_q, cu_seqlens_k,
|
||||||
|
total_tokens, embed_dim)`` matching the live shape from GPT-3:
|
||||||
|
two segments of 3 tokens each, ``embed_dim = heads * head_dim =
|
||||||
|
2 * 8 = 16``.
|
||||||
|
"""
|
||||||
|
heads, head_dim, total_tokens = 2, 8, 6
|
||||||
|
embed_dim = heads * head_dim
|
||||||
|
q = torch.randn(total_tokens, embed_dim)
|
||||||
|
k = torch.randn(total_tokens, embed_dim)
|
||||||
|
v = torch.randn(total_tokens, embed_dim)
|
||||||
|
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
|
||||||
|
return q, k, v, heads, cu, cu, total_tokens, embed_dim
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_guard_source_pin():
|
||||||
|
"""Walk the AST of ``var_attention_pytorch`` and assert that the
|
||||||
|
first ``raise RuntimeError(...)`` statement appears strictly
|
||||||
|
before any attribute access named ``nested_tensor_from_jagged``.
|
||||||
|
|
||||||
|
Substring-based source pinning (``src.index('raise RuntimeError(')
|
||||||
|
< src.index('nested_tensor_from_jagged')``) is fragile: it false-
|
||||||
|
positives on docstring or comment text containing the literal,
|
||||||
|
and false-negatives on a refactor that splits ``raise
|
||||||
|
RuntimeError(`` across lines or replaces it with a helper
|
||||||
|
raising ``RuntimeError`` from another scope. AST-walking the
|
||||||
|
function body collapses both failure modes onto the only
|
||||||
|
invariant we actually require — the guard statement dominates
|
||||||
|
the attribute access by line number.
|
||||||
|
"""
|
||||||
|
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
|
||||||
|
tree = ast.parse(src)
|
||||||
|
raise_lines = []
|
||||||
|
nested_lines = []
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
|
||||||
|
func = node.exc.func
|
||||||
|
if isinstance(func, ast.Name) and func.id == "RuntimeError":
|
||||||
|
raise_lines.append(node.lineno)
|
||||||
|
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
|
||||||
|
nested_lines.append(node.lineno)
|
||||||
|
assert raise_lines, (
|
||||||
|
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
|
||||||
|
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
|
||||||
|
)
|
||||||
|
assert nested_lines, (
|
||||||
|
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
|
||||||
|
f"attribute access; cannot pin guard ordering.\n"
|
||||||
|
f"--- source ---\n{src}"
|
||||||
|
)
|
||||||
|
first_raise = min(raise_lines)
|
||||||
|
first_nested = min(nested_lines)
|
||||||
|
assert first_raise < first_nested, (
|
||||||
|
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
|
||||||
|
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
|
||||||
|
f"at line {first_nested}; the guard must precede the lookup.\n"
|
||||||
|
f"--- source ---\n{src}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
|
||||||
|
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
|
||||||
|
q, k, v, heads, cu_q, cu_k, _, _ = _inputs()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
||||||
|
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
||||||
|
|
||||||
|
_assert_guard_source_pin()
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
|
||||||
|
monkeypatch.delattr(torch, "nested", raising=False)
|
||||||
|
q, k, v, heads, cu_q, cu_k, _, _ = _inputs()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
|
||||||
|
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
||||||
|
|
||||||
|
_assert_guard_source_pin()
|
||||||
|
|
||||||
|
|
||||||
|
def test_present_api_returns_expected_shape():
|
||||||
|
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _inputs()
|
||||||
|
|
||||||
|
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
|
||||||
|
old_torch_fx_level = torch_fx_logger.level
|
||||||
|
torch_fx_logger.setLevel(logging.ERROR)
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message="The PyTorch API of nested tensors is in prototype stage.*",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
|
||||||
|
finally:
|
||||||
|
torch_fx_logger.setLevel(old_torch_fx_level)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == (total_tokens, embed_dim), (
|
||||||
|
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_assert_guard_source_pin()
|
||||||
|
|
||||||
|
|
||||||
|
def test_malformed_offsets_propagates_torch_runtime_error():
|
||||||
|
q, k, v, heads, _, _, _, _ = _inputs()
|
||||||
|
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
|
||||||
|
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
|
||||||
|
|
||||||
|
msg = str(exc_info.value)
|
||||||
|
assert "split_with_sizes" in msg, (
|
||||||
|
f"expected torch's `split_with_sizes` error to propagate; got: {msg!r}"
|
||||||
|
)
|
||||||
|
assert "SeedVR2" not in msg, (
|
||||||
|
f"SeedVR2-context substring must not be substituted onto torch's "
|
||||||
|
f"per-call shape error; got: {msg!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_assert_guard_source_pin()
|
||||||
Loading…
Reference in New Issue
Block a user