mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Add SeedVR2 node coverage
This commit is contained in:
parent
bed0cd2b8c
commit
7050bdc02b
213
tests-unit/comfy_extras_test/test_seedvr2_conditioning.py
Normal file
213
tests-unit/comfy_extras_test/test_seedvr2_conditioning.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
"""Consolidated SeedVR2 conditioning and refactor regression tests.
|
||||||
|
|
||||||
|
Merges the prior test_seedvr2_refactor_nodes.py and
|
||||||
|
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
|
||||||
|
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
|
||||||
|
use _import_nodes_seedvr_isolated() for sys.modules isolation when
|
||||||
|
mocking comfy.model_management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
_TARGETS = (
|
||||||
|
("comfy.model_management", "comfy"),
|
||||||
|
("comfy_extras.nodes_seedvr", "comfy_extras"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_nodes_seedvr_isolated():
|
||||||
|
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
|
||||||
|
priors = []
|
||||||
|
for mod_name, parent_name in _TARGETS:
|
||||||
|
prior_mod = sys.modules.get(mod_name, _SENTINEL)
|
||||||
|
parent = sys.modules.get(parent_name)
|
||||||
|
attr = mod_name.split(".")[-1]
|
||||||
|
prior_attr = (
|
||||||
|
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
|
||||||
|
)
|
||||||
|
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
|
||||||
|
|
||||||
|
mock_mm = MagicMock()
|
||||||
|
for fn in (
|
||||||
|
"xformers_enabled", "xformers_enabled_vae",
|
||||||
|
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
|
||||||
|
"sage_attention_enabled", "flash_attention_enabled",
|
||||||
|
"is_intel_xpu",
|
||||||
|
):
|
||||||
|
getattr(mock_mm, fn).return_value = False
|
||||||
|
tv = torch.version.__version__.split(".")
|
||||||
|
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
|
||||||
|
mock_mm.WINDOWS = False
|
||||||
|
sys.modules["comfy.model_management"] = mock_mm
|
||||||
|
if sys.modules.get("comfy") is None:
|
||||||
|
import comfy as _comfy_pkg # noqa: F401
|
||||||
|
comfy_pkg = sys.modules.get("comfy")
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
setattr(comfy_pkg, "model_management", mock_mm)
|
||||||
|
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
|
||||||
|
importlib.import_module("comfy_extras.nodes_seedvr")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _restore():
|
||||||
|
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
|
||||||
|
if prior_mod is _SENTINEL:
|
||||||
|
sys.modules.pop(mod_name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[mod_name] = prior_mod
|
||||||
|
parent = sys.modules.get(parent_name)
|
||||||
|
if parent is None:
|
||||||
|
continue
|
||||||
|
if prior_attr is _SENTINEL:
|
||||||
|
if hasattr(parent, attr):
|
||||||
|
delattr(parent, attr)
|
||||||
|
else:
|
||||||
|
setattr(parent, attr, prior_attr)
|
||||||
|
|
||||||
|
return nodes_seedvr, _restore
|
||||||
|
|
||||||
|
|
||||||
|
class _Rope(nn.Module):
|
||||||
|
"""Minimal RoPE stub exposing a `freqs` parameter."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.freqs = nn.Parameter(torch.zeros(4))
|
||||||
|
|
||||||
|
|
||||||
|
class _Block(nn.Module):
|
||||||
|
"""Minimal transformer block stub holding a `_Rope`."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.rope = _Rope()
|
||||||
|
|
||||||
|
|
||||||
|
class _DiffusionModel(nn.Module):
|
||||||
|
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
|
||||||
|
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
||||||
|
pos = torch.zeros if zero_conditioning else torch.ones
|
||||||
|
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
|
||||||
|
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelInner:
|
||||||
|
"""Inner model wrapper exposing `.diffusion_model`."""
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.diffusion_model = diffusion_model
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelPatcher:
|
||||||
|
"""ModelPatcher stub exposing `.model._ModelInner`."""
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.model = _ModelInner(diffusion_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
|
||||||
|
assert [input_item.id for input_item in schema.inputs] == [
|
||||||
|
"model",
|
||||||
|
"vae_conditioning",
|
||||||
|
]
|
||||||
|
assert schema.inputs[1].display_name == "latent"
|
||||||
|
assert [output.display_name for output in schema.outputs] == [
|
||||||
|
"model",
|
||||||
|
"positive",
|
||||||
|
"negative",
|
||||||
|
"latent",
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel()
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
|
||||||
|
vae_conditioning = {"samples": samples}
|
||||||
|
|
||||||
|
_, first_positive, first_negative, first_latent = (
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher,
|
||||||
|
vae_conditioning,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_, second_positive, second_negative, second_latent = (
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher,
|
||||||
|
vae_conditioning,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_latent = samples.reshape(1, 6, 2, 2)
|
||||||
|
channel_last = samples.movedim(1, -1).contiguous()
|
||||||
|
expected_condition = torch.cat(
|
||||||
|
[
|
||||||
|
channel_last,
|
||||||
|
torch.ones((*channel_last.shape[:-1], 1)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).movedim(-1, 1).reshape(1, 9, 2, 2)
|
||||||
|
|
||||||
|
assert torch.equal(first_latent["samples"], expected_latent)
|
||||||
|
assert torch.equal(second_latent["samples"], expected_latent)
|
||||||
|
assert torch.equal(
|
||||||
|
first_positive[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
second_positive[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
first_negative[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
second_negative[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher, vae_conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = str(excinfo.value)
|
||||||
|
assert message.startswith(
|
||||||
|
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||||
|
), (
|
||||||
|
"Fail-loud message must use the standard "
|
||||||
|
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
|
||||||
|
f"can match it. Got: {message!r}"
|
||||||
|
)
|
||||||
|
assert "positive_conditioning" in message
|
||||||
|
assert "negative_conditioning" in message
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
55
tests-unit/comfy_extras_test/test_seedvr2_nodes.py
Normal file
55
tests-unit/comfy_extras_test/test_seedvr2_nodes.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr_node_signature_matches_schema():
|
||||||
|
mock_mm = MagicMock()
|
||||||
|
mock_mm.xformers_enabled.return_value = False
|
||||||
|
mock_mm.xformers_enabled_vae.return_value = False
|
||||||
|
mock_mm.sage_attention_enabled.return_value = False
|
||||||
|
mock_mm.flash_attention_enabled.return_value = False
|
||||||
|
|
||||||
|
sentinel = object()
|
||||||
|
prior_cpu = cli_args.cpu
|
||||||
|
cli_args.cpu = True
|
||||||
|
prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
|
||||||
|
comfy_pkg = sys.modules.get("comfy")
|
||||||
|
prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {"comfy.model_management": mock_mm}):
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
setattr(comfy_pkg, "model_management", mock_mm)
|
||||||
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
|
try:
|
||||||
|
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||||
|
for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler):
|
||||||
|
schema_ids = [i.id for i in node_cls.define_schema().inputs]
|
||||||
|
exec_params = [
|
||||||
|
p for p in inspect.signature(node_cls.execute).parameters.keys()
|
||||||
|
if p != "cls"
|
||||||
|
]
|
||||||
|
assert schema_ids == exec_params, (
|
||||||
|
f"{node_cls.__name__} schema/execute drift: "
|
||||||
|
f"schema_ids={schema_ids}, exec_params={exec_params}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
cli_args.cpu = prior_cpu
|
||||||
|
if prior_module is sentinel:
|
||||||
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
|
else:
|
||||||
|
sys.modules["comfy_extras.nodes_seedvr"] = prior_module
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
if prior_mm_attr is sentinel:
|
||||||
|
if hasattr(comfy_pkg, "model_management"):
|
||||||
|
delattr(comfy_pkg, "model_management")
|
||||||
|
else:
|
||||||
|
setattr(comfy_pkg, "model_management", prior_mm_attr)
|
||||||
57
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
57
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
from comfy_extras import nodes_seedvr # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_ids(items):
|
||||||
|
return [item.id for item in items]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_schema():
|
||||||
|
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
||||||
|
|
||||||
|
assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
|
||||||
|
assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
|
||||||
|
assert schema.inputs[2].default == "lab"
|
||||||
|
assert schema.outputs[0].get_io_type() == "IMAGE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
|
||||||
|
decoded = torch.full((1, 3, 4, 4), 0.25)
|
||||||
|
reference = torch.full((1, 3, 4, 4), 0.75)
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
||||||
|
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||||
|
decoded, reference, torch.device("cpu"), "lab",
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
assert "color_correction_method=lab" in str(exc)
|
||||||
|
assert " method=lab" not in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
||||||
|
decoded = torch.zeros(1, 2, 4, 4, 3)
|
||||||
|
original = torch.zeros(1, 2, 4, 4, 3)
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "color_correction_method" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected ValueError for unknown color_correction_method")
|
||||||
Loading…
Reference in New Issue
Block a user