mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
177 lines
7.4 KiB
Python
177 lines
7.4 KiB
Python
"""Regression test: ``comfy.ldm.seedvr.model.apply_rotary_emb`` must delegate
|
|
to ``comfy.ldm.flux.math.apply_rope1`` and produce exact-equality output
|
|
across the wrapper's slicing, scaling, and concatenation logic. Drift between
|
|
the wrapper and the delegate would silently corrupt SeedVR2's RoPE; this test
|
|
fails loudly on any future drift.
|
|
|
|
Each parametrized case does both:
|
|
|
|
1. Patches ``comfy.ldm.seedvr.model.apply_rope1`` with a ``wraps``-style spy
|
|
and asserts ``spy.call_count >= 1`` so a future change that inlines the
|
|
math and stops calling ``apply_rope1`` fails the test.
|
|
2. Compares the wrapper's output against a hand-rolled reproduction using
|
|
``torch.testing.assert_close(rtol=0, atol=0)`` -- exact tensor equality,
|
|
not bit-equality (``+0.0`` vs ``-0.0`` and NaN payloads can still match);
|
|
the assertion catches any future kernel-precision drift in the
|
|
``apply_rope1`` dispatch.
|
|
|
|
The test uses a local ``torch.Generator`` so global RNG state is not mutated.
|
|
Parametrization covers non-default ``start_index`` and ``scale`` and a case
|
|
where ``freqs.shape[0] > t.shape[seq_dim]`` so the wrapper's
|
|
``slice_at_dim(freqs, slice(-seq_len, None), dim=0)`` path is exercised.
|
|
Imports are taken at module level. Heavy-import stubbing of
|
|
``comfy.model_management`` was attempted but is insufficient on this live
|
|
import chain (``comfy.ldm.seedvr.model`` pulls
|
|
``comfy.ldm.modules.diffusionmodules.model -> comfy.ops ->
|
|
comfy.memory_management -> comfy.quant_ops -> comfy_kitchen.tensor ->
|
|
torch._dynamo``), so this test intentionally runs against the real modules
|
|
to fail loudly if that import path or runtime state drifts. Other tests in
|
|
this repo (e.g. ``tests-unit/comfy_extras_test/image_stitch_test.py``) do
|
|
stub via ``patch.dict(sys.modules, ...)`` for narrower targets; the choice
|
|
here is local to this regression and not a repo-wide convention.
|
|
"""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
# CPU-only CI fix: ``comfy.ldm.seedvr.model`` transitively imports
|
|
# ``comfy.model_management``, whose import-time ``get_torch_device()`` call
|
|
# probes ``torch.cuda.current_device()`` unless ``comfy.cli_args.args.cpu`` is
|
|
# set. On a CPU-only build that probe can raise during test collection before
|
|
# the ``cuda`` case has had a chance to be skipped. Match the pattern used by
|
|
# ``tests-unit/comfy_quant/test_mixed_precision.py``: flip ``args.cpu`` before
|
|
# importing any ``comfy.ldm.*`` symbol.
|
|
from comfy.cli_args import args
|
|
|
|
if not torch.cuda.is_available():
|
|
args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
|
from comfy.ldm.flux.math import apply_rope1 # noqa: E402
|
|
from comfy.ldm.seedvr.model import apply_rotary_emb # noqa: E402
|
|
|
|
|
|
def _direct_reproduction(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
|
"""Reproduce the body of ``apply_rotary_emb`` for the default case where
|
|
``freqs.ndim == 2`` and ``t.ndim == 3`` (implicit ``freqs_seq_dim=0``).
|
|
Mirrors the wrapper's ``slice_at_dim(freqs, slice(-seq_len, None), dim=0)``
|
|
step when freqs is longer than ``t`` along ``seq_dim``. Calls the real
|
|
``apply_rope1`` via the test module's import (the test patches the
|
|
``seedvr_model.apply_rope1`` attribute; this call uses the unpatched
|
|
``flux.math`` symbol).
|
|
"""
|
|
if freqs.ndim == 2 and t.ndim == 3:
|
|
seq_len = t.shape[seq_dim]
|
|
freqs = freqs[-seq_len:]
|
|
|
|
rot_feats = freqs.shape[-1]
|
|
end_index = start_index + rot_feats
|
|
t_left = t[..., :start_index]
|
|
t_middle = t[..., start_index:end_index]
|
|
t_right = t[..., end_index:]
|
|
angles = freqs.to(t_middle.device)[..., ::2]
|
|
cos = torch.cos(angles) * scale
|
|
sin = torch.sin(angles) * scale
|
|
col0 = torch.stack([cos, sin], dim=-1)
|
|
col1 = torch.stack([-sin, cos], dim=-1)
|
|
freqs_mat = torch.stack([col0, col1], dim=-1)
|
|
t_middle_out = apply_rope1(t_middle, freqs_mat)
|
|
return torch.cat((t_left, t_middle_out, t_right), dim=-1).type(t.dtype)
|
|
|
|
|
|
def _cpu_trig_supported(dtype):
|
|
"""Return whether ``torch.cos`` (and by symmetry ``torch.sin``) is
|
|
implemented for the given dtype on CPU on the current runtime. Some
|
|
PyTorch CPU wheels don't implement trig ops for ``float16`` / ``bfloat16``
|
|
and raise at runtime; the parametrized cases for those dtypes are skipped
|
|
when that's the case so CI remains stable across PyTorch builds.
|
|
"""
|
|
try:
|
|
torch.cos(torch.zeros(1, dtype=dtype))
|
|
except (RuntimeError, TypeError):
|
|
return False
|
|
return True
|
|
|
|
|
|
_CPU_FP16_TRIG_OK = _cpu_trig_supported(torch.float16)
|
|
_CPU_BF16_TRIG_OK = _cpu_trig_supported(torch.bfloat16)
|
|
|
|
|
|
# (device, dtype, t_shape, freqs_shape, start_index, scale)
|
|
_CASES = [
|
|
pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 1.0,
|
|
id="cpu-float32-base"),
|
|
pytest.param(
|
|
"cpu", torch.float16, (1, 8, 16), (8, 16), 0, 1.0,
|
|
id="cpu-float16-base",
|
|
marks=pytest.mark.skipif(
|
|
not _CPU_FP16_TRIG_OK,
|
|
reason="torch.cos/torch.sin unsupported for float16 tensors on CPU",
|
|
),
|
|
),
|
|
pytest.param(
|
|
"cpu", torch.bfloat16, (1, 8, 16), (8, 16), 0, 1.0,
|
|
id="cpu-bfloat16-base",
|
|
marks=pytest.mark.skipif(
|
|
not _CPU_BF16_TRIG_OK,
|
|
reason="torch.cos/torch.sin unsupported for bfloat16 tensors on CPU",
|
|
),
|
|
),
|
|
pytest.param("cpu", torch.float32, (2, 16, 32), (16, 32), 0, 1.0,
|
|
id="cpu-float32-larger"),
|
|
pytest.param("cpu", torch.float32, (1, 8, 24), (8, 16), 4, 1.0,
|
|
id="cpu-float32-non-empty-left-and-right-slices"),
|
|
pytest.param("cpu", torch.float32, (1, 8, 16), (8, 16), 0, 0.5,
|
|
id="cpu-float32-non-default-scale"),
|
|
pytest.param("cpu", torch.float32, (1, 8, 16), (12, 16), 0, 1.0,
|
|
id="cpu-float32-freqs-longer-than-seq"),
|
|
pytest.param(
|
|
"cuda", torch.float16, (1, 8, 16), (8, 16), 0, 1.0,
|
|
id="cuda-float16-base",
|
|
marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda"),
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("device,dtype,t_shape,freqs_shape,start_index,scale", _CASES)
|
|
def test_apply_rotary_emb_delegates_to_apply_rope1(
|
|
device, dtype, t_shape, freqs_shape, start_index, scale
|
|
):
|
|
generator = torch.Generator(device=device).manual_seed(0)
|
|
t = torch.randn(*t_shape, dtype=dtype, device=device, generator=generator)
|
|
freqs = torch.randn(*freqs_shape, dtype=dtype, device=device, generator=generator)
|
|
|
|
# Patch the apply_rope1 symbol as imported into seedvr.model with a wraps
|
|
# spy: a future change that inlines the math and stops calling the
|
|
# imported apply_rope1 makes spy.call_count == 0 and fails the test.
|
|
with patch.object(
|
|
seedvr_model, "apply_rope1", wraps=seedvr_model.apply_rope1
|
|
) as spy:
|
|
wrapper_out = apply_rotary_emb(
|
|
freqs, t, start_index=start_index, scale=scale
|
|
)
|
|
|
|
assert spy.call_count >= 1, (
|
|
"apply_rotary_emb did not call comfy.ldm.seedvr.model.apply_rope1; "
|
|
"the delegation invariant is broken"
|
|
)
|
|
|
|
direct_out = _direct_reproduction(
|
|
freqs, t, start_index=start_index, scale=scale
|
|
)
|
|
|
|
msg = (
|
|
f"apply_rotary_emb output does not match direct apply_rope1 "
|
|
f"reproduction (device={device}, dtype={dtype}, t_shape={t_shape}, "
|
|
f"freqs_shape={freqs_shape}, start_index={start_index}, scale={scale})"
|
|
)
|
|
torch.testing.assert_close(
|
|
wrapper_out,
|
|
direct_out,
|
|
rtol=0,
|
|
atol=0,
|
|
msg=msg,
|
|
)
|