mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Add SeedVR2 node and sampler coverage
This commit is contained in:
parent
c3bfb743e8
commit
8ac1b59107
58
tests-unit/comfy_extras_test/test_seedvr2_node_boundaries.py
Normal file
58
tests-unit/comfy_extras_test/test_seedvr2_node_boundaries.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
from comfy_extras import nodes_seedvr # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_ids(items):
|
||||||
|
return [item.id for item in items]
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_schemas_are_preprocess_only():
|
||||||
|
simple = nodes_seedvr.SeedVR2Resize.define_schema()
|
||||||
|
advanced = nodes_seedvr.SeedVR2ResizeAdvanced.define_schema()
|
||||||
|
|
||||||
|
assert _schema_ids(simple.inputs) == ["images", "multiplier"]
|
||||||
|
assert _schema_ids(simple.outputs) == ["input_pixels", "original_image", "upscaled_shorter_edge"]
|
||||||
|
assert simple.outputs[0].get_io_type() == "IMAGE"
|
||||||
|
|
||||||
|
assert _schema_ids(advanced.inputs) == ["images", "shorter_edge"]
|
||||||
|
assert _schema_ids(advanced.outputs) == ["input_pixels", "original_image", "upscaled_shorter_edge"]
|
||||||
|
assert advanced.outputs[0].get_io_type() == "IMAGE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_nodes_do_not_call_encode_decode_or_color_transfer():
|
||||||
|
source = "\n".join(
|
||||||
|
[
|
||||||
|
inspect.getsource(nodes_seedvr.SeedVR2Resize.execute),
|
||||||
|
inspect.getsource(nodes_seedvr.SeedVR2ResizeAdvanced.execute),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tree = ast.parse(textwrap.dedent(source))
|
||||||
|
forbidden_names = {
|
||||||
|
"encode",
|
||||||
|
"encode_tiled",
|
||||||
|
"decode",
|
||||||
|
"decode_tiled",
|
||||||
|
"tiled_vae",
|
||||||
|
"lab_color_transfer",
|
||||||
|
}
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
func = node.func
|
||||||
|
if isinstance(func, ast.Name):
|
||||||
|
name = func.id
|
||||||
|
elif isinstance(func, ast.Attribute):
|
||||||
|
name = func.attr
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
assert name not in forbidden_names
|
||||||
461
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
461
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
@ -0,0 +1,461 @@
|
|||||||
|
import inspect
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
from comfy_extras import nodes_seedvr # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_ids(items):
|
||||||
|
return [item.id for item in items]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_schema():
|
||||||
|
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
||||||
|
|
||||||
|
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"]
|
||||||
|
assert schema.inputs[2].default is None
|
||||||
|
assert schema.inputs[2].min == 2
|
||||||
|
assert schema.inputs[2].force_input is True
|
||||||
|
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
|
||||||
|
assert schema.inputs[3].default == "lab"
|
||||||
|
assert schema.outputs[0].get_io_type() == "IMAGE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_color_correction_memory_multipliers_are_named():
|
||||||
|
assert nodes_seedvr.LAB_SCALE_MULTIPLIER == 13
|
||||||
|
assert nodes_seedvr.WAVELET_SCALE_MULTIPLIER == 10
|
||||||
|
assert nodes_seedvr.ADAIN_SCALE_MULTIPLIER == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_lab_autochunks_from_memory_estimate(monkeypatch):
|
||||||
|
decoded = torch.full((1, 5, 2, 2, 3), 0.25)
|
||||||
|
original = torch.full((1, 5, 2, 2, 3), 0.75)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
calls.append(content.shape[0])
|
||||||
|
return content
|
||||||
|
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1700)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
|
||||||
|
|
||||||
|
assert calls == [1, 1, 1, 1, 1]
|
||||||
|
assert tuple(output.shape) == (1, 5, 2, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_lab_runs_each_frame_independently(monkeypatch):
|
||||||
|
decoded = torch.full((1, 4, 2, 2, 3), 0.25)
|
||||||
|
original = torch.full((1, 4, 2, 2, 3), 0.75)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
calls.append(content.shape[0])
|
||||||
|
return content
|
||||||
|
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
|
||||||
|
|
||||||
|
assert calls == [1, 1, 1, 1]
|
||||||
|
assert tuple(output.shape) == (1, 4, 2, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_lab_derives_reference_from_original_and_upscaled_shorter_edge():
|
||||||
|
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
||||||
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
calls.append((content.clone(), style.clone()))
|
||||||
|
return torch.zeros_like(content)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0][0].shape == (1, 3, 8, 10)
|
||||||
|
assert calls[0][1].shape == (1, 3, 8, 10)
|
||||||
|
assert torch.equal(calls[0][0], torch.full_like(calls[0][0], -0.5))
|
||||||
|
assert torch.allclose(calls[0][1], torch.full_like(calls[0][1], 0.5))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_lab_runs_color_transfer_on_vae_device():
|
||||||
|
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing.execute)
|
||||||
|
chunk_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks)
|
||||||
|
helper_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._lab_color_transfer_on_vae_device)
|
||||||
|
|
||||||
|
assert "_color_transfer_chunked" in source
|
||||||
|
assert "_lab_color_transfer_on_vae_device" in chunk_source
|
||||||
|
assert "torch.cat" not in chunk_source
|
||||||
|
assert "torch.empty" in chunk_source
|
||||||
|
assert ".copy_(" in chunk_source
|
||||||
|
assert "reference_5d.to(device=decoded_5d.device)" not in source
|
||||||
|
assert "comfy.model_management.vae_device()" in helper_source
|
||||||
|
assert ".to(device=color_device)" in helper_source
|
||||||
|
assert ".to(device=output_device)" in helper_source
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_lab_chunking_is_frame_independent(monkeypatch):
|
||||||
|
decoded = torch.linspace(-0.9, 0.9, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
|
||||||
|
reference = torch.linspace(0.8, -0.8, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
|
||||||
|
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||||
|
|
||||||
|
one_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
|
||||||
|
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 1,
|
||||||
|
)
|
||||||
|
multi_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
|
||||||
|
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.equal(one_frame, multi_frame)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_lab_retry_does_not_mutate_reference(monkeypatch):
|
||||||
|
decoded = torch.full((2, 3, 4, 4), 0.25)
|
||||||
|
reference = torch.full((2, 3, 4, 4), 0.75)
|
||||||
|
original_reference = reference.clone()
|
||||||
|
calls = []
|
||||||
|
cache_clears = []
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
calls.append((content.clone(), style.clone()))
|
||||||
|
style.add_(10.0)
|
||||||
|
if len(calls) == 1:
|
||||||
|
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
||||||
|
return content
|
||||||
|
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: cache_clears.append(True))
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||||
|
decoded, reference, torch.device("cpu"), "lab",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(cache_clears) == 1
|
||||||
|
assert torch.equal(reference, original_reference)
|
||||||
|
assert torch.equal(calls[1][1], original_reference[0:1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
|
||||||
|
decoded = torch.full((1, 3, 4, 4), 0.25)
|
||||||
|
reference = torch.full((1, 3, 4, 4), 0.75)
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
||||||
|
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||||
|
decoded, reference, torch.device("cpu"), "lab",
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
assert "color_correction_method=lab" in str(exc)
|
||||||
|
assert " method=lab" not in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_raw_conversion_does_not_probe_full_tensor_range():
|
||||||
|
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._to_seedvr2_raw)
|
||||||
|
|
||||||
|
assert ".amin" not in source
|
||||||
|
assert ".item" not in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_none_does_not_resize_reference_pixels():
|
||||||
|
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
|
||||||
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "side_resize") as resize:
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||||
|
|
||||||
|
resize.assert_not_called()
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_rejects_invalid_upscaled_shorter_edge():
|
||||||
|
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
|
||||||
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||||
|
|
||||||
|
for edge in (None, 1, 1.5):
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, edge, "none")
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "upscaled_shorter_edge" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"expected ValueError for upscaled_shorter_edge={edge!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_lab_resizes_full_reference_frame():
|
||||||
|
decoded = torch.full((1, 2, 4, 5, 3), 0.25)
|
||||||
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||||
|
resize_calls = []
|
||||||
|
lab_calls = []
|
||||||
|
|
||||||
|
def _resize(images, size, interpolation=None, antialias=None):
|
||||||
|
resize_calls.append((images.clone(), size, interpolation, antialias))
|
||||||
|
if isinstance(size, int):
|
||||||
|
return torch.full((2, 3, size, round(images.shape[-1] * size / images.shape[-2])), 0.5)
|
||||||
|
return torch.full((2, 3, size[0], size[1]), 0.5)
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
lab_calls.append((content.clone(), style.clone()))
|
||||||
|
return torch.zeros_like(content)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr.TVF, "resize", _resize):
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 2, 4, 4, 3)
|
||||||
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||||
|
assert resize_calls[0][0].shape == (2, 3, 16, 20)
|
||||||
|
assert resize_calls[0][1] == 8
|
||||||
|
assert resize_calls[1][0].shape == (2, 3, 8, 10)
|
||||||
|
assert resize_calls[1][1] == (4, 5)
|
||||||
|
assert len(lab_calls) == 2
|
||||||
|
assert lab_calls[0][1].shape == (1, 3, 4, 5)
|
||||||
|
assert torch.equal(lab_calls[0][1], torch.zeros_like(lab_calls[0][1]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_none_trims_and_crops_without_color_correction():
|
||||||
|
decoded = torch.arange(1 * 3 * 9 * 11 * 3, dtype=torch.float32).reshape(1, 3, 9, 11, 3)
|
||||||
|
original = torch.zeros(1, 2, 16, 20, 3)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||||
|
|
||||||
|
assert lab.call_count == 0
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
assert torch.equal(output, decoded[:, :2, :8, :10, :])
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_restores_flattened_padded_batches_before_trimming():
|
||||||
|
decoded = torch.arange(10 * 4 * 6 * 1, dtype=torch.float32).reshape(10, 4, 6, 1)
|
||||||
|
original = torch.zeros(2, 2, 4, 6, 1)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0]
|
||||||
|
|
||||||
|
expected = torch.cat((decoded[0:2], decoded[5:7]), dim=0)
|
||||||
|
assert tuple(output.shape) == (4, 4, 6, 1)
|
||||||
|
assert torch.equal(output, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_none_preserves_decoded_spatial_size_when_reference_is_larger():
|
||||||
|
decoded = torch.arange(1 * 3 * 8 * 10 * 3, dtype=torch.float32).reshape(1, 3, 8, 10, 3)
|
||||||
|
original = torch.zeros(1, 2, 16, 20, 3)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 16, "none").result[0]
|
||||||
|
|
||||||
|
assert lab.call_count == 0
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
assert torch.equal(output, decoded[:, :2, :, :, :])
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_crops_to_reference_tensor_when_reference_is_smaller():
|
||||||
|
decoded = torch.ones((1, 1, 720, 1280, 3), dtype=torch.float32)
|
||||||
|
original = torch.ones((1, 1, 360, 640, 3), dtype=torch.float32)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 360, "none").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 1, 360, 640, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_uses_decoded_size_when_reference_is_larger():
|
||||||
|
decoded = torch.ones((1, 1, 128, 160, 3), dtype=torch.float32)
|
||||||
|
original = torch.ones((1, 1, 480, 640, 3), dtype=torch.float32)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 1, 128, 160, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_derives_crop_from_upscaled_shorter_edge():
|
||||||
|
decoded = torch.ones((1, 1, 128, 224, 3), dtype=torch.float32)
|
||||||
|
original = torch.ones((1, 1, 1080, 1920, 3), dtype=torch.float32)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 1, 120, 212, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_uses_even_crop_from_odd_resized_width():
|
||||||
|
decoded = torch.ones((1, 1, 128, 256, 3), dtype=torch.float32)
|
||||||
|
original = torch.ones((1, 1, 120, 169, 3), dtype=torch.float32)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 1, 120, 168, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_none_preserves_black_bottom_row_content():
|
||||||
|
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||||
|
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||||
|
original[:, :, -1, :, :] = -1.0
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
assert torch.equal(output, decoded)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_none_preserves_black_right_column_content():
|
||||||
|
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||||
|
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||||
|
original[:, :, :, -1, :] = -1.0
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||||
|
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
assert torch.equal(output, decoded)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_wavelet_dispatch_routes_through_wavelet_color_transfer():
|
||||||
|
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
||||||
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||||
|
wavelet_calls = []
|
||||||
|
lab_calls = []
|
||||||
|
|
||||||
|
def _wavelet(content, style):
|
||||||
|
wavelet_calls.append((content.clone(), style.clone()))
|
||||||
|
return torch.zeros_like(content)
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
lab_calls.append((content.clone(), style.clone()))
|
||||||
|
return torch.zeros_like(content)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "wavelet_color_transfer", _wavelet):
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "wavelet").result[0]
|
||||||
|
|
||||||
|
assert len(wavelet_calls) == 1
|
||||||
|
assert len(lab_calls) == 0
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||||
|
assert wavelet_calls[0][0].shape == (2, 3, 8, 10)
|
||||||
|
assert wavelet_calls[0][1].shape == (2, 3, 8, 10)
|
||||||
|
assert torch.equal(wavelet_calls[0][0], torch.full_like(wavelet_calls[0][0], -0.5))
|
||||||
|
assert torch.allclose(wavelet_calls[0][1], torch.full_like(wavelet_calls[0][1], 0.5))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_adain_dispatch_routes_through_adain_color_transfer():
|
||||||
|
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
||||||
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||||
|
adain_calls = []
|
||||||
|
lab_calls = []
|
||||||
|
|
||||||
|
def _adain(content, style):
|
||||||
|
adain_calls.append((content.clone(), style.clone()))
|
||||||
|
return torch.zeros_like(content)
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
lab_calls.append((content.clone(), style.clone()))
|
||||||
|
return torch.zeros_like(content)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "adain_color_transfer", _adain):
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "adain").result[0]
|
||||||
|
|
||||||
|
assert len(adain_calls) == 1
|
||||||
|
assert len(lab_calls) == 0
|
||||||
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||||
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||||
|
assert adain_calls[0][0].shape == (2, 3, 8, 10)
|
||||||
|
assert adain_calls[0][1].shape == (2, 3, 8, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_color_transfer_helper_runs_on_vae_device():
|
||||||
|
import inspect as _inspect
|
||||||
|
helper_source = _inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._color_transfer_on_vae_device)
|
||||||
|
assert "comfy.model_management.vae_device()" in helper_source
|
||||||
|
assert ".to(device=color_device)" in helper_source
|
||||||
|
assert ".to(device=output_device)" in helper_source
|
||||||
|
assert "transfer_fn" in helper_source
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_wavelet_color_transfer_matches_primary_source_reconstruction():
|
||||||
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||||
|
torch.manual_seed(0)
|
||||||
|
content = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
|
||||||
|
style = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
|
||||||
|
out = seedvr_vae.wavelet_color_transfer(content, style)
|
||||||
|
expected = seedvr_vae.wavelet_reconstruction(content.clone(), style.clone())
|
||||||
|
assert torch.equal(out, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_adain_color_transfer_matches_huang_belongie_formula():
|
||||||
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||||
|
torch.manual_seed(0)
|
||||||
|
content = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
|
||||||
|
style = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
|
||||||
|
out = seedvr_vae.adain_color_transfer(content.clone(), style.clone())
|
||||||
|
|
||||||
|
b, c = 2, 3
|
||||||
|
cf = content.float().reshape(b, c, -1)
|
||||||
|
sf = style.float().reshape(b, c, -1)
|
||||||
|
eps = 1e-5
|
||||||
|
mu_c = cf.mean(dim=2).reshape(b, c, 1, 1)
|
||||||
|
sd_c = (cf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||||
|
mu_s = sf.mean(dim=2).reshape(b, c, 1, 1)
|
||||||
|
sd_s = (sf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||||
|
expected = ((content.float() - mu_c) / sd_c) * sd_s + mu_s
|
||||||
|
expected = expected.clamp(-1.0, 1.0)
|
||||||
|
assert torch.allclose(out, expected, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_adain_single_pixel_uses_population_variance_without_nan():
|
||||||
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||||
|
content = torch.tensor([[[[0.25]], [[-0.5]], [[0.75]]]], dtype=torch.float32)
|
||||||
|
style = torch.tensor([[[[-0.25]], [[0.5]], [[-0.75]]]], dtype=torch.float32)
|
||||||
|
|
||||||
|
out = seedvr_vae.adain_color_transfer(content, style)
|
||||||
|
|
||||||
|
assert torch.isfinite(out).all()
|
||||||
|
assert torch.equal(out, style)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_adain_preserves_input_dtype():
|
||||||
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||||
|
content = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
|
||||||
|
style = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
|
||||||
|
out = seedvr_vae.adain_color_transfer(content, style)
|
||||||
|
assert out.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_adain_resizes_mismatched_style_to_content_shape():
|
||||||
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||||
|
content = torch.rand(1, 3, 8, 10) * 2.0 - 1.0
|
||||||
|
style = torch.rand(1, 3, 16, 20) * 2.0 - 1.0
|
||||||
|
out = seedvr_vae.adain_color_transfer(content, style)
|
||||||
|
assert tuple(out.shape) == (1, 3, 8, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
||||||
|
decoded = torch.zeros(1, 2, 4, 4, 3)
|
||||||
|
original = torch.zeros(1, 2, 4, 4, 3)
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus")
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "color_correction_method" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected ValueError for unknown color_correction_method")
|
||||||
@ -0,0 +1,601 @@
|
|||||||
|
"""Regression tests for SeedVR2 conditioning model resolution and RoPE
|
||||||
|
frequency cast.
|
||||||
|
|
||||||
|
Pin two behaviors:
|
||||||
|
|
||||||
|
1. ``_resolve_seedvr2_diffusion_model`` returns the inner diffusion-model
|
||||||
|
for the expected ``model.model.diffusion_model`` shape and fails loud
|
||||||
|
with a ``RuntimeError`` whose message begins with
|
||||||
|
``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` for any other shape, including
|
||||||
|
the four distinct missing-vs-None subcases of the chain.
|
||||||
|
2. ``_apply_rope_freqs_float32_cast`` is idempotent **per-tensor by
|
||||||
|
dtype check**, NOT per-instance by sentinel attribute. Every call
|
||||||
|
walks the diffusion-model module tree and invokes ``.to(float32)``
|
||||||
|
only on tensors whose dtype is not already ``float32``. A cache-by-
|
||||||
|
attribute (sentinel) approach is rejected because the sentinel
|
||||||
|
would survive ComfyUI's dynamic model unload/reload cycle while
|
||||||
|
``rope.freqs`` itself is restored to the archived dtype, so the
|
||||||
|
next call would short-circuit and leave RoPE running in fp16/bf16
|
||||||
|
— the exact failure this helper is supposed to prevent. The dtype
|
||||||
|
check is self-correcting against any weight-restore lifecycle
|
||||||
|
event.
|
||||||
|
|
||||||
|
Import isolation: ``comfy.model_management`` is stubbed via direct
|
||||||
|
``sys.modules`` assignment so importing ``comfy_extras.nodes_seedvr`` does
|
||||||
|
not trigger GPU/server-side initialization. ``patch.dict`` is intentionally
|
||||||
|
NOT used here because its snapshot/restore semantics evict transitively
|
||||||
|
imported third-party modules (e.g. ``torchvision``) on exit, which causes
|
||||||
|
``torch``'s global op-library Meta-key registrations to double-register on
|
||||||
|
re-import. Module-level cached import + scoped restore of the four mocked
|
||||||
|
entries avoids that hazard. See ``_import_nodes_seedvr_isolated``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _import_nodes_seedvr_isolated():
|
||||||
|
"""Stub ``comfy.model_management``, import (or reuse a cached import of)
|
||||||
|
``comfy_extras.nodes_seedvr``, and return ``(module, restore)``.
|
||||||
|
|
||||||
|
``restore()`` snapshots and restores three in-process import-state
|
||||||
|
surfaces:
|
||||||
|
|
||||||
|
1. ``sys.modules["comfy.model_management"]`` — the stubbed module.
|
||||||
|
2. ``sys.modules["comfy_extras.nodes_seedvr"]`` — the imported test
|
||||||
|
target. If we leave this in ``sys.modules`` after the test, a
|
||||||
|
later test importing the real ``comfy_extras.nodes_seedvr`` will
|
||||||
|
get our stubbed-``comfy.model_management`` cached version, which
|
||||||
|
does not re-resolve against the real ``comfy.model_management``.
|
||||||
|
3. ``comfy_extras.nodes_seedvr`` package attribute on the
|
||||||
|
``comfy_extras`` package, mirroring the existing
|
||||||
|
``comfy.model_management`` attribute restore.
|
||||||
|
|
||||||
|
All three are restored verbatim if previously set; deleted on exit
|
||||||
|
if previously unset. No global state leaks into later tests.
|
||||||
|
"""
|
||||||
|
prior_comfy_mm = sys.modules.get("comfy.model_management", _SENTINEL)
|
||||||
|
prior_comfy_mm_attr = _SENTINEL
|
||||||
|
comfy_pkg = sys.modules.get("comfy")
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
prior_comfy_mm_attr = getattr(comfy_pkg, "model_management", _SENTINEL)
|
||||||
|
prior_nodes_seedvr_module = sys.modules.get(
|
||||||
|
"comfy_extras.nodes_seedvr", _SENTINEL,
|
||||||
|
)
|
||||||
|
prior_nodes_seedvr_attr = _SENTINEL
|
||||||
|
comfy_extras_pkg = sys.modules.get("comfy_extras")
|
||||||
|
if comfy_extras_pkg is not None:
|
||||||
|
prior_nodes_seedvr_attr = getattr(
|
||||||
|
comfy_extras_pkg, "nodes_seedvr", _SENTINEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ``comfy_extras.nodes_seedvr`` imports ``comfy.sample`` (added in PR
|
||||||
|
# #59) which pulls in the full samplers/k_diffusion/model_patcher
|
||||||
|
# transitive chain. That chain re-imports ``comfy.model_management``
|
||||||
|
# and calls feature-detection predicates like ``xformers_enabled()``
|
||||||
|
# in module-init code (``comfy/ldm/modules/attention.py:18``); a bare
|
||||||
|
# ``MagicMock()`` returns truthy for those calls and triggers a real
|
||||||
|
# ``import xformers`` that fails in the test environment. Pin the
|
||||||
|
# boolean-returning predicates to ``False`` so the import chain
|
||||||
|
# follows the no-extension path.
|
||||||
|
# Configure stub so every ``..._enabled[_*]()`` predicate returns
|
||||||
|
# False. The transitive import chain through ``comfy.sample`` → ...
|
||||||
|
# invokes several feature-detection predicates at module-init time
|
||||||
|
# (``comfy/ldm/modules/attention.py`` ``xformers_enabled()``,
|
||||||
|
# ``comfy/ldm/modules/diffusionmodules/model.py``
|
||||||
|
# ``xformers_enabled_vae()``, etc.). A bare ``MagicMock()`` returns
|
||||||
|
# truthy auto-attrs, which triggers real ``import xformers`` calls
|
||||||
|
# that fail in the test environment.
|
||||||
|
mock_mm = MagicMock()
|
||||||
|
mock_mm.xformers_enabled.return_value = False
|
||||||
|
mock_mm.xformers_enabled_vae.return_value = False
|
||||||
|
mock_mm.pytorch_attention_enabled.return_value = False
|
||||||
|
mock_mm.pytorch_attention_enabled_vae.return_value = False
|
||||||
|
mock_mm.sage_attention_enabled.return_value = False
|
||||||
|
mock_mm.flash_attention_enabled.return_value = False
|
||||||
|
torch_version_parts = torch.version.__version__.split(".")
|
||||||
|
mock_mm.torch_version_numeric = (
|
||||||
|
int(torch_version_parts[0]),
|
||||||
|
int(torch_version_parts[1]),
|
||||||
|
)
|
||||||
|
mock_mm.WINDOWS = False
|
||||||
|
mock_mm.is_intel_xpu.return_value = False
|
||||||
|
sys.modules["comfy.model_management"] = mock_mm
|
||||||
|
# The transitive import chain reaches code paths that do
|
||||||
|
# ``comfy.model_management.<attr>`` (attribute access on the comfy
|
||||||
|
# package, not a fresh import). Setting only ``sys.modules`` is not
|
||||||
|
# enough — also bind the stub as the package attribute. If the
|
||||||
|
# ``comfy`` package isn't imported yet at stub-time (cold first run),
|
||||||
|
# importing it now is safe and idempotent.
|
||||||
|
if comfy_pkg is None:
|
||||||
|
import comfy as _comfy_pkg # noqa: F401
|
||||||
|
comfy_pkg = sys.modules.get("comfy")
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
setattr(comfy_pkg, "model_management", mock_mm)
|
||||||
|
if "comfy_extras.nodes_seedvr" in sys.modules:
|
||||||
|
nodes_seedvr = sys.modules["comfy_extras.nodes_seedvr"]
|
||||||
|
else:
|
||||||
|
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||||
|
|
||||||
|
def _restore():
|
||||||
|
# 1. comfy.model_management sys.modules entry
|
||||||
|
if prior_comfy_mm is _SENTINEL:
|
||||||
|
sys.modules.pop("comfy.model_management", None)
|
||||||
|
else:
|
||||||
|
sys.modules["comfy.model_management"] = prior_comfy_mm
|
||||||
|
# 2. comfy.model_management package attribute on comfy
|
||||||
|
comfy_pkg_now = sys.modules.get("comfy")
|
||||||
|
if comfy_pkg_now is not None:
|
||||||
|
if prior_comfy_mm_attr is _SENTINEL:
|
||||||
|
if hasattr(comfy_pkg_now, "model_management"):
|
||||||
|
delattr(comfy_pkg_now, "model_management")
|
||||||
|
else:
|
||||||
|
setattr(comfy_pkg_now, "model_management", prior_comfy_mm_attr)
|
||||||
|
# 3. comfy_extras.nodes_seedvr sys.modules entry
|
||||||
|
if prior_nodes_seedvr_module is _SENTINEL:
|
||||||
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
|
else:
|
||||||
|
sys.modules["comfy_extras.nodes_seedvr"] = prior_nodes_seedvr_module
|
||||||
|
# 4. comfy_extras.nodes_seedvr package attribute on comfy_extras
|
||||||
|
comfy_extras_pkg_now = sys.modules.get("comfy_extras")
|
||||||
|
if comfy_extras_pkg_now is not None:
|
||||||
|
if prior_nodes_seedvr_attr is _SENTINEL:
|
||||||
|
if hasattr(comfy_extras_pkg_now, "nodes_seedvr"):
|
||||||
|
delattr(comfy_extras_pkg_now, "nodes_seedvr")
|
||||||
|
else:
|
||||||
|
setattr(
|
||||||
|
comfy_extras_pkg_now, "nodes_seedvr",
|
||||||
|
prior_nodes_seedvr_attr,
|
||||||
|
)
|
||||||
|
|
||||||
|
return nodes_seedvr, _restore
|
||||||
|
|
||||||
|
|
||||||
|
class _Rope(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.freqs = nn.Parameter(torch.zeros(4))
|
||||||
|
|
||||||
|
|
||||||
|
class _Block(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.rope = _Rope()
|
||||||
|
|
||||||
|
|
||||||
|
class _DiffusionModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_blocks=3,
|
||||||
|
zero_conditioning=False,
|
||||||
|
conditioning_dtype=torch.float32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
||||||
|
if zero_conditioning:
|
||||||
|
# Simulates a numz-format DiT-only file loaded via UNETLoader:
|
||||||
|
# ``register_buffer`` zero-init at ``comfy/ldm/seedvr/model.py``
|
||||||
|
# leaves the buffers at zero when ``load_state_dict`` cannot
|
||||||
|
# find ``positive_conditioning`` / ``negative_conditioning``
|
||||||
|
# keys in the state_dict. The fail-loud guard at
|
||||||
|
# ``SeedVR2Conditioning.execute`` distinguishes this from a
|
||||||
|
# properly-baked file by ``abs().sum() == 0`` on both buffers.
|
||||||
|
self.register_buffer(
|
||||||
|
"positive_conditioning",
|
||||||
|
torch.zeros((2, 4), dtype=conditioning_dtype),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"negative_conditioning",
|
||||||
|
torch.zeros((3, 4), dtype=conditioning_dtype),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"positive_conditioning",
|
||||||
|
torch.ones((2, 4), dtype=conditioning_dtype),
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"negative_conditioning",
|
||||||
|
torch.zeros((3, 4), dtype=conditioning_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelInner:
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.diffusion_model = diffusion_model
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelPatcher:
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.model = _ModelInner(diffusion_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_seedvr2_diffusion_model_returns_inner_when_valid():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel()
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
resolved = nodes_seedvr._resolve_seedvr2_diffusion_model(patcher)
|
||||||
|
assert resolved is diffusion_model
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
|
||||||
|
assert [input_item.id for input_item in schema.inputs] == [
|
||||||
|
"model",
|
||||||
|
"vae_conditioning",
|
||||||
|
]
|
||||||
|
assert schema.inputs[1].display_name == "LATENT"
|
||||||
|
assert [output.display_name for output in schema.outputs] == [
|
||||||
|
"model",
|
||||||
|
"positive",
|
||||||
|
"negative",
|
||||||
|
"latent",
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel()
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
|
||||||
|
vae_conditioning = {"samples": samples}
|
||||||
|
|
||||||
|
_, first_positive, first_negative, first_latent = (
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher,
|
||||||
|
vae_conditioning,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_, second_positive, second_negative, second_latent = (
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher,
|
||||||
|
vae_conditioning,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_latent = samples.reshape(1, 6, 2, 2)
|
||||||
|
channel_last = samples.movedim(1, -1).contiguous()
|
||||||
|
expected_condition = torch.cat(
|
||||||
|
[
|
||||||
|
channel_last,
|
||||||
|
torch.ones((*channel_last.shape[:-1], 1)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).movedim(-1, 1).reshape(1, 9, 2, 2)
|
||||||
|
|
||||||
|
assert torch.equal(first_latent["samples"], expected_latent)
|
||||||
|
assert torch.equal(second_latent["samples"], expected_latent)
|
||||||
|
assert torch.equal(
|
||||||
|
first_positive[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
second_positive[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
first_negative[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
second_negative[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_seedvr2_diffusion_model_raises_runtime_error_with_specific_prefix():
|
||||||
|
"""Pin all four failure modes of the resolver chain to the same error
|
||||||
|
prefix and to message text that distinguishes 'attribute missing'
|
||||||
|
from 'attribute present but None'. The four modes:
|
||||||
|
|
||||||
|
mode 1: input has no 'model' attribute
|
||||||
|
mode 2: input.model is None
|
||||||
|
mode 3: 'model.model' has no 'diffusion_model' attribute
|
||||||
|
mode 4: 'model.model.diffusion_model' is None
|
||||||
|
"""
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
# Mode 1: model has no 'model' attribute at all.
|
||||||
|
class _NoModelAttr:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr._resolve_seedvr2_diffusion_model(_NoModelAttr())
|
||||||
|
msg = str(excinfo.value)
|
||||||
|
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||||
|
assert "no 'model' attribute" in msg
|
||||||
|
|
||||||
|
# Mode 2: model.model exists but is None (must not be conflated
|
||||||
|
# with "no 'model' attribute").
|
||||||
|
class _ModelIsNone:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr._resolve_seedvr2_diffusion_model(_ModelIsNone())
|
||||||
|
msg = str(excinfo.value)
|
||||||
|
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||||
|
assert "input.model is None" in msg
|
||||||
|
|
||||||
|
# Mode 3: model.model exists, has no 'diffusion_model' attribute.
|
||||||
|
class _NoDiffusionAttr:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = object()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr._resolve_seedvr2_diffusion_model(_NoDiffusionAttr())
|
||||||
|
msg = str(excinfo.value)
|
||||||
|
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||||
|
assert "no 'diffusion_model' attribute" in msg
|
||||||
|
|
||||||
|
# Mode 4: model.model.diffusion_model exists but is None (must not
|
||||||
|
# be conflated with "no 'diffusion_model' attribute").
|
||||||
|
class _DiffusionIsNoneInner:
|
||||||
|
def __init__(self):
|
||||||
|
self.diffusion_model = None
|
||||||
|
|
||||||
|
class _DiffusionIsNone:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = _DiffusionIsNoneInner()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr._resolve_seedvr2_diffusion_model(_DiffusionIsNone())
|
||||||
|
msg = str(excinfo.value)
|
||||||
|
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||||
|
assert "'model.model.diffusion_model' is None" in msg
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rope_freqs_float32_cast_idempotent_on_unchanged_dtype():
|
||||||
|
"""Calling the helper twice on a model whose rope.freqs is already
|
||||||
|
float32 must NOT mutate the tensor identity or contents — the dtype
|
||||||
|
check on every nested module short-circuits the .to() call when the
|
||||||
|
tensor is already in float32.
|
||||||
|
"""
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel()
|
||||||
|
|
||||||
|
# Starting dtype is non-float32 so the first call has work to do.
|
||||||
|
for module in diffusion_model.modules():
|
||||||
|
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||||
|
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
|
||||||
|
|
||||||
|
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||||
|
first_call_data_ids = []
|
||||||
|
for module in diffusion_model.modules():
|
||||||
|
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||||
|
assert module.rope.freqs.data.dtype == torch.float32
|
||||||
|
first_call_data_ids.append(id(module.rope.freqs.data))
|
||||||
|
|
||||||
|
# Second call on the same already-float32 model: every per-tensor
|
||||||
|
# dtype check sees float32 and skips the .to() call. Tensor data
|
||||||
|
# identity must be preserved (no re-allocation).
|
||||||
|
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||||
|
for module, prior_id in zip(
|
||||||
|
(m for m in diffusion_model.modules()
|
||||||
|
if hasattr(m, "rope") and hasattr(m.rope, "freqs")),
|
||||||
|
first_call_data_ids,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
assert module.rope.freqs.data.dtype == torch.float32
|
||||||
|
assert id(module.rope.freqs.data) == prior_id, (
|
||||||
|
"Already-float32 rope.freqs must not be re-allocated on "
|
||||||
|
"subsequent calls; the per-tensor dtype check must skip the "
|
||||||
|
".to(float32) call when the tensor is already in float32."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_rope_freqs_float32_cast_recovers_after_dtype_reset():
|
||||||
|
"""After a model unload/reload that restores rope.freqs from an
|
||||||
|
archived non-float32 dtype, the next call must re-cast to float32.
|
||||||
|
A bool-sentinel cache approach would short-circuit here and leave
|
||||||
|
RoPE running in fp16/bf16.
|
||||||
|
"""
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel()
|
||||||
|
for module in diffusion_model.modules():
|
||||||
|
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||||
|
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
|
||||||
|
|
||||||
|
# First call casts to float32.
|
||||||
|
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||||
|
for module in diffusion_model.modules():
|
||||||
|
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||||
|
assert module.rope.freqs.data.dtype == torch.float32
|
||||||
|
|
||||||
|
# Simulate a Comfy dynamic unload/reload that restores rope.freqs
|
||||||
|
# to the archived (non-float32) dtype.
|
||||||
|
for module in diffusion_model.modules():
|
||||||
|
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||||
|
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
|
||||||
|
|
||||||
|
# Second call must detect the dtype regression and re-cast.
|
||||||
|
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||||
|
for module in diffusion_model.modules():
|
||||||
|
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||||
|
assert module.rope.freqs.data.dtype == torch.float32, (
|
||||||
|
"After a model unload/reload that resets rope.freqs to "
|
||||||
|
"non-float32, the next _apply_rope_freqs_float32_cast "
|
||||||
|
"call MUST re-cast to float32. A bool-sentinel cache "
|
||||||
|
"would have short-circuited here."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fail-loud guard: zero-valued conditioning buffers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
|
||||||
|
"""A SeedVR2 model whose ``positive_conditioning`` AND
|
||||||
|
``negative_conditioning`` buffers are both zero-valued is an
|
||||||
|
unrecoverable load state — a numz-format DiT-only ``.safetensors``
|
||||||
|
file was loaded via ``UNETLoader`` without the SeedVR2 conditioning
|
||||||
|
keys baked in. ``SeedVR2Conditioning.execute`` must raise
|
||||||
|
``RuntimeError`` carrying the standard SeedVR2 invalid-model prefix
|
||||||
|
instead of letting the diffusion sampler run on null prompt
|
||||||
|
conditioning (which silently produces wrong output).
|
||||||
|
"""
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher, vae_conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = str(excinfo.value)
|
||||||
|
assert message.startswith(
|
||||||
|
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||||
|
), (
|
||||||
|
"Fail-loud message must use the standard "
|
||||||
|
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
|
||||||
|
f"can match it. Got: {message!r}"
|
||||||
|
)
|
||||||
|
assert "positive_conditioning" in message
|
||||||
|
assert "negative_conditioning" in message
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_fails_loud_on_fp8_zero_buffers():
|
||||||
|
"""The zero-buffer sentinel must reduce fp8 conditioning tensors
|
||||||
|
without hitting PyTorch's unsupported float8 reductions.
|
||||||
|
"""
|
||||||
|
fp8_dtype = getattr(torch, "float8_e4m3fn", None)
|
||||||
|
if fp8_dtype is None:
|
||||||
|
pytest.skip("torch build does not expose float8_e4m3fn")
|
||||||
|
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel(
|
||||||
|
zero_conditioning=True,
|
||||||
|
conditioning_dtype=fp8_dtype,
|
||||||
|
)
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher, vae_conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = str(excinfo.value)
|
||||||
|
assert message.startswith(
|
||||||
|
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||||
|
)
|
||||||
|
assert "zero-valued" in message
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_does_not_fire_on_partial_zero_buffers():
|
||||||
|
"""The guard checks BOTH buffers together: a model with zero
|
||||||
|
``negative_conditioning`` but non-zero ``positive_conditioning``
|
||||||
|
(the existing baseline mock fixture) must NOT trigger the fail-loud
|
||||||
|
path. This pins the AND-gating semantic and prevents a future
|
||||||
|
regression to OR-gating from rejecting valid bundled checkpoints
|
||||||
|
where one buffer happens to be all-zeros.
|
||||||
|
"""
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
# Baseline _DiffusionModel has positive=ones, negative=zeros.
|
||||||
|
diffusion_model = _DiffusionModel(zero_conditioning=False)
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||||
|
|
||||||
|
# Should not raise.
|
||||||
|
passthrough_model, positive, negative, latent = (
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher, vae_conditioning,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert positive[0][0].shape == (1, 2, 4)
|
||||||
|
assert negative[0][0].shape == (1, 3, 4)
|
||||||
|
assert passthrough_model is patcher
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_fail_loud_never_exposes_safetensors_path():
|
||||||
|
"""The fail-loud message must not expose local model paths from
|
||||||
|
``cached_patcher_init``. Public runtime errors should describe the
|
||||||
|
invalid SeedVR2 contract without making filesystem paths part of the
|
||||||
|
public behavior contract.
|
||||||
|
"""
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
# Mimic the ``cached_patcher_init`` shape comfy.sd attaches.
|
||||||
|
patcher.cached_patcher_init = (
|
||||||
|
object(), # function reference
|
||||||
|
("/some/models/diffusion_models/seedvr2_ema_7b_fp16.safetensors",),
|
||||||
|
)
|
||||||
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher, vae_conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = str(excinfo.value)
|
||||||
|
assert "/some/models/diffusion_models" not in message
|
||||||
|
assert "seedvr2_ema_7b_fp16.safetensors" not in message
|
||||||
|
assert "Source file:" not in message
|
||||||
|
assert "positive_conditioning" in message
|
||||||
|
assert "negative_conditioning" in message
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_fail_loud_falls_back_when_path_unavailable():
|
||||||
|
"""When ``cached_patcher_init`` is missing or its tuple does not
|
||||||
|
contain a ``.safetensors`` path, the fail-loud message still
|
||||||
|
delivers the actionable diagnostic without leaking ``None`` or
|
||||||
|
raising during message formatting.
|
||||||
|
"""
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
# No cached_patcher_init set on the patcher.
|
||||||
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher, vae_conditioning,
|
||||||
|
)
|
||||||
|
message = str(excinfo.value)
|
||||||
|
assert "Source file:" not in message # no empty path leak
|
||||||
|
assert "Re-bake" in message # actionable guidance still present
|
||||||
|
assert "bf16 keys" not in message
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
103
tests-unit/comfy_extras_test/test_seedvr_node_signature.py
Normal file
103
tests-unit/comfy_extras_test/test_seedvr_node_signature.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
"""Regression test: SeedVR2 resize schema input ids must match
|
||||||
|
execute() positional parameter order. Drift between the two would silently
|
||||||
|
swap arguments at runtime; this test fails loudly on any future drift.
|
||||||
|
|
||||||
|
The schema input attribute is `.id` (verified live via Python introspection
|
||||||
|
on the upstream class -- there is no `.name`).
|
||||||
|
|
||||||
|
`comfy.model_management` is stubbed via `patch.dict(sys.modules, ...)` for
|
||||||
|
the import performed inside this test, so importing
|
||||||
|
`comfy_extras.nodes_seedvr` here does not call
|
||||||
|
`torch.cuda.is_available()` or trigger other GPU/server-side
|
||||||
|
initialization through that dependency. Live introspection indicated that
|
||||||
|
`comfy_extras.nodes_seedvr` pulls in `comfy.model_management`
|
||||||
|
transitively here (not `nodes`, not `server`).
|
||||||
|
|
||||||
|
The test snapshots three pieces of import state before patching and
|
||||||
|
restores all three in `finally` via a sentinel:
|
||||||
|
|
||||||
|
1. `sys.modules["comfy_extras.nodes_seedvr"]`
|
||||||
|
2. `comfy.model_management` package attribute on the `comfy` package
|
||||||
|
3. `comfy_extras.nodes_seedvr` attribute on the `comfy_extras` package
|
||||||
|
|
||||||
|
If any of the three was set before the test, it is restored verbatim;
|
||||||
|
if it was unset, it is deleted on exit. This prevents the test from
|
||||||
|
clobbering a real `comfy.model_management` (or
|
||||||
|
`comfy_extras.nodes_seedvr`) module that another test may have
|
||||||
|
legitimately imported earlier in the same pytest process, while still
|
||||||
|
preventing the test's mock from leaking into later tests that import
|
||||||
|
the real `comfy_extras.nodes_seedvr`."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr_node_signature_matches_schema():
|
||||||
|
mock_model_management = MagicMock()
|
||||||
|
mock_model_management.xformers_enabled.return_value = False
|
||||||
|
mock_model_management.xformers_enabled_vae.return_value = False
|
||||||
|
mock_model_management.sage_attention_enabled.return_value = False
|
||||||
|
mock_model_management.flash_attention_enabled.return_value = False
|
||||||
|
sentinel = object()
|
||||||
|
prior_cpu = cli_args.cpu
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
comfy_module_pre = sys.modules.get("comfy")
|
||||||
|
comfy_extras_module_pre = sys.modules.get("comfy_extras")
|
||||||
|
prior_comfy_mm_attr = (
|
||||||
|
getattr(comfy_module_pre, "model_management", sentinel)
|
||||||
|
if comfy_module_pre is not None
|
||||||
|
else sentinel
|
||||||
|
)
|
||||||
|
prior_comfy_extras_seedvr_attr = (
|
||||||
|
getattr(comfy_extras_module_pre, "nodes_seedvr", sentinel)
|
||||||
|
if comfy_extras_module_pre is not None
|
||||||
|
else sentinel
|
||||||
|
)
|
||||||
|
prior_comfy_extras_seedvr_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {"comfy.model_management": mock_model_management}):
|
||||||
|
if comfy_module_pre is not None:
|
||||||
|
setattr(comfy_module_pre, "model_management", mock_model_management)
|
||||||
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
|
try:
|
||||||
|
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||||
|
for node_cls in (
|
||||||
|
nodes_seedvr.SeedVR2Resize,
|
||||||
|
nodes_seedvr.SeedVR2ResizeAdvanced,
|
||||||
|
):
|
||||||
|
schema_ids = [i.id for i in node_cls.define_schema().inputs]
|
||||||
|
exec_params = [
|
||||||
|
p
|
||||||
|
for p in inspect.signature(node_cls.execute).parameters.keys()
|
||||||
|
if p != "cls"
|
||||||
|
]
|
||||||
|
assert schema_ids == exec_params, (
|
||||||
|
f"{node_cls.__name__} schema input ids do not match "
|
||||||
|
f"execute() parameter order: schema_ids={schema_ids}, "
|
||||||
|
f"exec_params={exec_params}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if prior_comfy_extras_seedvr_module is sentinel:
|
||||||
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
|
else:
|
||||||
|
sys.modules["comfy_extras.nodes_seedvr"] = prior_comfy_extras_seedvr_module
|
||||||
|
cli_args.cpu = prior_cpu
|
||||||
|
comfy_extras_module = sys.modules.get("comfy_extras")
|
||||||
|
if comfy_extras_module is not None:
|
||||||
|
if prior_comfy_extras_seedvr_attr is sentinel:
|
||||||
|
if hasattr(comfy_extras_module, "nodes_seedvr"):
|
||||||
|
delattr(comfy_extras_module, "nodes_seedvr")
|
||||||
|
else:
|
||||||
|
setattr(comfy_extras_module, "nodes_seedvr", prior_comfy_extras_seedvr_attr)
|
||||||
|
comfy_module = sys.modules.get("comfy")
|
||||||
|
if comfy_module is not None:
|
||||||
|
if prior_comfy_mm_attr is sentinel:
|
||||||
|
if hasattr(comfy_module, "model_management"):
|
||||||
|
delattr(comfy_module, "model_management")
|
||||||
|
else:
|
||||||
|
setattr(comfy_module, "model_management", prior_comfy_mm_attr)
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
FILES = [
|
||||||
|
ROOT / "comfy/ldm/seedvr/vae.py",
|
||||||
|
ROOT / "comfy/sd.py",
|
||||||
|
ROOT / "comfy_extras/nodes_seedvr.py",
|
||||||
|
]
|
||||||
|
FORBIDDEN_ATTRS = {"original_image_video", "img_dims", "tiled_args"}
|
||||||
|
FORBIDDEN_KEYS = {
|
||||||
|
"sampler_metadata",
|
||||||
|
"latent_sidecar_metadata",
|
||||||
|
"saved_latent_metadata",
|
||||||
|
"workflow_hidden_state",
|
||||||
|
}
|
||||||
|
FORBIDDEN_GETSET_KEYS = {"original_image_video", "img_dims", "tiled_args"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_decode_paths_do_not_use_hidden_vae_object_state():
|
||||||
|
for path in FILES:
|
||||||
|
tree = ast.parse(path.read_text(encoding="utf-8"))
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.Attribute) and node.attr in FORBIDDEN_ATTRS:
|
||||||
|
pytest.fail(f"{path}: forbidden VAE object state attr {node.attr}")
|
||||||
|
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
||||||
|
if node.func.id in {"getattr", "setattr", "delattr"} and len(node.args) >= 2:
|
||||||
|
key = node.args[1]
|
||||||
|
if isinstance(key, ast.Constant) and key.value in FORBIDDEN_GETSET_KEYS:
|
||||||
|
pytest.fail(f"{path}: forbidden VAE object state access {key.value}")
|
||||||
|
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||||
|
if node.value in FORBIDDEN_ATTRS or node.value in FORBIDDEN_KEYS:
|
||||||
|
pytest.fail(f"{path}: forbidden hidden-state string {node.value}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(pytest.main([__file__]))
|
||||||
43
tests-unit/comfy_test/test_seedvr2_non_goal_static_audit.py
Normal file
43
tests-unit/comfy_test/test_seedvr2_non_goal_static_audit.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
FORBIDDEN_FILES = {
|
||||||
|
"comfy/ldm/seedvr/model.py",
|
||||||
|
"comfy/ldm/modules/attention.py",
|
||||||
|
"comfy/sample.py",
|
||||||
|
"comfy/samplers.py",
|
||||||
|
}
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.environ.get("SEEDVR2_NON_GOAL_STATIC_AUDIT") != "1",
|
||||||
|
reason="SEEDVR2_NON_GOAL_STATIC_AUDIT=1 is required for git-index audit execution.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _git_changed_paths(*args):
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "-C", str(ROOT), "diff", "--name-only", *args],
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
pytest.skip(f"git diff unavailable: {result.stderr.strip()}")
|
||||||
|
return set(result.stdout.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_non_goal_files_are_not_dirty():
|
||||||
|
changed = _git_changed_paths()
|
||||||
|
changed.update(_git_changed_paths("--cached"))
|
||||||
|
changed_forbidden = sorted(FORBIDDEN_FILES.intersection(changed))
|
||||||
|
if changed_forbidden:
|
||||||
|
pytest.fail(f"forbidden non-goal files changed: {changed_forbidden}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(pytest.main([__file__]))
|
||||||
@ -0,0 +1,110 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_simple_multiplier_resolves_upscaled_shorter_edge():
|
||||||
|
images = torch.zeros(1, 3, 16, 20, 3)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2Resize.execute(images, 4.0)
|
||||||
|
|
||||||
|
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||||
|
assert tuple(input_pixels.shape) == (1, 5, 64, 80, 3)
|
||||||
|
assert input_pixels.min().item() == 0.0
|
||||||
|
assert input_pixels.max().item() == 0.0
|
||||||
|
assert original_image is images
|
||||||
|
assert upscaled_shorter_edge == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_simple_silent_spatial_padding_keeps_unpadded_edge_output():
|
||||||
|
images = torch.zeros(1, 1, 16, 16, 3)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2Resize.execute(images, 7.5)
|
||||||
|
|
||||||
|
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||||
|
assert tuple(input_pixels.shape) == (1, 1, 128, 128, 3)
|
||||||
|
assert original_image is images
|
||||||
|
assert upscaled_shorter_edge == 120
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_simple_rejects_non_positive_multiplier():
|
||||||
|
images = torch.zeros(1, 1, 16, 16, 3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2Resize.execute(images, 0.0)
|
||||||
|
except ValueError as e:
|
||||||
|
assert "multiplier must be > 0" in str(e)
|
||||||
|
else:
|
||||||
|
raise AssertionError("non-positive multiplier was not rejected")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_simple_rejects_multiplier_resolving_to_too_small_edge():
|
||||||
|
images = torch.zeros(1, 1, 16, 16, 3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2Resize.execute(images, 0.01)
|
||||||
|
except ValueError as e:
|
||||||
|
assert "multiplier resolved upscaled_shorter_edge" in str(e)
|
||||||
|
assert "at least 2 pixels" in str(e)
|
||||||
|
else:
|
||||||
|
raise AssertionError("too-small resolved edge was not rejected")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_advanced_takes_exact_shorter_edge():
|
||||||
|
images = torch.zeros(1, 1, 16, 16, 3)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 120)
|
||||||
|
|
||||||
|
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||||
|
assert tuple(input_pixels.shape) == (1, 1, 128, 128, 3)
|
||||||
|
assert original_image is images
|
||||||
|
assert upscaled_shorter_edge == 120
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_advanced_treats_4d_image_as_one_video_frame_sequence():
|
||||||
|
images = torch.zeros(2, 16, 16, 3)
|
||||||
|
|
||||||
|
output = nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 120)
|
||||||
|
|
||||||
|
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||||
|
assert tuple(input_pixels.shape) == (1, 5, 128, 128, 3)
|
||||||
|
assert original_image is images
|
||||||
|
assert upscaled_shorter_edge == 120
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_advanced_rejects_one_pixel_shorter_edge():
|
||||||
|
images = torch.zeros(1, 1, 16, 16, 3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 1)
|
||||||
|
except ValueError as e:
|
||||||
|
assert "upscaled_shorter_edge must be at least 2 pixels" in str(e)
|
||||||
|
else:
|
||||||
|
raise AssertionError("one-pixel shorter_edge was not rejected")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_node_schemas_and_execute_signatures_are_preprocess_only():
|
||||||
|
simple = nodes_seedvr.SeedVR2Resize.define_schema()
|
||||||
|
advanced = nodes_seedvr.SeedVR2ResizeAdvanced.define_schema()
|
||||||
|
|
||||||
|
assert [item.id for item in simple.inputs] == ["images", "multiplier"]
|
||||||
|
assert simple.inputs[1].default == 4.0
|
||||||
|
assert [item.id for item in simple.outputs] == [
|
||||||
|
"input_pixels",
|
||||||
|
"original_image",
|
||||||
|
"upscaled_shorter_edge",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert [item.id for item in advanced.inputs] == ["images", "shorter_edge"]
|
||||||
|
assert advanced.inputs[1].min == 2
|
||||||
|
assert advanced.inputs[1].step is None
|
||||||
|
assert [item.id for item in advanced.outputs] == [
|
||||||
|
"input_pixels",
|
||||||
|
"original_image",
|
||||||
|
"upscaled_shorter_edge",
|
||||||
|
]
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
import io
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
from comfy_extras import nodes_seedvr # noqa: E402
|
||||||
|
import nodes as nodes_mod # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class _DecodeOnlyVAE:
|
||||||
|
def __init__(self):
|
||||||
|
self.decode_calls = 0
|
||||||
|
|
||||||
|
def decode(self, latent):
|
||||||
|
self.decode_calls += 1
|
||||||
|
b, tc, h, w = latent.shape
|
||||||
|
t = tc // 16
|
||||||
|
return torch.full((b, t, h * 8, w * 8, 3), 0.25)
|
||||||
|
|
||||||
|
|
||||||
|
def test_saved_loaded_seedvr2_latent_decode_boundary_does_not_rerun_preprocessing():
|
||||||
|
latent = {"samples": torch.zeros(1, 32, 4, 5)}
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(latent["samples"], buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
loaded = {"samples": torch.load(buffer, weights_only=True)}
|
||||||
|
|
||||||
|
vae = _DecodeOnlyVAE()
|
||||||
|
decoded = nodes_mod.VAEDecode().decode(vae, loaded)[0]
|
||||||
|
original = torch.full((1, 2, 32, 40, 3), 0.75)
|
||||||
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 32, "none").result[0]
|
||||||
|
|
||||||
|
assert vae.decode_calls == 1
|
||||||
|
assert tuple(output.shape) == (2, 32, 40, 3)
|
||||||
210
tests-unit/comfy_test/test_seedvr2_vae_graph_boundaries.py
Normal file
210
tests-unit/comfy_test/test_seedvr2_vae_graph_boundaries.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||||
|
import comfy.sd as sd_mod # noqa: E402
|
||||||
|
import nodes as nodes_mod # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class _Patcher:
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
return 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||||
|
def __init__(self, encoded):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.encoded = encoded
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.seen = []
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
self.seen.append(tuple(x.shape))
|
||||||
|
return self.encoded.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def decode(self, z, seedvr2_tiling=None):
|
||||||
|
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
||||||
|
if z.ndim == 4:
|
||||||
|
b, tc, h, w = z.shape
|
||||||
|
t = tc // 16
|
||||||
|
else:
|
||||||
|
b, _, t, h, w = z.shape
|
||||||
|
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vae(wrapper):
|
||||||
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||||
|
vae.first_stage_model = wrapper
|
||||||
|
vae.device = torch.device("cpu")
|
||||||
|
vae.output_device = torch.device("cpu")
|
||||||
|
vae.vae_dtype = torch.float32
|
||||||
|
vae.latent_channels = 16
|
||||||
|
vae.latent_dim = 3
|
||||||
|
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
||||||
|
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
vae.output_channels = 3
|
||||||
|
vae.disable_offload = True
|
||||||
|
vae.extra_1d_channel = None
|
||||||
|
vae.crop_input = False
|
||||||
|
vae.not_video = False
|
||||||
|
vae.patcher = _Patcher()
|
||||||
|
vae.process_input = lambda image: image
|
||||||
|
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||||
|
vae.vae_output_dtype = lambda: torch.float32
|
||||||
|
vae.memory_used_encode = lambda shape, dtype: 1
|
||||||
|
vae.memory_used_decode = lambda shape, dtype: 1
|
||||||
|
vae.throw_exception_if_invalid = lambda: None
|
||||||
|
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
||||||
|
vae.spacial_compression_decode = lambda: 8
|
||||||
|
vae.temporal_compression_decode = lambda: 4
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
|
||||||
|
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
||||||
|
vae = _make_vae(_EncodeWrapper(encoded))
|
||||||
|
pixels = torch.zeros(1, 5, 32, 40, 3)
|
||||||
|
|
||||||
|
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
||||||
|
node_latent = node_output["samples"]
|
||||||
|
assert set(node_output) == {"samples"}
|
||||||
|
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
||||||
|
assert node_latent.dtype == torch.float32
|
||||||
|
assert node_latent.stride()[-1] == 1
|
||||||
|
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
||||||
|
|
||||||
|
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
||||||
|
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
||||||
|
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
||||||
|
vae,
|
||||||
|
pixels,
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)[0]
|
||||||
|
tiled_latent = tiled_output["samples"]
|
||||||
|
assert set(tiled_output) == {"samples"}
|
||||||
|
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
||||||
|
assert tiled_latent.dtype == torch.float32
|
||||||
|
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_decode_and_decode_tiled_do_not_require_preprocessor_state(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
latent = {"samples": torch.zeros(1, 32, 4, 5)}
|
||||||
|
decoded = nodes_mod.VAEDecode().decode(vae, latent)[0]
|
||||||
|
assert tuple(decoded.shape) == (2, 32, 40, 3)
|
||||||
|
|
||||||
|
tiled = nodes_mod.VAEDecodeTiled().decode(
|
||||||
|
vae,
|
||||||
|
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)[0]
|
||||||
|
assert tuple(tiled.shape) == (2, 32, 40, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vaedecode_does_not_repair_latent_layout(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
latent = {"samples": torch.zeros(1, 2, 4, 5, 16)}
|
||||||
|
nodes_mod.VAEDecode().decode(vae, latent)
|
||||||
|
|
||||||
|
assert vae.first_stage_model.calls == [{"shape": (1, 2, 4, 5, 16), "seedvr2_tiling": None}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vaedecode_keeps_public_channel_first_width_16_latents(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
nodes_mod.VAEDecode().decode(
|
||||||
|
vae,
|
||||||
|
{"samples": torch.zeros(1, 16, 4, 5, 16)},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vae.first_stage_model.calls == [{"shape": (1, 16, 4, 5, 16), "seedvr2_tiling": None}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_direct_decode_preserves_channel_first_width_16(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
vae.decode(torch.zeros(1, 16, 2, 4, 16))
|
||||||
|
|
||||||
|
assert vae.first_stage_model.calls == [{"shape": (1, 16, 2, 4, 16), "seedvr2_tiling": None}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_decode_tiled_preserves_direct_channel_first_width_16(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
vae.decode_tiled_seedvr2(torch.zeros(1, 16, 2, 4, 16))
|
||||||
|
|
||||||
|
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 2, 4, 16)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vaedecode_tiled_keeps_public_channel_first_width_16_latents(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
nodes_mod.VAEDecodeTiled().decode(
|
||||||
|
vae,
|
||||||
|
{"samples": torch.zeros(1, 16, 4, 5, 16)},
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 4, 5, 16)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
nodes_mod.VAEDecodeTiled().decode(
|
||||||
|
vae,
|
||||||
|
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vae.first_stage_model.calls == [
|
||||||
|
{
|
||||||
|
"shape": (1, 16, 2, 4, 5),
|
||||||
|
"seedvr2_tiling": {
|
||||||
|
"enable_tiling": True,
|
||||||
|
"tile_size": (512, 512),
|
||||||
|
"tile_overlap": (64, 64),
|
||||||
|
"temporal_size": 16,
|
||||||
|
"temporal_overlap": 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
40
tests-unit/comfy_test/test_seedvr2_windows_static_verify.py
Normal file
40
tests-unit/comfy_test/test_seedvr2_windows_static_verify.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
def _read(relative):
|
||||||
|
return (ROOT / relative).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_windows_static_contract_tokens():
|
||||||
|
nodes = _read("comfy_extras/nodes_seedvr.py")
|
||||||
|
sd = _read("comfy/sd.py")
|
||||||
|
vae = _read("comfy/ldm/seedvr/vae.py")
|
||||||
|
|
||||||
|
required = [
|
||||||
|
"SeedVR2Resize",
|
||||||
|
"SeedVR2ResizeAdvanced",
|
||||||
|
"SeedVR2PostProcessing",
|
||||||
|
'io.Image.Input("decoded")',
|
||||||
|
'io.Image.Input("original_image")',
|
||||||
|
'io.Int.Input("upscaled_shorter_edge", min=2, force_input=True)',
|
||||||
|
'io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab")',
|
||||||
|
"def _format_seedvr2_encoded_samples",
|
||||||
|
"def decode(self, z, seedvr2_tiling=None)",
|
||||||
|
]
|
||||||
|
for needle in required:
|
||||||
|
if needle not in nodes + sd + vae:
|
||||||
|
pytest.fail(f"missing required static token: {needle}")
|
||||||
|
|
||||||
|
forbidden = ["original_image_video", "img_dims", "tiled_args"]
|
||||||
|
for needle in forbidden:
|
||||||
|
if needle in nodes + sd + vae:
|
||||||
|
pytest.fail(f"forbidden hidden-state token remains: {needle}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(pytest.main([__file__]))
|
||||||
1070
tests-unit/comfy_test/test_seedvr_progressive_sampler.py
Normal file
1070
tests-unit/comfy_test/test_seedvr_progressive_sampler.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user