mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 19:07:25 +08:00
- Reduce SeedVR2 coverage down to production unit tests - Route SeedVR2 7B through Comfy varlength attention - Disable SeedVR2 RoPE cache reuse after the upstream DynamicVRAM change
61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
from comfy_extras import nodes_seedvr # noqa: E402
|
|
|
|
|
|
def _schema_ids(items):
|
|
return [item.id for item in items]
|
|
|
|
|
|
def test_seedvr2_post_processing_schema():
|
|
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
|
|
|
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"]
|
|
assert schema.inputs[2].default is None
|
|
assert schema.inputs[2].min == 2
|
|
assert schema.inputs[2].force_input is True
|
|
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
|
|
assert schema.inputs[3].default == "lab"
|
|
assert schema.outputs[0].get_io_type() == "IMAGE"
|
|
|
|
|
|
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
|
|
decoded = torch.full((1, 3, 4, 4), 0.25)
|
|
reference = torch.full((1, 3, 4, 4), 0.75)
|
|
|
|
def _lab(content, style):
|
|
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
|
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
try:
|
|
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
|
decoded, reference, torch.device("cpu"), "lab",
|
|
)
|
|
except RuntimeError as exc:
|
|
assert "color_correction_method=lab" in str(exc)
|
|
assert " method=lab" not in str(exc)
|
|
else:
|
|
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
|
|
|
|
|
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
|
decoded = torch.zeros(1, 2, 4, 4, 3)
|
|
original = torch.zeros(1, 2, 4, 4, 3)
|
|
try:
|
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus")
|
|
except ValueError as exc:
|
|
assert "color_correction_method" in str(exc)
|
|
else:
|
|
raise AssertionError("expected ValueError for unknown color_correction_method")
|