ComfyUI/tests-unit/comfy_extras_test/test_seedvr_conditioning_hardening.py
2026-05-26 00:28:43 -05:00

602 lines
24 KiB
Python

"""Regression tests for SeedVR2 conditioning model resolution and RoPE
frequency cast.
Pin two behaviors:
1. ``_resolve_seedvr2_diffusion_model`` returns the inner diffusion-model
for the expected ``model.model.diffusion_model`` shape and fails loud
with a ``RuntimeError`` whose message begins with
``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` for any other shape, including
the four distinct missing-vs-None subcases of the chain.
2. ``_apply_rope_freqs_float32_cast`` is idempotent **per-tensor by
dtype check**, NOT per-instance by sentinel attribute. Every call
walks the diffusion-model module tree and invokes ``.to(float32)``
only on tensors whose dtype is not already ``float32``. A cache-by-
attribute (sentinel) approach is rejected because the sentinel
would survive ComfyUI's dynamic model unload/reload cycle while
``rope.freqs`` itself is restored to the archived dtype, so the
next call would short-circuit and leave RoPE running in fp16/bf16
— the exact failure this helper is supposed to prevent. The dtype
check is self-correcting against any weight-restore lifecycle
event.
Import isolation: ``comfy.model_management`` is stubbed via direct
``sys.modules`` assignment so importing ``comfy_extras.nodes_seedvr`` does
not trigger GPU/server-side initialization. ``patch.dict`` is intentionally
NOT used here because its snapshot/restore semantics evict transitively
imported third-party modules (e.g. ``torchvision``) on exit, which causes
``torch``'s global op-library Meta-key registrations to double-register on
re-import. Module-level cached import + scoped restore of the four mocked
entries avoids that hazard. See ``_import_nodes_seedvr_isolated``.
"""
import importlib
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
_SENTINEL = object()
def _import_nodes_seedvr_isolated():
"""Stub ``comfy.model_management``, import (or reuse a cached import of)
``comfy_extras.nodes_seedvr``, and return ``(module, restore)``.
``restore()`` snapshots and restores three in-process import-state
surfaces:
1. ``sys.modules["comfy.model_management"]`` — the stubbed module.
2. ``sys.modules["comfy_extras.nodes_seedvr"]`` — the imported test
target. If we leave this in ``sys.modules`` after the test, a
later test importing the real ``comfy_extras.nodes_seedvr`` will
get our stubbed-``comfy.model_management`` cached version, which
does not re-resolve against the real ``comfy.model_management``.
3. ``comfy_extras.nodes_seedvr`` package attribute on the
``comfy_extras`` package, mirroring the existing
``comfy.model_management`` attribute restore.
All three are restored verbatim if previously set; deleted on exit
if previously unset. No global state leaks into later tests.
"""
prior_comfy_mm = sys.modules.get("comfy.model_management", _SENTINEL)
prior_comfy_mm_attr = _SENTINEL
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None:
prior_comfy_mm_attr = getattr(comfy_pkg, "model_management", _SENTINEL)
prior_nodes_seedvr_module = sys.modules.get(
"comfy_extras.nodes_seedvr", _SENTINEL,
)
prior_nodes_seedvr_attr = _SENTINEL
comfy_extras_pkg = sys.modules.get("comfy_extras")
if comfy_extras_pkg is not None:
prior_nodes_seedvr_attr = getattr(
comfy_extras_pkg, "nodes_seedvr", _SENTINEL,
)
# ``comfy_extras.nodes_seedvr`` imports ``comfy.sample`` (added in PR
# #59) which pulls in the full samplers/k_diffusion/model_patcher
# transitive chain. That chain re-imports ``comfy.model_management``
# and calls feature-detection predicates like ``xformers_enabled()``
# in module-init code (``comfy/ldm/modules/attention.py:18``); a bare
# ``MagicMock()`` returns truthy for those calls and triggers a real
# ``import xformers`` that fails in the test environment. Pin the
# boolean-returning predicates to ``False`` so the import chain
# follows the no-extension path.
# Configure stub so every ``..._enabled[_*]()`` predicate returns
# False. The transitive import chain through ``comfy.sample`` → ...
# invokes several feature-detection predicates at module-init time
# (``comfy/ldm/modules/attention.py`` ``xformers_enabled()``,
# ``comfy/ldm/modules/diffusionmodules/model.py``
# ``xformers_enabled_vae()``, etc.). A bare ``MagicMock()`` returns
# truthy auto-attrs, which triggers real ``import xformers`` calls
# that fail in the test environment.
mock_mm = MagicMock()
mock_mm.xformers_enabled.return_value = False
mock_mm.xformers_enabled_vae.return_value = False
mock_mm.pytorch_attention_enabled.return_value = False
mock_mm.pytorch_attention_enabled_vae.return_value = False
mock_mm.sage_attention_enabled.return_value = False
mock_mm.flash_attention_enabled.return_value = False
torch_version_parts = torch.version.__version__.split(".")
mock_mm.torch_version_numeric = (
int(torch_version_parts[0]),
int(torch_version_parts[1]),
)
mock_mm.WINDOWS = False
mock_mm.is_intel_xpu.return_value = False
sys.modules["comfy.model_management"] = mock_mm
# The transitive import chain reaches code paths that do
# ``comfy.model_management.<attr>`` (attribute access on the comfy
# package, not a fresh import). Setting only ``sys.modules`` is not
# enough — also bind the stub as the package attribute. If the
# ``comfy`` package isn't imported yet at stub-time (cold first run),
# importing it now is safe and idempotent.
if comfy_pkg is None:
import comfy as _comfy_pkg # noqa: F401
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
if "comfy_extras.nodes_seedvr" in sys.modules:
nodes_seedvr = sys.modules["comfy_extras.nodes_seedvr"]
else:
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
def _restore():
# 1. comfy.model_management sys.modules entry
if prior_comfy_mm is _SENTINEL:
sys.modules.pop("comfy.model_management", None)
else:
sys.modules["comfy.model_management"] = prior_comfy_mm
# 2. comfy.model_management package attribute on comfy
comfy_pkg_now = sys.modules.get("comfy")
if comfy_pkg_now is not None:
if prior_comfy_mm_attr is _SENTINEL:
if hasattr(comfy_pkg_now, "model_management"):
delattr(comfy_pkg_now, "model_management")
else:
setattr(comfy_pkg_now, "model_management", prior_comfy_mm_attr)
# 3. comfy_extras.nodes_seedvr sys.modules entry
if prior_nodes_seedvr_module is _SENTINEL:
sys.modules.pop("comfy_extras.nodes_seedvr", None)
else:
sys.modules["comfy_extras.nodes_seedvr"] = prior_nodes_seedvr_module
# 4. comfy_extras.nodes_seedvr package attribute on comfy_extras
comfy_extras_pkg_now = sys.modules.get("comfy_extras")
if comfy_extras_pkg_now is not None:
if prior_nodes_seedvr_attr is _SENTINEL:
if hasattr(comfy_extras_pkg_now, "nodes_seedvr"):
delattr(comfy_extras_pkg_now, "nodes_seedvr")
else:
setattr(
comfy_extras_pkg_now, "nodes_seedvr",
prior_nodes_seedvr_attr,
)
return nodes_seedvr, _restore
class _Rope(nn.Module):
def __init__(self):
super().__init__()
self.freqs = nn.Parameter(torch.zeros(4))
class _Block(nn.Module):
def __init__(self):
super().__init__()
self.rope = _Rope()
class _DiffusionModel(nn.Module):
def __init__(
self,
n_blocks=3,
zero_conditioning=False,
conditioning_dtype=torch.float32,
):
super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
if zero_conditioning:
# Simulates a numz-format DiT-only file loaded via UNETLoader:
# ``register_buffer`` zero-init at ``comfy/ldm/seedvr/model.py``
# leaves the buffers at zero when ``load_state_dict`` cannot
# find ``positive_conditioning`` / ``negative_conditioning``
# keys in the state_dict. The fail-loud guard at
# ``SeedVR2Conditioning.execute`` distinguishes this from a
# properly-baked file by ``abs().sum() == 0`` on both buffers.
self.register_buffer(
"positive_conditioning",
torch.zeros((2, 4), dtype=conditioning_dtype),
)
self.register_buffer(
"negative_conditioning",
torch.zeros((3, 4), dtype=conditioning_dtype),
)
else:
self.register_buffer(
"positive_conditioning",
torch.ones((2, 4), dtype=conditioning_dtype),
)
self.register_buffer(
"negative_conditioning",
torch.zeros((3, 4), dtype=conditioning_dtype),
)
class _ModelInner:
def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model
class _ModelPatcher:
def __init__(self, diffusion_model):
self.model = _ModelInner(diffusion_model)
def test_resolve_seedvr2_diffusion_model_returns_inner_when_valid():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model)
resolved = nodes_seedvr._resolve_seedvr2_diffusion_model(patcher)
assert resolved is diffusion_model
finally:
restore()
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
assert [input_item.id for input_item in schema.inputs] == [
"model",
"vae_conditioning",
]
assert schema.inputs[1].display_name == "LATENT"
assert [output.display_name for output in schema.outputs] == [
"model",
"positive",
"negative",
"latent",
]
finally:
restore()
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model)
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
vae_conditioning = {"samples": samples}
_, first_positive, first_negative, first_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
_, second_positive, second_negative, second_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
expected_latent = samples.reshape(1, 6, 2, 2)
channel_last = samples.movedim(1, -1).contiguous()
expected_condition = torch.cat(
[
channel_last,
torch.ones((*channel_last.shape[:-1], 1)),
],
dim=-1,
).movedim(-1, 1).reshape(1, 9, 2, 2)
assert torch.equal(first_latent["samples"], expected_latent)
assert torch.equal(second_latent["samples"], expected_latent)
assert torch.equal(
first_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
first_negative[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_negative[0][1]["condition"],
expected_condition,
)
finally:
restore()
def test_resolve_seedvr2_diffusion_model_raises_runtime_error_with_specific_prefix():
"""Pin all four failure modes of the resolver chain to the same error
prefix and to message text that distinguishes 'attribute missing'
from 'attribute present but None'. The four modes:
mode 1: input has no 'model' attribute
mode 2: input.model is None
mode 3: 'model.model' has no 'diffusion_model' attribute
mode 4: 'model.model.diffusion_model' is None
"""
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
# Mode 1: model has no 'model' attribute at all.
class _NoModelAttr:
pass
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr._resolve_seedvr2_diffusion_model(_NoModelAttr())
msg = str(excinfo.value)
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
assert "no 'model' attribute" in msg
# Mode 2: model.model exists but is None (must not be conflated
# with "no 'model' attribute").
class _ModelIsNone:
def __init__(self):
self.model = None
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr._resolve_seedvr2_diffusion_model(_ModelIsNone())
msg = str(excinfo.value)
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
assert "input.model is None" in msg
# Mode 3: model.model exists, has no 'diffusion_model' attribute.
class _NoDiffusionAttr:
def __init__(self):
self.model = object()
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr._resolve_seedvr2_diffusion_model(_NoDiffusionAttr())
msg = str(excinfo.value)
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
assert "no 'diffusion_model' attribute" in msg
# Mode 4: model.model.diffusion_model exists but is None (must not
# be conflated with "no 'diffusion_model' attribute").
class _DiffusionIsNoneInner:
def __init__(self):
self.diffusion_model = None
class _DiffusionIsNone:
def __init__(self):
self.model = _DiffusionIsNoneInner()
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr._resolve_seedvr2_diffusion_model(_DiffusionIsNone())
msg = str(excinfo.value)
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
assert "'model.model.diffusion_model' is None" in msg
finally:
restore()
def test_apply_rope_freqs_float32_cast_idempotent_on_unchanged_dtype():
"""Calling the helper twice on a model whose rope.freqs is already
float32 must NOT mutate the tensor identity or contents — the dtype
check on every nested module short-circuits the .to() call when the
tensor is already in float32.
"""
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
# Starting dtype is non-float32 so the first call has work to do.
for module in diffusion_model.modules():
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
first_call_data_ids = []
for module in diffusion_model.modules():
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
assert module.rope.freqs.data.dtype == torch.float32
first_call_data_ids.append(id(module.rope.freqs.data))
# Second call on the same already-float32 model: every per-tensor
# dtype check sees float32 and skips the .to() call. Tensor data
# identity must be preserved (no re-allocation).
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
for module, prior_id in zip(
(m for m in diffusion_model.modules()
if hasattr(m, "rope") and hasattr(m.rope, "freqs")),
first_call_data_ids,
strict=True,
):
assert module.rope.freqs.data.dtype == torch.float32
assert id(module.rope.freqs.data) == prior_id, (
"Already-float32 rope.freqs must not be re-allocated on "
"subsequent calls; the per-tensor dtype check must skip the "
".to(float32) call when the tensor is already in float32."
)
finally:
restore()
def test_apply_rope_freqs_float32_cast_recovers_after_dtype_reset():
"""After a model unload/reload that restores rope.freqs from an
archived non-float32 dtype, the next call must re-cast to float32.
A bool-sentinel cache approach would short-circuit here and leave
RoPE running in fp16/bf16.
"""
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
for module in diffusion_model.modules():
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
# First call casts to float32.
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
for module in diffusion_model.modules():
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
assert module.rope.freqs.data.dtype == torch.float32
# Simulate a Comfy dynamic unload/reload that restores rope.freqs
# to the archived (non-float32) dtype.
for module in diffusion_model.modules():
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
# Second call must detect the dtype regression and re-cast.
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
for module in diffusion_model.modules():
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
assert module.rope.freqs.data.dtype == torch.float32, (
"After a model unload/reload that resets rope.freqs to "
"non-float32, the next _apply_rope_freqs_float32_cast "
"call MUST re-cast to float32. A bool-sentinel cache "
"would have short-circuited here."
)
finally:
restore()
# ---------------------------------------------------------------------------
# Fail-loud guard: zero-valued conditioning buffers
# ---------------------------------------------------------------------------
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
"""A SeedVR2 model whose ``positive_conditioning`` AND
``negative_conditioning`` buffers are both zero-valued is an
unrecoverable load state — a numz-format DiT-only ``.safetensors``
file was loaded via ``UNETLoader`` without the SeedVR2 conditioning
keys baked in. ``SeedVR2Conditioning.execute`` must raise
``RuntimeError`` carrying the standard SeedVR2 invalid-model prefix
instead of letting the diffusion sampler run on null prompt
conditioning (which silently produces wrong output).
"""
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(zero_conditioning=True)
patcher = _ModelPatcher(diffusion_model)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert message.startswith(
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
), (
"Fail-loud message must use the standard "
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
f"can match it. Got: {message!r}"
)
assert "positive_conditioning" in message
assert "negative_conditioning" in message
finally:
restore()
def test_seedvr2_conditioning_fails_loud_on_fp8_zero_buffers():
"""The zero-buffer sentinel must reduce fp8 conditioning tensors
without hitting PyTorch's unsupported float8 reductions.
"""
fp8_dtype = getattr(torch, "float8_e4m3fn", None)
if fp8_dtype is None:
pytest.skip("torch build does not expose float8_e4m3fn")
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(
zero_conditioning=True,
conditioning_dtype=fp8_dtype,
)
patcher = _ModelPatcher(diffusion_model)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert message.startswith(
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
)
assert "zero-valued" in message
finally:
restore()
def test_seedvr2_conditioning_does_not_fire_on_partial_zero_buffers():
"""The guard checks BOTH buffers together: a model with zero
``negative_conditioning`` but non-zero ``positive_conditioning``
(the existing baseline mock fixture) must NOT trigger the fail-loud
path. This pins the AND-gating semantic and prevents a future
regression to OR-gating from rejecting valid bundled checkpoints
where one buffer happens to be all-zeros.
"""
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
# Baseline _DiffusionModel has positive=ones, negative=zeros.
diffusion_model = _DiffusionModel(zero_conditioning=False)
patcher = _ModelPatcher(diffusion_model)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
# Should not raise.
passthrough_model, positive, negative, latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
)
assert positive[0][0].shape == (1, 2, 4)
assert negative[0][0].shape == (1, 3, 4)
assert passthrough_model is patcher
finally:
restore()
def test_seedvr2_conditioning_fail_loud_never_exposes_safetensors_path():
"""The fail-loud message must not expose local model paths from
``cached_patcher_init``. Public runtime errors should describe the
invalid SeedVR2 contract without making filesystem paths part of the
public behavior contract.
"""
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(zero_conditioning=True)
patcher = _ModelPatcher(diffusion_model)
# Mimic the ``cached_patcher_init`` shape comfy.sd attaches.
patcher.cached_patcher_init = (
object(), # function reference
("/some/models/diffusion_models/seedvr2_ema_7b_fp16.safetensors",),
)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert "/some/models/diffusion_models" not in message
assert "seedvr2_ema_7b_fp16.safetensors" not in message
assert "Source file:" not in message
assert "positive_conditioning" in message
assert "negative_conditioning" in message
finally:
restore()
def test_seedvr2_conditioning_fail_loud_falls_back_when_path_unavailable():
"""When ``cached_patcher_init`` is missing or its tuple does not
contain a ``.safetensors`` path, the fail-loud message still
delivers the actionable diagnostic without leaking ``None`` or
raising during message formatting.
"""
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(zero_conditioning=True)
patcher = _ModelPatcher(diffusion_model)
# No cached_patcher_init set on the patcher.
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert "Source file:" not in message # no empty path leak
assert "Re-bake" in message # actionable guidance still present
assert "bf16 keys" not in message
finally:
restore()