ComfyUI/tests-unit/comfy_test/seedvr_model_test.py
2026-05-26 00:28:29 -05:00

193 lines
7.8 KiB
Python

"""Regression tests for SeedVR2 conditioning split hardening.
Two bare ``except:`` clauses in ``NaDiT.forward`` previously swallowed
every failure mode on (1) the input-side text-conditioning split and
(2) the output-side positive/negative split, silently substituting
wrong fallbacks: the ``positive_conditioning`` buffer (which prior to
explicit zero-init held **uninitialized** memory — NaN, residual heap
contents, never guaranteed-zero) for the input, and the un-split
tensor for the output. Real prompt-shape, dtype, OOM, and downstream
tensor failures were re-routed to "no prompt supplied" with arbitrary
buffer contents standing in for actual prompt embeddings, or to a
wrong-order output, with no diagnostic.
The fix:
1. Input-side: explicit absence predicate (``context is None`` or
``context.numel() == 0``) → fall back to ``positive_conditioning``
buffer. Any other failure (wrong rank, odd batch, dtype, OOM)
propagates the original torch exception.
2. Output-side: no try/except at all. ``out.chunk(2)`` of the
network output is a contract: an unsplittable result is a bug,
not a recoverable condition.
The two blocks were extracted into named private methods on
``NaDiT`` (``_resolve_text_conditioning`` and ``_swap_pos_neg_halves``)
so the regression evidence drives the actual production code paths
without standing up a full transformer. The methods are called from
``forward`` exactly where the original try/except blocks lived.
"""
from comfy.cli_args import args
import torch
if not torch.cuda.is_available():
args.cpu = True
import ast # noqa: E402
import inspect # noqa: E402
import textwrap # noqa: E402
import pytest # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
def _make_standin(positive_conditioning):
class _StandIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"positive_conditioning", positive_conditioning
)
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
_swap_pos_neg_halves = NaDiT._swap_pos_neg_halves
return _StandIn()
def test_no_bare_except_in_forward_path():
"""Source-level pin: neither ``NaDiT.forward`` nor its split helpers
may carry the bare ``except:`` clauses that swallowed real torch
failures on the SeedVR2 conditioning paths. AST-walked rather than
substring-matched so that ``except:`` appearing in a docstring or
comment does not false-positive, and so that ``except Exception:``
(a typed handler, fine to have) does not false-negative.
"""
sources = [
inspect.getsource(NaDiT.forward),
inspect.getsource(NaDiT._resolve_text_conditioning),
inspect.getsource(NaDiT._swap_pos_neg_halves),
]
for src in sources:
tree = ast.parse(textwrap.dedent(src))
for node in ast.walk(tree):
if isinstance(node, ast.ExceptHandler):
assert node.type is not None, (
"Bare 'except:' (ast.ExceptHandler with type=None) "
f"must not appear on the SeedVR2 forward path:\n{src}"
)
def test_valid_context_splits_pos_neg():
"""AC: valid (neg, pos)-stacked context (shape ``(2, L, C)``)
produces a flattened ``[pos, neg]`` text tensor — first ``L`` rows
are positive, next ``L`` rows are negative — matching the original
semantics of the ``flatten([pos_cond, neg_cond])`` call.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
seq_len, channels = 7, 5120
neg = torch.full((1, seq_len, channels), -1.0)
pos = torch.full((1, seq_len, channels), 1.0)
context = torch.cat([neg, pos], dim=0)
txt, txt_shape = standin._resolve_text_conditioning(context)
assert txt.shape == (2 * seq_len, channels)
assert (txt[:seq_len] == 1.0).all(), "first half must be positive cond"
assert (txt[seq_len:] == -1.0).all(), "second half must be negative cond"
assert txt_shape.shape == (2, 1)
assert txt_shape[0].item() == seq_len
assert txt_shape[1].item() == seq_len
def test_missing_context_falls_back_to_positive_buffer():
"""AC: ``context is None`` falls back to the registered
``positive_conditioning`` buffer and runs to completion — no
silent zero substitution, no raised exception.
"""
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
assert txt.shape == (58, 5120)
assert (txt == 7.0).all(), (
"fallback path must use the positive_conditioning buffer "
"verbatim, not a zero tensor"
)
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
def test_empty_context_falls_back_to_positive_buffer():
"""AC: ``context.numel() == 0`` falls back to the registered
``positive_conditioning`` buffer and runs to completion.
"""
pos_buffer = torch.full((58, 5120), 13.0)
standin = _make_standin(pos_buffer)
empty = torch.empty((0, 5120))
assert empty.numel() == 0
txt, txt_shape = standin._resolve_text_conditioning(empty)
assert txt.shape == (58, 5120)
assert (txt == 13.0).all()
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
def test_wrong_rank_context_raises_original_torch_exception():
"""AC: a 1-D context tensor cannot be split into ``[pos, neg]``
via the ``chunk + squeeze + flatten`` chain; the original torch
exception must propagate rather than silently falling back.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
bad = torch.zeros(10)
with pytest.raises((RuntimeError, IndexError, ValueError)):
standin._resolve_text_conditioning(bad)
def test_odd_batch_context_raises_original_exception():
"""AC: a context whose batch dim cannot be split into two equal
chunks (here batch=1 so ``chunk(2, dim=0)`` returns a single
tensor) must propagate the original exception — no silent fallback.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
bad = torch.zeros((1, 7, 5120))
with pytest.raises((RuntimeError, ValueError)):
standin._resolve_text_conditioning(bad)
def test_output_side_misshaped_tensor_raises():
"""AC: the post-network output split must raise on an unsplittable
tensor (no silent return of the un-split tensor in the wrong
order/shape). Here a batch=1 tensor cannot be ``chunk(2, dim=0)``
into two halves; ``pos, neg = out.chunk(2, dim=0)`` raises on
unpacking — matching the production helper's explicit-dim contract
(``_swap_pos_neg_halves`` calls ``chunk(2, dim=0)`` and
``torch.cat(..., dim=0)``).
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
bad_out = torch.zeros((1, 4, 8, 8))
with pytest.raises((RuntimeError, ValueError)):
standin._swap_pos_neg_halves(bad_out)
def test_output_side_swaps_pos_neg_halves():
"""AC complement: ``_swap_pos_neg_halves`` reorders the post-network
output so the first half (positive) and second half (negative) trade
places. For a 2-batch tensor with distinguishable halves, the
returned tensor must be the swap — first half becomes negative,
second half becomes positive — matching the original
``torch.cat([neg, pos])`` semantics from the pre-fix forward path.
"""
pos_buffer = torch.zeros((58, 5120))
standin = _make_standin(pos_buffer)
pos_half = torch.full((1, 4, 8, 8), 1.0)
neg_half = torch.full((1, 4, 8, 8), -1.0)
out = torch.cat([pos_half, neg_half], dim=0)
swapped = standin._swap_pos_neg_halves(out)
assert swapped.shape == out.shape
assert (swapped[0] == -1.0).all(), "first half of swapped output must be the original negative half"
assert (swapped[1] == 1.0).all(), "second half of swapped output must be the original positive half"