mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Add SeedVR2 node and sampler coverage
This commit is contained in:
parent
c3bfb743e8
commit
8ac1b59107
58
tests-unit/comfy_extras_test/test_seedvr2_node_boundaries.py
Normal file
58
tests-unit/comfy_extras_test/test_seedvr2_node_boundaries.py
Normal file
@ -0,0 +1,58 @@
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def _schema_ids(items):
|
||||
return [item.id for item in items]
|
||||
|
||||
|
||||
def test_resize_schemas_are_preprocess_only():
|
||||
simple = nodes_seedvr.SeedVR2Resize.define_schema()
|
||||
advanced = nodes_seedvr.SeedVR2ResizeAdvanced.define_schema()
|
||||
|
||||
assert _schema_ids(simple.inputs) == ["images", "multiplier"]
|
||||
assert _schema_ids(simple.outputs) == ["input_pixels", "original_image", "upscaled_shorter_edge"]
|
||||
assert simple.outputs[0].get_io_type() == "IMAGE"
|
||||
|
||||
assert _schema_ids(advanced.inputs) == ["images", "shorter_edge"]
|
||||
assert _schema_ids(advanced.outputs) == ["input_pixels", "original_image", "upscaled_shorter_edge"]
|
||||
assert advanced.outputs[0].get_io_type() == "IMAGE"
|
||||
|
||||
|
||||
def test_resize_nodes_do_not_call_encode_decode_or_color_transfer():
|
||||
source = "\n".join(
|
||||
[
|
||||
inspect.getsource(nodes_seedvr.SeedVR2Resize.execute),
|
||||
inspect.getsource(nodes_seedvr.SeedVR2ResizeAdvanced.execute),
|
||||
]
|
||||
)
|
||||
tree = ast.parse(textwrap.dedent(source))
|
||||
forbidden_names = {
|
||||
"encode",
|
||||
"encode_tiled",
|
||||
"decode",
|
||||
"decode_tiled",
|
||||
"tiled_vae",
|
||||
"lab_color_transfer",
|
||||
}
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
func = node.func
|
||||
if isinstance(func, ast.Name):
|
||||
name = func.id
|
||||
elif isinstance(func, ast.Attribute):
|
||||
name = func.attr
|
||||
else:
|
||||
continue
|
||||
assert name not in forbidden_names
|
||||
461
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
461
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
@ -0,0 +1,461 @@
|
||||
import inspect
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def _schema_ids(items):
|
||||
return [item.id for item in items]
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_schema():
|
||||
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
||||
|
||||
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"]
|
||||
assert schema.inputs[2].default is None
|
||||
assert schema.inputs[2].min == 2
|
||||
assert schema.inputs[2].force_input is True
|
||||
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
|
||||
assert schema.inputs[3].default == "lab"
|
||||
assert schema.outputs[0].get_io_type() == "IMAGE"
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_color_correction_memory_multipliers_are_named():
|
||||
assert nodes_seedvr.LAB_SCALE_MULTIPLIER == 13
|
||||
assert nodes_seedvr.WAVELET_SCALE_MULTIPLIER == 10
|
||||
assert nodes_seedvr.ADAIN_SCALE_MULTIPLIER == 6
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_lab_autochunks_from_memory_estimate(monkeypatch):
|
||||
decoded = torch.full((1, 5, 2, 2, 3), 0.25)
|
||||
original = torch.full((1, 5, 2, 2, 3), 0.75)
|
||||
calls = []
|
||||
|
||||
def _lab(content, style):
|
||||
calls.append(content.shape[0])
|
||||
return content
|
||||
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1700)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
|
||||
|
||||
assert calls == [1, 1, 1, 1, 1]
|
||||
assert tuple(output.shape) == (1, 5, 2, 2, 3)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_lab_runs_each_frame_independently(monkeypatch):
|
||||
decoded = torch.full((1, 4, 2, 2, 3), 0.25)
|
||||
original = torch.full((1, 4, 2, 2, 3), 0.75)
|
||||
calls = []
|
||||
|
||||
def _lab(content, style):
|
||||
calls.append(content.shape[0])
|
||||
return content
|
||||
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
|
||||
|
||||
assert calls == [1, 1, 1, 1]
|
||||
assert tuple(output.shape) == (1, 4, 2, 2, 3)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_lab_derives_reference_from_original_and_upscaled_shorter_edge():
|
||||
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
||||
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||
calls = []
|
||||
|
||||
def _lab(content, style):
|
||||
calls.append((content.clone(), style.clone()))
|
||||
return torch.zeros_like(content)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||
assert len(calls) == 2
|
||||
assert calls[0][0].shape == (1, 3, 8, 10)
|
||||
assert calls[0][1].shape == (1, 3, 8, 10)
|
||||
assert torch.equal(calls[0][0], torch.full_like(calls[0][0], -0.5))
|
||||
assert torch.allclose(calls[0][1], torch.full_like(calls[0][1], 0.5))
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_lab_runs_color_transfer_on_vae_device():
|
||||
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing.execute)
|
||||
chunk_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks)
|
||||
helper_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._lab_color_transfer_on_vae_device)
|
||||
|
||||
assert "_color_transfer_chunked" in source
|
||||
assert "_lab_color_transfer_on_vae_device" in chunk_source
|
||||
assert "torch.cat" not in chunk_source
|
||||
assert "torch.empty" in chunk_source
|
||||
assert ".copy_(" in chunk_source
|
||||
assert "reference_5d.to(device=decoded_5d.device)" not in source
|
||||
assert "comfy.model_management.vae_device()" in helper_source
|
||||
assert ".to(device=color_device)" in helper_source
|
||||
assert ".to(device=output_device)" in helper_source
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_lab_chunking_is_frame_independent(monkeypatch):
|
||||
decoded = torch.linspace(-0.9, 0.9, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
|
||||
reference = torch.linspace(0.8, -0.8, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
|
||||
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||
|
||||
one_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
|
||||
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 1,
|
||||
)
|
||||
multi_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
|
||||
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 3,
|
||||
)
|
||||
|
||||
assert torch.equal(one_frame, multi_frame)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_lab_retry_does_not_mutate_reference(monkeypatch):
|
||||
decoded = torch.full((2, 3, 4, 4), 0.25)
|
||||
reference = torch.full((2, 3, 4, 4), 0.75)
|
||||
original_reference = reference.clone()
|
||||
calls = []
|
||||
cache_clears = []
|
||||
|
||||
def _lab(content, style):
|
||||
calls.append((content.clone(), style.clone()))
|
||||
style.add_(10.0)
|
||||
if len(calls) == 1:
|
||||
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
||||
return content
|
||||
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: cache_clears.append(True))
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||
decoded, reference, torch.device("cpu"), "lab",
|
||||
)
|
||||
|
||||
assert len(cache_clears) == 1
|
||||
assert torch.equal(reference, original_reference)
|
||||
assert torch.equal(calls[1][1], original_reference[0:1])
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
|
||||
decoded = torch.full((1, 3, 4, 4), 0.25)
|
||||
reference = torch.full((1, 3, 4, 4), 0.75)
|
||||
|
||||
def _lab(content, style):
|
||||
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
||||
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
try:
|
||||
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||
decoded, reference, torch.device("cpu"), "lab",
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
assert "color_correction_method=lab" in str(exc)
|
||||
assert " method=lab" not in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_raw_conversion_does_not_probe_full_tensor_range():
|
||||
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._to_seedvr2_raw)
|
||||
|
||||
assert ".amin" not in source
|
||||
assert ".item" not in source
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_none_does_not_resize_reference_pixels():
|
||||
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
|
||||
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||
|
||||
with patch.object(nodes_seedvr, "side_resize") as resize:
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||
|
||||
resize.assert_not_called()
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_rejects_invalid_upscaled_shorter_edge():
|
||||
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
|
||||
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||
|
||||
for edge in (None, 1, 1.5):
|
||||
try:
|
||||
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, edge, "none")
|
||||
except ValueError as exc:
|
||||
assert "upscaled_shorter_edge" in str(exc)
|
||||
else:
|
||||
raise AssertionError(f"expected ValueError for upscaled_shorter_edge={edge!r}")
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_lab_resizes_full_reference_frame():
|
||||
decoded = torch.full((1, 2, 4, 5, 3), 0.25)
|
||||
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||
resize_calls = []
|
||||
lab_calls = []
|
||||
|
||||
def _resize(images, size, interpolation=None, antialias=None):
|
||||
resize_calls.append((images.clone(), size, interpolation, antialias))
|
||||
if isinstance(size, int):
|
||||
return torch.full((2, 3, size, round(images.shape[-1] * size / images.shape[-2])), 0.5)
|
||||
return torch.full((2, 3, size[0], size[1]), 0.5)
|
||||
|
||||
def _lab(content, style):
|
||||
lab_calls.append((content.clone(), style.clone()))
|
||||
return torch.zeros_like(content)
|
||||
|
||||
with patch.object(nodes_seedvr.TVF, "resize", _resize):
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 2, 4, 4, 3)
|
||||
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||
assert resize_calls[0][0].shape == (2, 3, 16, 20)
|
||||
assert resize_calls[0][1] == 8
|
||||
assert resize_calls[1][0].shape == (2, 3, 8, 10)
|
||||
assert resize_calls[1][1] == (4, 5)
|
||||
assert len(lab_calls) == 2
|
||||
assert lab_calls[0][1].shape == (1, 3, 4, 5)
|
||||
assert torch.equal(lab_calls[0][1], torch.zeros_like(lab_calls[0][1]))
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_none_trims_and_crops_without_color_correction():
|
||||
decoded = torch.arange(1 * 3 * 9 * 11 * 3, dtype=torch.float32).reshape(1, 3, 9, 11, 3)
|
||||
original = torch.zeros(1, 2, 16, 20, 3)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||
|
||||
assert lab.call_count == 0
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
assert torch.equal(output, decoded[:, :2, :8, :10, :])
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_restores_flattened_padded_batches_before_trimming():
|
||||
decoded = torch.arange(10 * 4 * 6 * 1, dtype=torch.float32).reshape(10, 4, 6, 1)
|
||||
original = torch.zeros(2, 2, 4, 6, 1)
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0]
|
||||
|
||||
expected = torch.cat((decoded[0:2], decoded[5:7]), dim=0)
|
||||
assert tuple(output.shape) == (4, 4, 6, 1)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_none_preserves_decoded_spatial_size_when_reference_is_larger():
|
||||
decoded = torch.arange(1 * 3 * 8 * 10 * 3, dtype=torch.float32).reshape(1, 3, 8, 10, 3)
|
||||
original = torch.zeros(1, 2, 16, 20, 3)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 16, "none").result[0]
|
||||
|
||||
assert lab.call_count == 0
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
assert torch.equal(output, decoded[:, :2, :, :, :])
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_crops_to_reference_tensor_when_reference_is_smaller():
|
||||
decoded = torch.ones((1, 1, 720, 1280, 3), dtype=torch.float32)
|
||||
original = torch.ones((1, 1, 360, 640, 3), dtype=torch.float32)
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 360, "none").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 1, 360, 640, 3)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_uses_decoded_size_when_reference_is_larger():
|
||||
decoded = torch.ones((1, 1, 128, 160, 3), dtype=torch.float32)
|
||||
original = torch.ones((1, 1, 480, 640, 3), dtype=torch.float32)
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 1, 128, 160, 3)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_derives_crop_from_upscaled_shorter_edge():
|
||||
decoded = torch.ones((1, 1, 128, 224, 3), dtype=torch.float32)
|
||||
original = torch.ones((1, 1, 1080, 1920, 3), dtype=torch.float32)
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 1, 120, 212, 3)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_uses_even_crop_from_odd_resized_width():
|
||||
decoded = torch.ones((1, 1, 128, 256, 3), dtype=torch.float32)
|
||||
original = torch.ones((1, 1, 120, 169, 3), dtype=torch.float32)
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 1, 120, 168, 3)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_none_preserves_black_bottom_row_content():
|
||||
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||
original[:, :, -1, :, :] = -1.0
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
assert torch.equal(output, decoded)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_none_preserves_black_right_column_content():
|
||||
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
||||
original[:, :, :, -1, :] = -1.0
|
||||
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
||||
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
assert torch.equal(output, decoded)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_wavelet_dispatch_routes_through_wavelet_color_transfer():
|
||||
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
||||
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||
wavelet_calls = []
|
||||
lab_calls = []
|
||||
|
||||
def _wavelet(content, style):
|
||||
wavelet_calls.append((content.clone(), style.clone()))
|
||||
return torch.zeros_like(content)
|
||||
|
||||
def _lab(content, style):
|
||||
lab_calls.append((content.clone(), style.clone()))
|
||||
return torch.zeros_like(content)
|
||||
|
||||
with patch.object(nodes_seedvr, "wavelet_color_transfer", _wavelet):
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "wavelet").result[0]
|
||||
|
||||
assert len(wavelet_calls) == 1
|
||||
assert len(lab_calls) == 0
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||
assert wavelet_calls[0][0].shape == (2, 3, 8, 10)
|
||||
assert wavelet_calls[0][1].shape == (2, 3, 8, 10)
|
||||
assert torch.equal(wavelet_calls[0][0], torch.full_like(wavelet_calls[0][0], -0.5))
|
||||
assert torch.allclose(wavelet_calls[0][1], torch.full_like(wavelet_calls[0][1], 0.5))
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_adain_dispatch_routes_through_adain_color_transfer():
|
||||
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
||||
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
||||
adain_calls = []
|
||||
lab_calls = []
|
||||
|
||||
def _adain(content, style):
|
||||
adain_calls.append((content.clone(), style.clone()))
|
||||
return torch.zeros_like(content)
|
||||
|
||||
def _lab(content, style):
|
||||
lab_calls.append((content.clone(), style.clone()))
|
||||
return torch.zeros_like(content)
|
||||
|
||||
with patch.object(nodes_seedvr, "adain_color_transfer", _adain):
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "adain").result[0]
|
||||
|
||||
assert len(adain_calls) == 1
|
||||
assert len(lab_calls) == 0
|
||||
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
||||
assert torch.equal(output, torch.full_like(output, 0.5))
|
||||
assert adain_calls[0][0].shape == (2, 3, 8, 10)
|
||||
assert adain_calls[0][1].shape == (2, 3, 8, 10)
|
||||
|
||||
|
||||
def test_seedvr2_color_transfer_helper_runs_on_vae_device():
|
||||
import inspect as _inspect
|
||||
helper_source = _inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._color_transfer_on_vae_device)
|
||||
assert "comfy.model_management.vae_device()" in helper_source
|
||||
assert ".to(device=color_device)" in helper_source
|
||||
assert ".to(device=output_device)" in helper_source
|
||||
assert "transfer_fn" in helper_source
|
||||
|
||||
|
||||
def test_seedvr2_wavelet_color_transfer_matches_primary_source_reconstruction():
|
||||
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||
torch.manual_seed(0)
|
||||
content = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
|
||||
style = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
|
||||
out = seedvr_vae.wavelet_color_transfer(content, style)
|
||||
expected = seedvr_vae.wavelet_reconstruction(content.clone(), style.clone())
|
||||
assert torch.equal(out, expected)
|
||||
|
||||
|
||||
def test_seedvr2_adain_color_transfer_matches_huang_belongie_formula():
|
||||
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||
torch.manual_seed(0)
|
||||
content = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
|
||||
style = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
|
||||
out = seedvr_vae.adain_color_transfer(content.clone(), style.clone())
|
||||
|
||||
b, c = 2, 3
|
||||
cf = content.float().reshape(b, c, -1)
|
||||
sf = style.float().reshape(b, c, -1)
|
||||
eps = 1e-5
|
||||
mu_c = cf.mean(dim=2).reshape(b, c, 1, 1)
|
||||
sd_c = (cf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
mu_s = sf.mean(dim=2).reshape(b, c, 1, 1)
|
||||
sd_s = (sf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
expected = ((content.float() - mu_c) / sd_c) * sd_s + mu_s
|
||||
expected = expected.clamp(-1.0, 1.0)
|
||||
assert torch.allclose(out, expected, atol=1e-6)
|
||||
|
||||
|
||||
def test_seedvr2_adain_single_pixel_uses_population_variance_without_nan():
|
||||
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||
content = torch.tensor([[[[0.25]], [[-0.5]], [[0.75]]]], dtype=torch.float32)
|
||||
style = torch.tensor([[[[-0.25]], [[0.5]], [[-0.75]]]], dtype=torch.float32)
|
||||
|
||||
out = seedvr_vae.adain_color_transfer(content, style)
|
||||
|
||||
assert torch.isfinite(out).all()
|
||||
assert torch.equal(out, style)
|
||||
|
||||
|
||||
def test_seedvr2_adain_preserves_input_dtype():
|
||||
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||
content = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
|
||||
style = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
|
||||
out = seedvr_vae.adain_color_transfer(content, style)
|
||||
assert out.dtype == torch.float16
|
||||
|
||||
|
||||
def test_seedvr2_adain_resizes_mismatched_style_to_content_shape():
|
||||
from comfy.ldm.seedvr import vae as seedvr_vae
|
||||
content = torch.rand(1, 3, 8, 10) * 2.0 - 1.0
|
||||
style = torch.rand(1, 3, 16, 20) * 2.0 - 1.0
|
||||
out = seedvr_vae.adain_color_transfer(content, style)
|
||||
assert tuple(out.shape) == (1, 3, 8, 10)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
||||
decoded = torch.zeros(1, 2, 4, 4, 3)
|
||||
original = torch.zeros(1, 2, 4, 4, 3)
|
||||
try:
|
||||
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus")
|
||||
except ValueError as exc:
|
||||
assert "color_correction_method" in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected ValueError for unknown color_correction_method")
|
||||
@ -0,0 +1,601 @@
|
||||
"""Regression tests for SeedVR2 conditioning model resolution and RoPE
|
||||
frequency cast.
|
||||
|
||||
Pin two behaviors:
|
||||
|
||||
1. ``_resolve_seedvr2_diffusion_model`` returns the inner diffusion-model
|
||||
for the expected ``model.model.diffusion_model`` shape and fails loud
|
||||
with a ``RuntimeError`` whose message begins with
|
||||
``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` for any other shape, including
|
||||
the four distinct missing-vs-None subcases of the chain.
|
||||
2. ``_apply_rope_freqs_float32_cast`` is idempotent **per-tensor by
|
||||
dtype check**, NOT per-instance by sentinel attribute. Every call
|
||||
walks the diffusion-model module tree and invokes ``.to(float32)``
|
||||
only on tensors whose dtype is not already ``float32``. A cache-by-
|
||||
attribute (sentinel) approach is rejected because the sentinel
|
||||
would survive ComfyUI's dynamic model unload/reload cycle while
|
||||
``rope.freqs`` itself is restored to the archived dtype, so the
|
||||
next call would short-circuit and leave RoPE running in fp16/bf16
|
||||
— the exact failure this helper is supposed to prevent. The dtype
|
||||
check is self-correcting against any weight-restore lifecycle
|
||||
event.
|
||||
|
||||
Import isolation: ``comfy.model_management`` is stubbed via direct
|
||||
``sys.modules`` assignment so importing ``comfy_extras.nodes_seedvr`` does
|
||||
not trigger GPU/server-side initialization. ``patch.dict`` is intentionally
|
||||
NOT used here because its snapshot/restore semantics evict transitively
|
||||
imported third-party modules (e.g. ``torchvision``) on exit, which causes
|
||||
``torch``'s global op-library Meta-key registrations to double-register on
|
||||
re-import. Module-level cached import + scoped restore of the four mocked
|
||||
entries avoids that hazard. See ``_import_nodes_seedvr_isolated``.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def _import_nodes_seedvr_isolated():
|
||||
"""Stub ``comfy.model_management``, import (or reuse a cached import of)
|
||||
``comfy_extras.nodes_seedvr``, and return ``(module, restore)``.
|
||||
|
||||
``restore()`` snapshots and restores three in-process import-state
|
||||
surfaces:
|
||||
|
||||
1. ``sys.modules["comfy.model_management"]`` — the stubbed module.
|
||||
2. ``sys.modules["comfy_extras.nodes_seedvr"]`` — the imported test
|
||||
target. If we leave this in ``sys.modules`` after the test, a
|
||||
later test importing the real ``comfy_extras.nodes_seedvr`` will
|
||||
get our stubbed-``comfy.model_management`` cached version, which
|
||||
does not re-resolve against the real ``comfy.model_management``.
|
||||
3. ``comfy_extras.nodes_seedvr`` package attribute on the
|
||||
``comfy_extras`` package, mirroring the existing
|
||||
``comfy.model_management`` attribute restore.
|
||||
|
||||
All three are restored verbatim if previously set; deleted on exit
|
||||
if previously unset. No global state leaks into later tests.
|
||||
"""
|
||||
prior_comfy_mm = sys.modules.get("comfy.model_management", _SENTINEL)
|
||||
prior_comfy_mm_attr = _SENTINEL
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
if comfy_pkg is not None:
|
||||
prior_comfy_mm_attr = getattr(comfy_pkg, "model_management", _SENTINEL)
|
||||
prior_nodes_seedvr_module = sys.modules.get(
|
||||
"comfy_extras.nodes_seedvr", _SENTINEL,
|
||||
)
|
||||
prior_nodes_seedvr_attr = _SENTINEL
|
||||
comfy_extras_pkg = sys.modules.get("comfy_extras")
|
||||
if comfy_extras_pkg is not None:
|
||||
prior_nodes_seedvr_attr = getattr(
|
||||
comfy_extras_pkg, "nodes_seedvr", _SENTINEL,
|
||||
)
|
||||
|
||||
# ``comfy_extras.nodes_seedvr`` imports ``comfy.sample`` (added in PR
|
||||
# #59) which pulls in the full samplers/k_diffusion/model_patcher
|
||||
# transitive chain. That chain re-imports ``comfy.model_management``
|
||||
# and calls feature-detection predicates like ``xformers_enabled()``
|
||||
# in module-init code (``comfy/ldm/modules/attention.py:18``); a bare
|
||||
# ``MagicMock()`` returns truthy for those calls and triggers a real
|
||||
# ``import xformers`` that fails in the test environment. Pin the
|
||||
# boolean-returning predicates to ``False`` so the import chain
|
||||
# follows the no-extension path.
|
||||
# Configure stub so every ``..._enabled[_*]()`` predicate returns
|
||||
# False. The transitive import chain through ``comfy.sample`` → ...
|
||||
# invokes several feature-detection predicates at module-init time
|
||||
# (``comfy/ldm/modules/attention.py`` ``xformers_enabled()``,
|
||||
# ``comfy/ldm/modules/diffusionmodules/model.py``
|
||||
# ``xformers_enabled_vae()``, etc.). A bare ``MagicMock()`` returns
|
||||
# truthy auto-attrs, which triggers real ``import xformers`` calls
|
||||
# that fail in the test environment.
|
||||
mock_mm = MagicMock()
|
||||
mock_mm.xformers_enabled.return_value = False
|
||||
mock_mm.xformers_enabled_vae.return_value = False
|
||||
mock_mm.pytorch_attention_enabled.return_value = False
|
||||
mock_mm.pytorch_attention_enabled_vae.return_value = False
|
||||
mock_mm.sage_attention_enabled.return_value = False
|
||||
mock_mm.flash_attention_enabled.return_value = False
|
||||
torch_version_parts = torch.version.__version__.split(".")
|
||||
mock_mm.torch_version_numeric = (
|
||||
int(torch_version_parts[0]),
|
||||
int(torch_version_parts[1]),
|
||||
)
|
||||
mock_mm.WINDOWS = False
|
||||
mock_mm.is_intel_xpu.return_value = False
|
||||
sys.modules["comfy.model_management"] = mock_mm
|
||||
# The transitive import chain reaches code paths that do
|
||||
# ``comfy.model_management.<attr>`` (attribute access on the comfy
|
||||
# package, not a fresh import). Setting only ``sys.modules`` is not
|
||||
# enough — also bind the stub as the package attribute. If the
|
||||
# ``comfy`` package isn't imported yet at stub-time (cold first run),
|
||||
# importing it now is safe and idempotent.
|
||||
if comfy_pkg is None:
|
||||
import comfy as _comfy_pkg # noqa: F401
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
if comfy_pkg is not None:
|
||||
setattr(comfy_pkg, "model_management", mock_mm)
|
||||
if "comfy_extras.nodes_seedvr" in sys.modules:
|
||||
nodes_seedvr = sys.modules["comfy_extras.nodes_seedvr"]
|
||||
else:
|
||||
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||
|
||||
def _restore():
|
||||
# 1. comfy.model_management sys.modules entry
|
||||
if prior_comfy_mm is _SENTINEL:
|
||||
sys.modules.pop("comfy.model_management", None)
|
||||
else:
|
||||
sys.modules["comfy.model_management"] = prior_comfy_mm
|
||||
# 2. comfy.model_management package attribute on comfy
|
||||
comfy_pkg_now = sys.modules.get("comfy")
|
||||
if comfy_pkg_now is not None:
|
||||
if prior_comfy_mm_attr is _SENTINEL:
|
||||
if hasattr(comfy_pkg_now, "model_management"):
|
||||
delattr(comfy_pkg_now, "model_management")
|
||||
else:
|
||||
setattr(comfy_pkg_now, "model_management", prior_comfy_mm_attr)
|
||||
# 3. comfy_extras.nodes_seedvr sys.modules entry
|
||||
if prior_nodes_seedvr_module is _SENTINEL:
|
||||
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||
else:
|
||||
sys.modules["comfy_extras.nodes_seedvr"] = prior_nodes_seedvr_module
|
||||
# 4. comfy_extras.nodes_seedvr package attribute on comfy_extras
|
||||
comfy_extras_pkg_now = sys.modules.get("comfy_extras")
|
||||
if comfy_extras_pkg_now is not None:
|
||||
if prior_nodes_seedvr_attr is _SENTINEL:
|
||||
if hasattr(comfy_extras_pkg_now, "nodes_seedvr"):
|
||||
delattr(comfy_extras_pkg_now, "nodes_seedvr")
|
||||
else:
|
||||
setattr(
|
||||
comfy_extras_pkg_now, "nodes_seedvr",
|
||||
prior_nodes_seedvr_attr,
|
||||
)
|
||||
|
||||
return nodes_seedvr, _restore
|
||||
|
||||
|
||||
class _Rope(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.freqs = nn.Parameter(torch.zeros(4))
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rope = _Rope()
|
||||
|
||||
|
||||
class _DiffusionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_blocks=3,
|
||||
zero_conditioning=False,
|
||||
conditioning_dtype=torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
||||
if zero_conditioning:
|
||||
# Simulates a numz-format DiT-only file loaded via UNETLoader:
|
||||
# ``register_buffer`` zero-init at ``comfy/ldm/seedvr/model.py``
|
||||
# leaves the buffers at zero when ``load_state_dict`` cannot
|
||||
# find ``positive_conditioning`` / ``negative_conditioning``
|
||||
# keys in the state_dict. The fail-loud guard at
|
||||
# ``SeedVR2Conditioning.execute`` distinguishes this from a
|
||||
# properly-baked file by ``abs().sum() == 0`` on both buffers.
|
||||
self.register_buffer(
|
||||
"positive_conditioning",
|
||||
torch.zeros((2, 4), dtype=conditioning_dtype),
|
||||
)
|
||||
self.register_buffer(
|
||||
"negative_conditioning",
|
||||
torch.zeros((3, 4), dtype=conditioning_dtype),
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"positive_conditioning",
|
||||
torch.ones((2, 4), dtype=conditioning_dtype),
|
||||
)
|
||||
self.register_buffer(
|
||||
"negative_conditioning",
|
||||
torch.zeros((3, 4), dtype=conditioning_dtype),
|
||||
)
|
||||
|
||||
|
||||
class _ModelInner:
|
||||
def __init__(self, diffusion_model):
|
||||
self.diffusion_model = diffusion_model
|
||||
|
||||
|
||||
class _ModelPatcher:
|
||||
def __init__(self, diffusion_model):
|
||||
self.model = _ModelInner(diffusion_model)
|
||||
|
||||
|
||||
def test_resolve_seedvr2_diffusion_model_returns_inner_when_valid():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel()
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
resolved = nodes_seedvr._resolve_seedvr2_diffusion_model(patcher)
|
||||
assert resolved is diffusion_model
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
|
||||
assert [input_item.id for input_item in schema.inputs] == [
|
||||
"model",
|
||||
"vae_conditioning",
|
||||
]
|
||||
assert schema.inputs[1].display_name == "LATENT"
|
||||
assert [output.display_name for output in schema.outputs] == [
|
||||
"model",
|
||||
"positive",
|
||||
"negative",
|
||||
"latent",
|
||||
]
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel()
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
|
||||
vae_conditioning = {"samples": samples}
|
||||
|
||||
_, first_positive, first_negative, first_latent = (
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher,
|
||||
vae_conditioning,
|
||||
)
|
||||
)
|
||||
_, second_positive, second_negative, second_latent = (
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher,
|
||||
vae_conditioning,
|
||||
)
|
||||
)
|
||||
|
||||
expected_latent = samples.reshape(1, 6, 2, 2)
|
||||
channel_last = samples.movedim(1, -1).contiguous()
|
||||
expected_condition = torch.cat(
|
||||
[
|
||||
channel_last,
|
||||
torch.ones((*channel_last.shape[:-1], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
).movedim(-1, 1).reshape(1, 9, 2, 2)
|
||||
|
||||
assert torch.equal(first_latent["samples"], expected_latent)
|
||||
assert torch.equal(second_latent["samples"], expected_latent)
|
||||
assert torch.equal(
|
||||
first_positive[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
assert torch.equal(
|
||||
second_positive[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
assert torch.equal(
|
||||
first_negative[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
assert torch.equal(
|
||||
second_negative[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_resolve_seedvr2_diffusion_model_raises_runtime_error_with_specific_prefix():
|
||||
"""Pin all four failure modes of the resolver chain to the same error
|
||||
prefix and to message text that distinguishes 'attribute missing'
|
||||
from 'attribute present but None'. The four modes:
|
||||
|
||||
mode 1: input has no 'model' attribute
|
||||
mode 2: input.model is None
|
||||
mode 3: 'model.model' has no 'diffusion_model' attribute
|
||||
mode 4: 'model.model.diffusion_model' is None
|
||||
"""
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
# Mode 1: model has no 'model' attribute at all.
|
||||
class _NoModelAttr:
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr._resolve_seedvr2_diffusion_model(_NoModelAttr())
|
||||
msg = str(excinfo.value)
|
||||
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||
assert "no 'model' attribute" in msg
|
||||
|
||||
# Mode 2: model.model exists but is None (must not be conflated
|
||||
# with "no 'model' attribute").
|
||||
class _ModelIsNone:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr._resolve_seedvr2_diffusion_model(_ModelIsNone())
|
||||
msg = str(excinfo.value)
|
||||
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||
assert "input.model is None" in msg
|
||||
|
||||
# Mode 3: model.model exists, has no 'diffusion_model' attribute.
|
||||
class _NoDiffusionAttr:
|
||||
def __init__(self):
|
||||
self.model = object()
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr._resolve_seedvr2_diffusion_model(_NoDiffusionAttr())
|
||||
msg = str(excinfo.value)
|
||||
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||
assert "no 'diffusion_model' attribute" in msg
|
||||
|
||||
# Mode 4: model.model.diffusion_model exists but is None (must not
|
||||
# be conflated with "no 'diffusion_model' attribute").
|
||||
class _DiffusionIsNoneInner:
|
||||
def __init__(self):
|
||||
self.diffusion_model = None
|
||||
|
||||
class _DiffusionIsNone:
|
||||
def __init__(self):
|
||||
self.model = _DiffusionIsNoneInner()
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr._resolve_seedvr2_diffusion_model(_DiffusionIsNone())
|
||||
msg = str(excinfo.value)
|
||||
assert msg.startswith(nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX)
|
||||
assert "'model.model.diffusion_model' is None" in msg
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_apply_rope_freqs_float32_cast_idempotent_on_unchanged_dtype():
|
||||
"""Calling the helper twice on a model whose rope.freqs is already
|
||||
float32 must NOT mutate the tensor identity or contents — the dtype
|
||||
check on every nested module short-circuits the .to() call when the
|
||||
tensor is already in float32.
|
||||
"""
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel()
|
||||
|
||||
# Starting dtype is non-float32 so the first call has work to do.
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
|
||||
|
||||
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||
first_call_data_ids = []
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||
assert module.rope.freqs.data.dtype == torch.float32
|
||||
first_call_data_ids.append(id(module.rope.freqs.data))
|
||||
|
||||
# Second call on the same already-float32 model: every per-tensor
|
||||
# dtype check sees float32 and skips the .to() call. Tensor data
|
||||
# identity must be preserved (no re-allocation).
|
||||
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||
for module, prior_id in zip(
|
||||
(m for m in diffusion_model.modules()
|
||||
if hasattr(m, "rope") and hasattr(m.rope, "freqs")),
|
||||
first_call_data_ids,
|
||||
strict=True,
|
||||
):
|
||||
assert module.rope.freqs.data.dtype == torch.float32
|
||||
assert id(module.rope.freqs.data) == prior_id, (
|
||||
"Already-float32 rope.freqs must not be re-allocated on "
|
||||
"subsequent calls; the per-tensor dtype check must skip the "
|
||||
".to(float32) call when the tensor is already in float32."
|
||||
)
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_apply_rope_freqs_float32_cast_recovers_after_dtype_reset():
|
||||
"""After a model unload/reload that restores rope.freqs from an
|
||||
archived non-float32 dtype, the next call must re-cast to float32.
|
||||
A bool-sentinel cache approach would short-circuit here and leave
|
||||
RoPE running in fp16/bf16.
|
||||
"""
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel()
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
|
||||
|
||||
# First call casts to float32.
|
||||
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||
assert module.rope.freqs.data.dtype == torch.float32
|
||||
|
||||
# Simulate a Comfy dynamic unload/reload that restores rope.freqs
|
||||
# to the archived (non-float32) dtype.
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||
module.rope.freqs.data = module.rope.freqs.data.to(torch.float64)
|
||||
|
||||
# Second call must detect the dtype regression and re-cast.
|
||||
nodes_seedvr._apply_rope_freqs_float32_cast(diffusion_model)
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, "rope") and hasattr(module.rope, "freqs"):
|
||||
assert module.rope.freqs.data.dtype == torch.float32, (
|
||||
"After a model unload/reload that resets rope.freqs to "
|
||||
"non-float32, the next _apply_rope_freqs_float32_cast "
|
||||
"call MUST re-cast to float32. A bool-sentinel cache "
|
||||
"would have short-circuited here."
|
||||
)
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fail-loud guard: zero-valued conditioning buffers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
|
||||
"""A SeedVR2 model whose ``positive_conditioning`` AND
|
||||
``negative_conditioning`` buffers are both zero-valued is an
|
||||
unrecoverable load state — a numz-format DiT-only ``.safetensors``
|
||||
file was loaded via ``UNETLoader`` without the SeedVR2 conditioning
|
||||
keys baked in. ``SeedVR2Conditioning.execute`` must raise
|
||||
``RuntimeError`` carrying the standard SeedVR2 invalid-model prefix
|
||||
instead of letting the diffusion sampler run on null prompt
|
||||
conditioning (which silently produces wrong output).
|
||||
"""
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher, vae_conditioning,
|
||||
)
|
||||
|
||||
message = str(excinfo.value)
|
||||
assert message.startswith(
|
||||
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||
), (
|
||||
"Fail-loud message must use the standard "
|
||||
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
|
||||
f"can match it. Got: {message!r}"
|
||||
)
|
||||
assert "positive_conditioning" in message
|
||||
assert "negative_conditioning" in message
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_fails_loud_on_fp8_zero_buffers():
|
||||
"""The zero-buffer sentinel must reduce fp8 conditioning tensors
|
||||
without hitting PyTorch's unsupported float8 reductions.
|
||||
"""
|
||||
fp8_dtype = getattr(torch, "float8_e4m3fn", None)
|
||||
if fp8_dtype is None:
|
||||
pytest.skip("torch build does not expose float8_e4m3fn")
|
||||
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel(
|
||||
zero_conditioning=True,
|
||||
conditioning_dtype=fp8_dtype,
|
||||
)
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher, vae_conditioning,
|
||||
)
|
||||
|
||||
message = str(excinfo.value)
|
||||
assert message.startswith(
|
||||
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||
)
|
||||
assert "zero-valued" in message
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_does_not_fire_on_partial_zero_buffers():
|
||||
"""The guard checks BOTH buffers together: a model with zero
|
||||
``negative_conditioning`` but non-zero ``positive_conditioning``
|
||||
(the existing baseline mock fixture) must NOT trigger the fail-loud
|
||||
path. This pins the AND-gating semantic and prevents a future
|
||||
regression to OR-gating from rejecting valid bundled checkpoints
|
||||
where one buffer happens to be all-zeros.
|
||||
"""
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
# Baseline _DiffusionModel has positive=ones, negative=zeros.
|
||||
diffusion_model = _DiffusionModel(zero_conditioning=False)
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||
|
||||
# Should not raise.
|
||||
passthrough_model, positive, negative, latent = (
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher, vae_conditioning,
|
||||
)
|
||||
)
|
||||
assert positive[0][0].shape == (1, 2, 4)
|
||||
assert negative[0][0].shape == (1, 3, 4)
|
||||
assert passthrough_model is patcher
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_fail_loud_never_exposes_safetensors_path():
|
||||
"""The fail-loud message must not expose local model paths from
|
||||
``cached_patcher_init``. Public runtime errors should describe the
|
||||
invalid SeedVR2 contract without making filesystem paths part of the
|
||||
public behavior contract.
|
||||
"""
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
# Mimic the ``cached_patcher_init`` shape comfy.sd attaches.
|
||||
patcher.cached_patcher_init = (
|
||||
object(), # function reference
|
||||
("/some/models/diffusion_models/seedvr2_ema_7b_fp16.safetensors",),
|
||||
)
|
||||
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher, vae_conditioning,
|
||||
)
|
||||
|
||||
message = str(excinfo.value)
|
||||
assert "/some/models/diffusion_models" not in message
|
||||
assert "seedvr2_ema_7b_fp16.safetensors" not in message
|
||||
assert "Source file:" not in message
|
||||
assert "positive_conditioning" in message
|
||||
assert "negative_conditioning" in message
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_fail_loud_falls_back_when_path_unavailable():
|
||||
"""When ``cached_patcher_init`` is missing or its tuple does not
|
||||
contain a ``.safetensors`` path, the fail-loud message still
|
||||
delivers the actionable diagnostic without leaking ``None`` or
|
||||
raising during message formatting.
|
||||
"""
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
# No cached_patcher_init set on the patcher.
|
||||
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher, vae_conditioning,
|
||||
)
|
||||
message = str(excinfo.value)
|
||||
assert "Source file:" not in message # no empty path leak
|
||||
assert "Re-bake" in message # actionable guidance still present
|
||||
assert "bf16 keys" not in message
|
||||
finally:
|
||||
restore()
|
||||
103
tests-unit/comfy_extras_test/test_seedvr_node_signature.py
Normal file
103
tests-unit/comfy_extras_test/test_seedvr_node_signature.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""Regression test: SeedVR2 resize schema input ids must match
|
||||
execute() positional parameter order. Drift between the two would silently
|
||||
swap arguments at runtime; this test fails loudly on any future drift.
|
||||
|
||||
The schema input attribute is `.id` (verified live via Python introspection
|
||||
on the upstream class -- there is no `.name`).
|
||||
|
||||
`comfy.model_management` is stubbed via `patch.dict(sys.modules, ...)` for
|
||||
the import performed inside this test, so importing
|
||||
`comfy_extras.nodes_seedvr` here does not call
|
||||
`torch.cuda.is_available()` or trigger other GPU/server-side
|
||||
initialization through that dependency. Live introspection indicated that
|
||||
`comfy_extras.nodes_seedvr` pulls in `comfy.model_management`
|
||||
transitively here (not `nodes`, not `server`).
|
||||
|
||||
The test snapshots three pieces of import state before patching and
|
||||
restores all three in `finally` via a sentinel:
|
||||
|
||||
1. `sys.modules["comfy_extras.nodes_seedvr"]`
|
||||
2. `comfy.model_management` package attribute on the `comfy` package
|
||||
3. `comfy_extras.nodes_seedvr` attribute on the `comfy_extras` package
|
||||
|
||||
If any of the three was set before the test, it is restored verbatim;
|
||||
if it was unset, it is deleted on exit. This prevents the test from
|
||||
clobbering a real `comfy.model_management` (or
|
||||
`comfy_extras.nodes_seedvr`) module that another test may have
|
||||
legitimately imported earlier in the same pytest process, while still
|
||||
preventing the test's mock from leaking into later tests that import
|
||||
the real `comfy_extras.nodes_seedvr`."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
|
||||
def test_seedvr_node_signature_matches_schema():
|
||||
mock_model_management = MagicMock()
|
||||
mock_model_management.xformers_enabled.return_value = False
|
||||
mock_model_management.xformers_enabled_vae.return_value = False
|
||||
mock_model_management.sage_attention_enabled.return_value = False
|
||||
mock_model_management.flash_attention_enabled.return_value = False
|
||||
sentinel = object()
|
||||
prior_cpu = cli_args.cpu
|
||||
cli_args.cpu = True
|
||||
|
||||
comfy_module_pre = sys.modules.get("comfy")
|
||||
comfy_extras_module_pre = sys.modules.get("comfy_extras")
|
||||
prior_comfy_mm_attr = (
|
||||
getattr(comfy_module_pre, "model_management", sentinel)
|
||||
if comfy_module_pre is not None
|
||||
else sentinel
|
||||
)
|
||||
prior_comfy_extras_seedvr_attr = (
|
||||
getattr(comfy_extras_module_pre, "nodes_seedvr", sentinel)
|
||||
if comfy_extras_module_pre is not None
|
||||
else sentinel
|
||||
)
|
||||
prior_comfy_extras_seedvr_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
|
||||
|
||||
with patch.dict(sys.modules, {"comfy.model_management": mock_model_management}):
|
||||
if comfy_module_pre is not None:
|
||||
setattr(comfy_module_pre, "model_management", mock_model_management)
|
||||
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||
try:
|
||||
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||
for node_cls in (
|
||||
nodes_seedvr.SeedVR2Resize,
|
||||
nodes_seedvr.SeedVR2ResizeAdvanced,
|
||||
):
|
||||
schema_ids = [i.id for i in node_cls.define_schema().inputs]
|
||||
exec_params = [
|
||||
p
|
||||
for p in inspect.signature(node_cls.execute).parameters.keys()
|
||||
if p != "cls"
|
||||
]
|
||||
assert schema_ids == exec_params, (
|
||||
f"{node_cls.__name__} schema input ids do not match "
|
||||
f"execute() parameter order: schema_ids={schema_ids}, "
|
||||
f"exec_params={exec_params}"
|
||||
)
|
||||
finally:
|
||||
if prior_comfy_extras_seedvr_module is sentinel:
|
||||
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||
else:
|
||||
sys.modules["comfy_extras.nodes_seedvr"] = prior_comfy_extras_seedvr_module
|
||||
cli_args.cpu = prior_cpu
|
||||
comfy_extras_module = sys.modules.get("comfy_extras")
|
||||
if comfy_extras_module is not None:
|
||||
if prior_comfy_extras_seedvr_attr is sentinel:
|
||||
if hasattr(comfy_extras_module, "nodes_seedvr"):
|
||||
delattr(comfy_extras_module, "nodes_seedvr")
|
||||
else:
|
||||
setattr(comfy_extras_module, "nodes_seedvr", prior_comfy_extras_seedvr_attr)
|
||||
comfy_module = sys.modules.get("comfy")
|
||||
if comfy_module is not None:
|
||||
if prior_comfy_mm_attr is sentinel:
|
||||
if hasattr(comfy_module, "model_management"):
|
||||
delattr(comfy_module, "model_management")
|
||||
else:
|
||||
setattr(comfy_module, "model_management", prior_comfy_mm_attr)
|
||||
@ -0,0 +1,40 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
FILES = [
|
||||
ROOT / "comfy/ldm/seedvr/vae.py",
|
||||
ROOT / "comfy/sd.py",
|
||||
ROOT / "comfy_extras/nodes_seedvr.py",
|
||||
]
|
||||
FORBIDDEN_ATTRS = {"original_image_video", "img_dims", "tiled_args"}
|
||||
FORBIDDEN_KEYS = {
|
||||
"sampler_metadata",
|
||||
"latent_sidecar_metadata",
|
||||
"saved_latent_metadata",
|
||||
"workflow_hidden_state",
|
||||
}
|
||||
FORBIDDEN_GETSET_KEYS = {"original_image_video", "img_dims", "tiled_args"}
|
||||
|
||||
|
||||
def test_seedvr2_decode_paths_do_not_use_hidden_vae_object_state():
|
||||
for path in FILES:
|
||||
tree = ast.parse(path.read_text(encoding="utf-8"))
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Attribute) and node.attr in FORBIDDEN_ATTRS:
|
||||
pytest.fail(f"{path}: forbidden VAE object state attr {node.attr}")
|
||||
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
||||
if node.func.id in {"getattr", "setattr", "delattr"} and len(node.args) >= 2:
|
||||
key = node.args[1]
|
||||
if isinstance(key, ast.Constant) and key.value in FORBIDDEN_GETSET_KEYS:
|
||||
pytest.fail(f"{path}: forbidden VAE object state access {key.value}")
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
if node.value in FORBIDDEN_ATTRS or node.value in FORBIDDEN_KEYS:
|
||||
pytest.fail(f"{path}: forbidden hidden-state string {node.value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(pytest.main([__file__]))
|
||||
43
tests-unit/comfy_test/test_seedvr2_non_goal_static_audit.py
Normal file
43
tests-unit/comfy_test/test_seedvr2_non_goal_static_audit.py
Normal file
@ -0,0 +1,43 @@
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
FORBIDDEN_FILES = {
|
||||
"comfy/ldm/seedvr/model.py",
|
||||
"comfy/ldm/modules/attention.py",
|
||||
"comfy/sample.py",
|
||||
"comfy/samplers.py",
|
||||
}
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("SEEDVR2_NON_GOAL_STATIC_AUDIT") != "1",
|
||||
reason="SEEDVR2_NON_GOAL_STATIC_AUDIT=1 is required for git-index audit execution.",
|
||||
)
|
||||
|
||||
|
||||
def _git_changed_paths(*args):
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(ROOT), "diff", "--name-only", *args],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
pytest.skip(f"git diff unavailable: {result.stderr.strip()}")
|
||||
return set(result.stdout.splitlines())
|
||||
|
||||
|
||||
def test_seedvr2_non_goal_files_are_not_dirty():
|
||||
changed = _git_changed_paths()
|
||||
changed.update(_git_changed_paths("--cached"))
|
||||
changed_forbidden = sorted(FORBIDDEN_FILES.intersection(changed))
|
||||
if changed_forbidden:
|
||||
pytest.fail(f"forbidden non-goal files changed: {changed_forbidden}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(pytest.main([__file__]))
|
||||
@ -0,0 +1,110 @@
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def test_resize_simple_multiplier_resolves_upscaled_shorter_edge():
|
||||
images = torch.zeros(1, 3, 16, 20, 3)
|
||||
|
||||
output = nodes_seedvr.SeedVR2Resize.execute(images, 4.0)
|
||||
|
||||
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||
assert tuple(input_pixels.shape) == (1, 5, 64, 80, 3)
|
||||
assert input_pixels.min().item() == 0.0
|
||||
assert input_pixels.max().item() == 0.0
|
||||
assert original_image is images
|
||||
assert upscaled_shorter_edge == 64
|
||||
|
||||
|
||||
def test_resize_simple_silent_spatial_padding_keeps_unpadded_edge_output():
|
||||
images = torch.zeros(1, 1, 16, 16, 3)
|
||||
|
||||
output = nodes_seedvr.SeedVR2Resize.execute(images, 7.5)
|
||||
|
||||
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||
assert tuple(input_pixels.shape) == (1, 1, 128, 128, 3)
|
||||
assert original_image is images
|
||||
assert upscaled_shorter_edge == 120
|
||||
|
||||
|
||||
def test_resize_simple_rejects_non_positive_multiplier():
|
||||
images = torch.zeros(1, 1, 16, 16, 3)
|
||||
|
||||
try:
|
||||
nodes_seedvr.SeedVR2Resize.execute(images, 0.0)
|
||||
except ValueError as e:
|
||||
assert "multiplier must be > 0" in str(e)
|
||||
else:
|
||||
raise AssertionError("non-positive multiplier was not rejected")
|
||||
|
||||
|
||||
def test_resize_simple_rejects_multiplier_resolving_to_too_small_edge():
|
||||
images = torch.zeros(1, 1, 16, 16, 3)
|
||||
|
||||
try:
|
||||
nodes_seedvr.SeedVR2Resize.execute(images, 0.01)
|
||||
except ValueError as e:
|
||||
assert "multiplier resolved upscaled_shorter_edge" in str(e)
|
||||
assert "at least 2 pixels" in str(e)
|
||||
else:
|
||||
raise AssertionError("too-small resolved edge was not rejected")
|
||||
|
||||
|
||||
def test_resize_advanced_takes_exact_shorter_edge():
|
||||
images = torch.zeros(1, 1, 16, 16, 3)
|
||||
|
||||
output = nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 120)
|
||||
|
||||
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||
assert tuple(input_pixels.shape) == (1, 1, 128, 128, 3)
|
||||
assert original_image is images
|
||||
assert upscaled_shorter_edge == 120
|
||||
|
||||
|
||||
def test_resize_advanced_treats_4d_image_as_one_video_frame_sequence():
|
||||
images = torch.zeros(2, 16, 16, 3)
|
||||
|
||||
output = nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 120)
|
||||
|
||||
input_pixels, original_image, upscaled_shorter_edge = output.result
|
||||
assert tuple(input_pixels.shape) == (1, 5, 128, 128, 3)
|
||||
assert original_image is images
|
||||
assert upscaled_shorter_edge == 120
|
||||
|
||||
|
||||
def test_resize_advanced_rejects_one_pixel_shorter_edge():
|
||||
images = torch.zeros(1, 1, 16, 16, 3)
|
||||
|
||||
try:
|
||||
nodes_seedvr.SeedVR2ResizeAdvanced.execute(images, 1)
|
||||
except ValueError as e:
|
||||
assert "upscaled_shorter_edge must be at least 2 pixels" in str(e)
|
||||
else:
|
||||
raise AssertionError("one-pixel shorter_edge was not rejected")
|
||||
|
||||
|
||||
def test_resize_node_schemas_and_execute_signatures_are_preprocess_only():
|
||||
simple = nodes_seedvr.SeedVR2Resize.define_schema()
|
||||
advanced = nodes_seedvr.SeedVR2ResizeAdvanced.define_schema()
|
||||
|
||||
assert [item.id for item in simple.inputs] == ["images", "multiplier"]
|
||||
assert simple.inputs[1].default == 4.0
|
||||
assert [item.id for item in simple.outputs] == [
|
||||
"input_pixels",
|
||||
"original_image",
|
||||
"upscaled_shorter_edge",
|
||||
]
|
||||
|
||||
assert [item.id for item in advanced.inputs] == ["images", "shorter_edge"]
|
||||
assert advanced.inputs[1].min == 2
|
||||
assert advanced.inputs[1].step is None
|
||||
assert [item.id for item in advanced.outputs] == [
|
||||
"input_pixels",
|
||||
"original_image",
|
||||
"upscaled_shorter_edge",
|
||||
]
|
||||
@ -0,0 +1,38 @@
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
import nodes as nodes_mod # noqa: E402
|
||||
|
||||
|
||||
class _DecodeOnlyVAE:
|
||||
def __init__(self):
|
||||
self.decode_calls = 0
|
||||
|
||||
def decode(self, latent):
|
||||
self.decode_calls += 1
|
||||
b, tc, h, w = latent.shape
|
||||
t = tc // 16
|
||||
return torch.full((b, t, h * 8, w * 8, 3), 0.25)
|
||||
|
||||
|
||||
def test_saved_loaded_seedvr2_latent_decode_boundary_does_not_rerun_preprocessing():
|
||||
latent = {"samples": torch.zeros(1, 32, 4, 5)}
|
||||
buffer = io.BytesIO()
|
||||
torch.save(latent["samples"], buffer)
|
||||
buffer.seek(0)
|
||||
loaded = {"samples": torch.load(buffer, weights_only=True)}
|
||||
|
||||
vae = _DecodeOnlyVAE()
|
||||
decoded = nodes_mod.VAEDecode().decode(vae, loaded)[0]
|
||||
original = torch.full((1, 2, 32, 40, 3), 0.75)
|
||||
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 32, "none").result[0]
|
||||
|
||||
assert vae.decode_calls == 1
|
||||
assert tuple(output.shape) == (2, 32, 40, 3)
|
||||
210
tests-unit/comfy_test/test_seedvr2_vae_graph_boundaries.py
Normal file
210
tests-unit/comfy_test/test_seedvr2_vae_graph_boundaries.py
Normal file
@ -0,0 +1,210 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
import nodes as nodes_mod # noqa: E402
|
||||
|
||||
|
||||
class _Patcher:
|
||||
def get_free_memory(self, device):
|
||||
return 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self, encoded):
|
||||
nn.Module.__init__(self)
|
||||
self.encoded = encoded
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.seen = []
|
||||
|
||||
def encode(self, x):
|
||||
self.seen.append(tuple(x.shape))
|
||||
return self.encoded.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
|
||||
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.calls = []
|
||||
|
||||
def decode(self, z, seedvr2_tiling=None):
|
||||
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
||||
if z.ndim == 4:
|
||||
b, tc, h, w = z.shape
|
||||
t = tc // 16
|
||||
else:
|
||||
b, _, t, h, w = z.shape
|
||||
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||
|
||||
|
||||
def _make_vae(wrapper):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = wrapper
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
||||
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
vae.output_channels = 3
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.crop_input = False
|
||||
vae.not_video = False
|
||||
vae.patcher = _Patcher()
|
||||
vae.process_input = lambda image: image
|
||||
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.memory_used_encode = lambda shape, dtype: 1
|
||||
vae.memory_used_decode = lambda shape, dtype: 1
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
||||
vae.spacial_compression_decode = lambda: 8
|
||||
vae.temporal_compression_decode = lambda: 4
|
||||
return vae
|
||||
|
||||
|
||||
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
|
||||
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
||||
vae = _make_vae(_EncodeWrapper(encoded))
|
||||
pixels = torch.zeros(1, 5, 32, 40, 3)
|
||||
|
||||
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
||||
node_latent = node_output["samples"]
|
||||
assert set(node_output) == {"samples"}
|
||||
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
||||
assert node_latent.dtype == torch.float32
|
||||
assert node_latent.stride()[-1] == 1
|
||||
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
||||
|
||||
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
||||
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
||||
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
||||
vae,
|
||||
pixels,
|
||||
tile_size=512,
|
||||
overlap=64,
|
||||
temporal_size=16,
|
||||
temporal_overlap=4,
|
||||
)[0]
|
||||
tiled_latent = tiled_output["samples"]
|
||||
assert set(tiled_output) == {"samples"}
|
||||
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
||||
assert tiled_latent.dtype == torch.float32
|
||||
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
||||
|
||||
|
||||
def test_seedvr2_decode_and_decode_tiled_do_not_require_preprocessor_state(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
latent = {"samples": torch.zeros(1, 32, 4, 5)}
|
||||
decoded = nodes_mod.VAEDecode().decode(vae, latent)[0]
|
||||
assert tuple(decoded.shape) == (2, 32, 40, 3)
|
||||
|
||||
tiled = nodes_mod.VAEDecodeTiled().decode(
|
||||
vae,
|
||||
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||
tile_size=512,
|
||||
overlap=64,
|
||||
temporal_size=16,
|
||||
temporal_overlap=4,
|
||||
)[0]
|
||||
assert tuple(tiled.shape) == (2, 32, 40, 3)
|
||||
|
||||
|
||||
def test_seedvr2_vaedecode_does_not_repair_latent_layout(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
latent = {"samples": torch.zeros(1, 2, 4, 5, 16)}
|
||||
nodes_mod.VAEDecode().decode(vae, latent)
|
||||
|
||||
assert vae.first_stage_model.calls == [{"shape": (1, 2, 4, 5, 16), "seedvr2_tiling": None}]
|
||||
|
||||
|
||||
def test_seedvr2_vaedecode_keeps_public_channel_first_width_16_latents(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
nodes_mod.VAEDecode().decode(
|
||||
vae,
|
||||
{"samples": torch.zeros(1, 16, 4, 5, 16)},
|
||||
)
|
||||
|
||||
assert vae.first_stage_model.calls == [{"shape": (1, 16, 4, 5, 16), "seedvr2_tiling": None}]
|
||||
|
||||
|
||||
def test_seedvr2_direct_decode_preserves_channel_first_width_16(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
vae.decode(torch.zeros(1, 16, 2, 4, 16))
|
||||
|
||||
assert vae.first_stage_model.calls == [{"shape": (1, 16, 2, 4, 16), "seedvr2_tiling": None}]
|
||||
|
||||
|
||||
def test_seedvr2_decode_tiled_preserves_direct_channel_first_width_16(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
vae.decode_tiled_seedvr2(torch.zeros(1, 16, 2, 4, 16))
|
||||
|
||||
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 2, 4, 16)
|
||||
|
||||
|
||||
def test_seedvr2_vaedecode_tiled_keeps_public_channel_first_width_16_latents(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
nodes_mod.VAEDecodeTiled().decode(
|
||||
vae,
|
||||
{"samples": torch.zeros(1, 16, 4, 5, 16)},
|
||||
tile_size=512,
|
||||
overlap=64,
|
||||
temporal_size=16,
|
||||
temporal_overlap=4,
|
||||
)
|
||||
|
||||
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 4, 5, 16)
|
||||
|
||||
|
||||
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
nodes_mod.VAEDecodeTiled().decode(
|
||||
vae,
|
||||
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||
tile_size=512,
|
||||
overlap=64,
|
||||
temporal_size=16,
|
||||
temporal_overlap=4,
|
||||
)
|
||||
|
||||
assert vae.first_stage_model.calls == [
|
||||
{
|
||||
"shape": (1, 16, 2, 4, 5),
|
||||
"seedvr2_tiling": {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (512, 512),
|
||||
"tile_overlap": (64, 64),
|
||||
"temporal_size": 16,
|
||||
"temporal_overlap": 4,
|
||||
},
|
||||
}
|
||||
]
|
||||
40
tests-unit/comfy_test/test_seedvr2_windows_static_verify.py
Normal file
40
tests-unit/comfy_test/test_seedvr2_windows_static_verify.py
Normal file
@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _read(relative):
|
||||
return (ROOT / relative).read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_seedvr2_windows_static_contract_tokens():
|
||||
nodes = _read("comfy_extras/nodes_seedvr.py")
|
||||
sd = _read("comfy/sd.py")
|
||||
vae = _read("comfy/ldm/seedvr/vae.py")
|
||||
|
||||
required = [
|
||||
"SeedVR2Resize",
|
||||
"SeedVR2ResizeAdvanced",
|
||||
"SeedVR2PostProcessing",
|
||||
'io.Image.Input("decoded")',
|
||||
'io.Image.Input("original_image")',
|
||||
'io.Int.Input("upscaled_shorter_edge", min=2, force_input=True)',
|
||||
'io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab")',
|
||||
"def _format_seedvr2_encoded_samples",
|
||||
"def decode(self, z, seedvr2_tiling=None)",
|
||||
]
|
||||
for needle in required:
|
||||
if needle not in nodes + sd + vae:
|
||||
pytest.fail(f"missing required static token: {needle}")
|
||||
|
||||
forbidden = ["original_image_video", "img_dims", "tiled_args"]
|
||||
for needle in forbidden:
|
||||
if needle in nodes + sd + vae:
|
||||
pytest.fail(f"forbidden hidden-state token remains: {needle}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(pytest.main([__file__]))
|
||||
1070
tests-unit/comfy_test/test_seedvr_progressive_sampler.py
Normal file
1070
tests-unit/comfy_test/test_seedvr_progressive_sampler.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user