mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-31 03:17:23 +08:00
- Reduce SeedVR2 coverage down to production unit tests - Route SeedVR2 7B through Comfy varlength attention - Disable SeedVR2 RoPE cache reuse after the upstream DynamicVRAM change
214 lines
7.0 KiB
Python
214 lines
7.0 KiB
Python
"""Consolidated SeedVR2 conditioning and refactor regression tests.
|
|
|
|
Merges the prior test_seedvr2_refactor_nodes.py and
|
|
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
|
|
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
|
|
use _import_nodes_seedvr_isolated() for sys.modules isolation when
|
|
mocking comfy.model_management.
|
|
"""
|
|
|
|
import importlib
|
|
import sys
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
|
|
_SENTINEL = object()
|
|
_TARGETS = (
|
|
("comfy.model_management", "comfy"),
|
|
("comfy_extras.nodes_seedvr", "comfy_extras"),
|
|
)
|
|
|
|
|
|
def _import_nodes_seedvr_isolated():
|
|
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
|
|
priors = []
|
|
for mod_name, parent_name in _TARGETS:
|
|
prior_mod = sys.modules.get(mod_name, _SENTINEL)
|
|
parent = sys.modules.get(parent_name)
|
|
attr = mod_name.split(".")[-1]
|
|
prior_attr = (
|
|
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
|
|
)
|
|
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
|
|
|
|
mock_mm = MagicMock()
|
|
for fn in (
|
|
"xformers_enabled", "xformers_enabled_vae",
|
|
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
|
|
"sage_attention_enabled", "flash_attention_enabled",
|
|
"is_intel_xpu",
|
|
):
|
|
getattr(mock_mm, fn).return_value = False
|
|
tv = torch.version.__version__.split(".")
|
|
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
|
|
mock_mm.WINDOWS = False
|
|
sys.modules["comfy.model_management"] = mock_mm
|
|
if sys.modules.get("comfy") is None:
|
|
import comfy as _comfy_pkg # noqa: F401
|
|
comfy_pkg = sys.modules.get("comfy")
|
|
if comfy_pkg is not None:
|
|
setattr(comfy_pkg, "model_management", mock_mm)
|
|
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
|
|
importlib.import_module("comfy_extras.nodes_seedvr")
|
|
)
|
|
|
|
def _restore():
|
|
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
|
|
if prior_mod is _SENTINEL:
|
|
sys.modules.pop(mod_name, None)
|
|
else:
|
|
sys.modules[mod_name] = prior_mod
|
|
parent = sys.modules.get(parent_name)
|
|
if parent is None:
|
|
continue
|
|
if prior_attr is _SENTINEL:
|
|
if hasattr(parent, attr):
|
|
delattr(parent, attr)
|
|
else:
|
|
setattr(parent, attr, prior_attr)
|
|
|
|
return nodes_seedvr, _restore
|
|
|
|
|
|
class _Rope(nn.Module):
|
|
"""Minimal RoPE stub exposing a `freqs` parameter."""
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.freqs = nn.Parameter(torch.zeros(4))
|
|
|
|
|
|
class _Block(nn.Module):
|
|
"""Minimal transformer block stub holding a `_Rope`."""
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.rope = _Rope()
|
|
|
|
|
|
class _DiffusionModel(nn.Module):
|
|
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
|
|
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
|
|
super().__init__()
|
|
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
|
pos = torch.zeros if zero_conditioning else torch.ones
|
|
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
|
|
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
|
|
|
|
|
|
class _ModelInner:
|
|
"""Inner model wrapper exposing `.diffusion_model`."""
|
|
def __init__(self, diffusion_model):
|
|
self.diffusion_model = diffusion_model
|
|
|
|
|
|
class _ModelPatcher:
|
|
"""ModelPatcher stub exposing `.model._ModelInner`."""
|
|
def __init__(self, diffusion_model):
|
|
self.model = _ModelInner(diffusion_model)
|
|
|
|
|
|
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
|
try:
|
|
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
|
|
assert [input_item.id for input_item in schema.inputs] == [
|
|
"model",
|
|
"vae_conditioning",
|
|
]
|
|
assert schema.inputs[1].display_name == "LATENT"
|
|
assert [output.display_name for output in schema.outputs] == [
|
|
"model",
|
|
"positive",
|
|
"negative",
|
|
"latent",
|
|
]
|
|
finally:
|
|
restore()
|
|
|
|
|
|
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
|
try:
|
|
diffusion_model = _DiffusionModel()
|
|
patcher = _ModelPatcher(diffusion_model)
|
|
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
|
|
vae_conditioning = {"samples": samples}
|
|
|
|
_, first_positive, first_negative, first_latent = (
|
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
|
patcher,
|
|
vae_conditioning,
|
|
)
|
|
)
|
|
_, second_positive, second_negative, second_latent = (
|
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
|
patcher,
|
|
vae_conditioning,
|
|
)
|
|
)
|
|
|
|
expected_latent = samples.reshape(1, 6, 2, 2)
|
|
channel_last = samples.movedim(1, -1).contiguous()
|
|
expected_condition = torch.cat(
|
|
[
|
|
channel_last,
|
|
torch.ones((*channel_last.shape[:-1], 1)),
|
|
],
|
|
dim=-1,
|
|
).movedim(-1, 1).reshape(1, 9, 2, 2)
|
|
|
|
assert torch.equal(first_latent["samples"], expected_latent)
|
|
assert torch.equal(second_latent["samples"], expected_latent)
|
|
assert torch.equal(
|
|
first_positive[0][1]["condition"],
|
|
expected_condition,
|
|
)
|
|
assert torch.equal(
|
|
second_positive[0][1]["condition"],
|
|
expected_condition,
|
|
)
|
|
assert torch.equal(
|
|
first_negative[0][1]["condition"],
|
|
expected_condition,
|
|
)
|
|
assert torch.equal(
|
|
second_negative[0][1]["condition"],
|
|
expected_condition,
|
|
)
|
|
finally:
|
|
restore()
|
|
|
|
|
|
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
|
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
|
try:
|
|
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
|
patcher = _ModelPatcher(diffusion_model)
|
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
|
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
|
patcher, vae_conditioning,
|
|
)
|
|
|
|
message = str(excinfo.value)
|
|
assert message.startswith(
|
|
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
|
), (
|
|
"Fail-loud message must use the standard "
|
|
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
|
|
f"can match it. Got: {message!r}"
|
|
)
|
|
assert "positive_conditioning" in message
|
|
assert "negative_conditioning" in message
|
|
finally:
|
|
restore()
|