mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
193 lines
7.8 KiB
Python
193 lines
7.8 KiB
Python
"""Regression tests for SeedVR2 conditioning split hardening.
|
|
|
|
Two bare ``except:`` clauses in ``NaDiT.forward`` previously swallowed
|
|
every failure mode on (1) the input-side text-conditioning split and
|
|
(2) the output-side positive/negative split, silently substituting
|
|
wrong fallbacks: the ``positive_conditioning`` buffer (which prior to
|
|
explicit zero-init held **uninitialized** memory — NaN, residual heap
|
|
contents, never guaranteed-zero) for the input, and the un-split
|
|
tensor for the output. Real prompt-shape, dtype, OOM, and downstream
|
|
tensor failures were re-routed to "no prompt supplied" with arbitrary
|
|
buffer contents standing in for actual prompt embeddings, or to a
|
|
wrong-order output, with no diagnostic.
|
|
|
|
The fix:
|
|
|
|
1. Input-side: explicit absence predicate (``context is None`` or
|
|
``context.numel() == 0``) → fall back to ``positive_conditioning``
|
|
buffer. Any other failure (wrong rank, odd batch, dtype, OOM)
|
|
propagates the original torch exception.
|
|
2. Output-side: no try/except at all. ``out.chunk(2)`` of the
|
|
network output is a contract: an unsplittable result is a bug,
|
|
not a recoverable condition.
|
|
|
|
The two blocks were extracted into named private methods on
|
|
``NaDiT`` (``_resolve_text_conditioning`` and ``_swap_pos_neg_halves``)
|
|
so the regression evidence drives the actual production code paths
|
|
without standing up a full transformer. The methods are called from
|
|
``forward`` exactly where the original try/except blocks lived.
|
|
"""
|
|
|
|
from comfy.cli_args import args
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
args.cpu = True
|
|
|
|
import ast # noqa: E402
|
|
import inspect # noqa: E402
|
|
import textwrap # noqa: E402
|
|
|
|
import pytest # noqa: E402
|
|
|
|
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
|
|
|
|
|
def _make_standin(positive_conditioning):
|
|
class _StandIn(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"positive_conditioning", positive_conditioning
|
|
)
|
|
|
|
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
|
|
_swap_pos_neg_halves = NaDiT._swap_pos_neg_halves
|
|
|
|
return _StandIn()
|
|
|
|
|
|
def test_no_bare_except_in_forward_path():
|
|
"""Source-level pin: neither ``NaDiT.forward`` nor its split helpers
|
|
may carry the bare ``except:`` clauses that swallowed real torch
|
|
failures on the SeedVR2 conditioning paths. AST-walked rather than
|
|
substring-matched so that ``except:`` appearing in a docstring or
|
|
comment does not false-positive, and so that ``except Exception:``
|
|
(a typed handler, fine to have) does not false-negative.
|
|
"""
|
|
sources = [
|
|
inspect.getsource(NaDiT.forward),
|
|
inspect.getsource(NaDiT._resolve_text_conditioning),
|
|
inspect.getsource(NaDiT._swap_pos_neg_halves),
|
|
]
|
|
for src in sources:
|
|
tree = ast.parse(textwrap.dedent(src))
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.ExceptHandler):
|
|
assert node.type is not None, (
|
|
"Bare 'except:' (ast.ExceptHandler with type=None) "
|
|
f"must not appear on the SeedVR2 forward path:\n{src}"
|
|
)
|
|
|
|
|
|
def test_valid_context_splits_pos_neg():
|
|
"""AC: valid (neg, pos)-stacked context (shape ``(2, L, C)``)
|
|
produces a flattened ``[pos, neg]`` text tensor — first ``L`` rows
|
|
are positive, next ``L`` rows are negative — matching the original
|
|
semantics of the ``flatten([pos_cond, neg_cond])`` call.
|
|
"""
|
|
pos_buffer = torch.zeros((58, 5120))
|
|
standin = _make_standin(pos_buffer)
|
|
seq_len, channels = 7, 5120
|
|
neg = torch.full((1, seq_len, channels), -1.0)
|
|
pos = torch.full((1, seq_len, channels), 1.0)
|
|
context = torch.cat([neg, pos], dim=0)
|
|
txt, txt_shape = standin._resolve_text_conditioning(context)
|
|
assert txt.shape == (2 * seq_len, channels)
|
|
assert (txt[:seq_len] == 1.0).all(), "first half must be positive cond"
|
|
assert (txt[seq_len:] == -1.0).all(), "second half must be negative cond"
|
|
assert txt_shape.shape == (2, 1)
|
|
assert txt_shape[0].item() == seq_len
|
|
assert txt_shape[1].item() == seq_len
|
|
|
|
|
|
def test_missing_context_falls_back_to_positive_buffer():
|
|
"""AC: ``context is None`` falls back to the registered
|
|
``positive_conditioning`` buffer and runs to completion — no
|
|
silent zero substitution, no raised exception.
|
|
"""
|
|
pos_buffer = torch.full((58, 5120), 7.0)
|
|
standin = _make_standin(pos_buffer)
|
|
txt, txt_shape = standin._resolve_text_conditioning(None)
|
|
assert txt.shape == (58, 5120)
|
|
assert (txt == 7.0).all(), (
|
|
"fallback path must use the positive_conditioning buffer "
|
|
"verbatim, not a zero tensor"
|
|
)
|
|
assert txt_shape.shape == (1, 1)
|
|
assert txt_shape[0, 0].item() == 58
|
|
|
|
|
|
def test_empty_context_falls_back_to_positive_buffer():
|
|
"""AC: ``context.numel() == 0`` falls back to the registered
|
|
``positive_conditioning`` buffer and runs to completion.
|
|
"""
|
|
pos_buffer = torch.full((58, 5120), 13.0)
|
|
standin = _make_standin(pos_buffer)
|
|
empty = torch.empty((0, 5120))
|
|
assert empty.numel() == 0
|
|
txt, txt_shape = standin._resolve_text_conditioning(empty)
|
|
assert txt.shape == (58, 5120)
|
|
assert (txt == 13.0).all()
|
|
assert txt_shape.shape == (1, 1)
|
|
assert txt_shape[0, 0].item() == 58
|
|
|
|
|
|
def test_wrong_rank_context_raises_original_torch_exception():
|
|
"""AC: a 1-D context tensor cannot be split into ``[pos, neg]``
|
|
via the ``chunk + squeeze + flatten`` chain; the original torch
|
|
exception must propagate rather than silently falling back.
|
|
"""
|
|
pos_buffer = torch.zeros((58, 5120))
|
|
standin = _make_standin(pos_buffer)
|
|
bad = torch.zeros(10)
|
|
with pytest.raises((RuntimeError, IndexError, ValueError)):
|
|
standin._resolve_text_conditioning(bad)
|
|
|
|
|
|
def test_odd_batch_context_raises_original_exception():
|
|
"""AC: a context whose batch dim cannot be split into two equal
|
|
chunks (here batch=1 so ``chunk(2, dim=0)`` returns a single
|
|
tensor) must propagate the original exception — no silent fallback.
|
|
"""
|
|
pos_buffer = torch.zeros((58, 5120))
|
|
standin = _make_standin(pos_buffer)
|
|
bad = torch.zeros((1, 7, 5120))
|
|
with pytest.raises((RuntimeError, ValueError)):
|
|
standin._resolve_text_conditioning(bad)
|
|
|
|
|
|
def test_output_side_misshaped_tensor_raises():
|
|
"""AC: the post-network output split must raise on an unsplittable
|
|
tensor (no silent return of the un-split tensor in the wrong
|
|
order/shape). Here a batch=1 tensor cannot be ``chunk(2, dim=0)``
|
|
into two halves; ``pos, neg = out.chunk(2, dim=0)`` raises on
|
|
unpacking — matching the production helper's explicit-dim contract
|
|
(``_swap_pos_neg_halves`` calls ``chunk(2, dim=0)`` and
|
|
``torch.cat(..., dim=0)``).
|
|
"""
|
|
pos_buffer = torch.zeros((58, 5120))
|
|
standin = _make_standin(pos_buffer)
|
|
bad_out = torch.zeros((1, 4, 8, 8))
|
|
with pytest.raises((RuntimeError, ValueError)):
|
|
standin._swap_pos_neg_halves(bad_out)
|
|
|
|
|
|
def test_output_side_swaps_pos_neg_halves():
|
|
"""AC complement: ``_swap_pos_neg_halves`` reorders the post-network
|
|
output so the first half (positive) and second half (negative) trade
|
|
places. For a 2-batch tensor with distinguishable halves, the
|
|
returned tensor must be the swap — first half becomes negative,
|
|
second half becomes positive — matching the original
|
|
``torch.cat([neg, pos])`` semantics from the pre-fix forward path.
|
|
"""
|
|
pos_buffer = torch.zeros((58, 5120))
|
|
standin = _make_standin(pos_buffer)
|
|
pos_half = torch.full((1, 4, 8, 8), 1.0)
|
|
neg_half = torch.full((1, 4, 8, 8), -1.0)
|
|
out = torch.cat([pos_half, neg_half], dim=0)
|
|
swapped = standin._swap_pos_neg_halves(out)
|
|
assert swapped.shape == out.shape
|
|
assert (swapped[0] == -1.0).all(), "first half of swapped output must be the original negative half"
|
|
assert (swapped[1] == 1.0).all(), "second half of swapped output must be the original positive half"
|