ComfyUI/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
2026-05-26 00:28:43 -05:00

462 lines
19 KiB
Python

import inspect
from unittest.mock import patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy_extras import nodes_seedvr # noqa: E402
def _schema_ids(items):
return [item.id for item in items]
def test_seedvr2_post_processing_schema():
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"]
assert schema.inputs[2].default is None
assert schema.inputs[2].min == 2
assert schema.inputs[2].force_input is True
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
assert schema.inputs[3].default == "lab"
assert schema.outputs[0].get_io_type() == "IMAGE"
def test_seedvr2_post_processing_color_correction_memory_multipliers_are_named():
assert nodes_seedvr.LAB_SCALE_MULTIPLIER == 13
assert nodes_seedvr.WAVELET_SCALE_MULTIPLIER == 10
assert nodes_seedvr.ADAIN_SCALE_MULTIPLIER == 6
def test_seedvr2_post_processing_lab_autochunks_from_memory_estimate(monkeypatch):
decoded = torch.full((1, 5, 2, 2, 3), 0.25)
original = torch.full((1, 5, 2, 2, 3), 0.75)
calls = []
def _lab(content, style):
calls.append(content.shape[0])
return content
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1700)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
assert calls == [1, 1, 1, 1, 1]
assert tuple(output.shape) == (1, 5, 2, 2, 3)
def test_seedvr2_post_processing_lab_runs_each_frame_independently(monkeypatch):
decoded = torch.full((1, 4, 2, 2, 3), 0.25)
original = torch.full((1, 4, 2, 2, 3), 0.75)
calls = []
def _lab(content, style):
calls.append(content.shape[0])
return content
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
assert calls == [1, 1, 1, 1]
assert tuple(output.shape) == (1, 4, 2, 2, 3)
def test_seedvr2_post_processing_lab_derives_reference_from_original_and_upscaled_shorter_edge():
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
original = torch.full((1, 2, 16, 20, 3), 0.75)
calls = []
def _lab(content, style):
calls.append((content.clone(), style.clone()))
return torch.zeros_like(content)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
assert tuple(output.shape) == (1, 2, 8, 10, 3)
assert torch.equal(output, torch.full_like(output, 0.5))
assert len(calls) == 2
assert calls[0][0].shape == (1, 3, 8, 10)
assert calls[0][1].shape == (1, 3, 8, 10)
assert torch.equal(calls[0][0], torch.full_like(calls[0][0], -0.5))
assert torch.allclose(calls[0][1], torch.full_like(calls[0][1], 0.5))
def test_seedvr2_post_processing_lab_runs_color_transfer_on_vae_device():
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing.execute)
chunk_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks)
helper_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._lab_color_transfer_on_vae_device)
assert "_color_transfer_chunked" in source
assert "_lab_color_transfer_on_vae_device" in chunk_source
assert "torch.cat" not in chunk_source
assert "torch.empty" in chunk_source
assert ".copy_(" in chunk_source
assert "reference_5d.to(device=decoded_5d.device)" not in source
assert "comfy.model_management.vae_device()" in helper_source
assert ".to(device=color_device)" in helper_source
assert ".to(device=output_device)" in helper_source
def test_seedvr2_post_processing_lab_chunking_is_frame_independent(monkeypatch):
decoded = torch.linspace(-0.9, 0.9, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
reference = torch.linspace(0.8, -0.8, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
one_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 1,
)
multi_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 3,
)
assert torch.equal(one_frame, multi_frame)
def test_seedvr2_post_processing_lab_retry_does_not_mutate_reference(monkeypatch):
decoded = torch.full((2, 3, 4, 4), 0.25)
reference = torch.full((2, 3, 4, 4), 0.75)
original_reference = reference.clone()
calls = []
cache_clears = []
def _lab(content, style):
calls.append((content.clone(), style.clone()))
style.add_(10.0)
if len(calls) == 1:
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
return content
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: cache_clears.append(True))
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
decoded, reference, torch.device("cpu"), "lab",
)
assert len(cache_clears) == 1
assert torch.equal(reference, original_reference)
assert torch.equal(calls[1][1], original_reference[0:1])
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
decoded = torch.full((1, 3, 4, 4), 0.25)
reference = torch.full((1, 3, 4, 4), 0.75)
def _lab(content, style):
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
try:
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
decoded, reference, torch.device("cpu"), "lab",
)
except RuntimeError as exc:
assert "color_correction_method=lab" in str(exc)
assert " method=lab" not in str(exc)
else:
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
def test_seedvr2_post_processing_raw_conversion_does_not_probe_full_tensor_range():
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._to_seedvr2_raw)
assert ".amin" not in source
assert ".item" not in source
def test_seedvr2_post_processing_none_does_not_resize_reference_pixels():
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
original = torch.full((1, 2, 16, 20, 3), 0.75)
with patch.object(nodes_seedvr, "side_resize") as resize:
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
resize.assert_not_called()
assert tuple(output.shape) == (1, 2, 8, 10, 3)
def test_seedvr2_post_processing_rejects_invalid_upscaled_shorter_edge():
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
original = torch.full((1, 2, 16, 20, 3), 0.75)
for edge in (None, 1, 1.5):
try:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, edge, "none")
except ValueError as exc:
assert "upscaled_shorter_edge" in str(exc)
else:
raise AssertionError(f"expected ValueError for upscaled_shorter_edge={edge!r}")
def test_seedvr2_post_processing_lab_resizes_full_reference_frame():
decoded = torch.full((1, 2, 4, 5, 3), 0.25)
original = torch.full((1, 2, 16, 20, 3), 0.75)
resize_calls = []
lab_calls = []
def _resize(images, size, interpolation=None, antialias=None):
resize_calls.append((images.clone(), size, interpolation, antialias))
if isinstance(size, int):
return torch.full((2, 3, size, round(images.shape[-1] * size / images.shape[-2])), 0.5)
return torch.full((2, 3, size[0], size[1]), 0.5)
def _lab(content, style):
lab_calls.append((content.clone(), style.clone()))
return torch.zeros_like(content)
with patch.object(nodes_seedvr.TVF, "resize", _resize):
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
assert tuple(output.shape) == (1, 2, 4, 4, 3)
assert torch.equal(output, torch.full_like(output, 0.5))
assert resize_calls[0][0].shape == (2, 3, 16, 20)
assert resize_calls[0][1] == 8
assert resize_calls[1][0].shape == (2, 3, 8, 10)
assert resize_calls[1][1] == (4, 5)
assert len(lab_calls) == 2
assert lab_calls[0][1].shape == (1, 3, 4, 5)
assert torch.equal(lab_calls[0][1], torch.zeros_like(lab_calls[0][1]))
def test_seedvr2_post_processing_none_trims_and_crops_without_color_correction():
decoded = torch.arange(1 * 3 * 9 * 11 * 3, dtype=torch.float32).reshape(1, 3, 9, 11, 3)
original = torch.zeros(1, 2, 16, 20, 3)
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
assert lab.call_count == 0
assert tuple(output.shape) == (1, 2, 8, 10, 3)
assert torch.equal(output, decoded[:, :2, :8, :10, :])
def test_seedvr2_post_processing_restores_flattened_padded_batches_before_trimming():
decoded = torch.arange(10 * 4 * 6 * 1, dtype=torch.float32).reshape(10, 4, 6, 1)
original = torch.zeros(2, 2, 4, 6, 1)
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0]
expected = torch.cat((decoded[0:2], decoded[5:7]), dim=0)
assert tuple(output.shape) == (4, 4, 6, 1)
assert torch.equal(output, expected)
def test_seedvr2_post_processing_none_preserves_decoded_spatial_size_when_reference_is_larger():
decoded = torch.arange(1 * 3 * 8 * 10 * 3, dtype=torch.float32).reshape(1, 3, 8, 10, 3)
original = torch.zeros(1, 2, 16, 20, 3)
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 16, "none").result[0]
assert lab.call_count == 0
assert tuple(output.shape) == (1, 2, 8, 10, 3)
assert torch.equal(output, decoded[:, :2, :, :, :])
def test_seedvr2_post_processing_crops_to_reference_tensor_when_reference_is_smaller():
decoded = torch.ones((1, 1, 720, 1280, 3), dtype=torch.float32)
original = torch.ones((1, 1, 360, 640, 3), dtype=torch.float32)
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 360, "none").result[0]
assert tuple(output.shape) == (1, 1, 360, 640, 3)
def test_seedvr2_post_processing_uses_decoded_size_when_reference_is_larger():
decoded = torch.ones((1, 1, 128, 160, 3), dtype=torch.float32)
original = torch.ones((1, 1, 480, 640, 3), dtype=torch.float32)
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0]
assert tuple(output.shape) == (1, 1, 128, 160, 3)
def test_seedvr2_post_processing_derives_crop_from_upscaled_shorter_edge():
decoded = torch.ones((1, 1, 128, 224, 3), dtype=torch.float32)
original = torch.ones((1, 1, 1080, 1920, 3), dtype=torch.float32)
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
assert tuple(output.shape) == (1, 1, 120, 212, 3)
def test_seedvr2_post_processing_uses_even_crop_from_odd_resized_width():
decoded = torch.ones((1, 1, 128, 256, 3), dtype=torch.float32)
original = torch.ones((1, 1, 120, 169, 3), dtype=torch.float32)
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
assert tuple(output.shape) == (1, 1, 120, 168, 3)
def test_seedvr2_post_processing_none_preserves_black_bottom_row_content():
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
original[:, :, -1, :, :] = -1.0
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
assert tuple(output.shape) == (1, 2, 8, 10, 3)
assert torch.equal(output, decoded)
def test_seedvr2_post_processing_none_preserves_black_right_column_content():
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
original[:, :, :, -1, :] = -1.0
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
assert tuple(output.shape) == (1, 2, 8, 10, 3)
assert torch.equal(output, decoded)
def test_seedvr2_post_processing_wavelet_dispatch_routes_through_wavelet_color_transfer():
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
original = torch.full((1, 2, 16, 20, 3), 0.75)
wavelet_calls = []
lab_calls = []
def _wavelet(content, style):
wavelet_calls.append((content.clone(), style.clone()))
return torch.zeros_like(content)
def _lab(content, style):
lab_calls.append((content.clone(), style.clone()))
return torch.zeros_like(content)
with patch.object(nodes_seedvr, "wavelet_color_transfer", _wavelet):
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "wavelet").result[0]
assert len(wavelet_calls) == 1
assert len(lab_calls) == 0
assert tuple(output.shape) == (1, 2, 8, 10, 3)
assert torch.equal(output, torch.full_like(output, 0.5))
assert wavelet_calls[0][0].shape == (2, 3, 8, 10)
assert wavelet_calls[0][1].shape == (2, 3, 8, 10)
assert torch.equal(wavelet_calls[0][0], torch.full_like(wavelet_calls[0][0], -0.5))
assert torch.allclose(wavelet_calls[0][1], torch.full_like(wavelet_calls[0][1], 0.5))
def test_seedvr2_post_processing_adain_dispatch_routes_through_adain_color_transfer():
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
original = torch.full((1, 2, 16, 20, 3), 0.75)
adain_calls = []
lab_calls = []
def _adain(content, style):
adain_calls.append((content.clone(), style.clone()))
return torch.zeros_like(content)
def _lab(content, style):
lab_calls.append((content.clone(), style.clone()))
return torch.zeros_like(content)
with patch.object(nodes_seedvr, "adain_color_transfer", _adain):
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "adain").result[0]
assert len(adain_calls) == 1
assert len(lab_calls) == 0
assert tuple(output.shape) == (1, 2, 8, 10, 3)
assert torch.equal(output, torch.full_like(output, 0.5))
assert adain_calls[0][0].shape == (2, 3, 8, 10)
assert adain_calls[0][1].shape == (2, 3, 8, 10)
def test_seedvr2_color_transfer_helper_runs_on_vae_device():
import inspect as _inspect
helper_source = _inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._color_transfer_on_vae_device)
assert "comfy.model_management.vae_device()" in helper_source
assert ".to(device=color_device)" in helper_source
assert ".to(device=output_device)" in helper_source
assert "transfer_fn" in helper_source
def test_seedvr2_wavelet_color_transfer_matches_primary_source_reconstruction():
from comfy.ldm.seedvr import vae as seedvr_vae
torch.manual_seed(0)
content = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
style = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
out = seedvr_vae.wavelet_color_transfer(content, style)
expected = seedvr_vae.wavelet_reconstruction(content.clone(), style.clone())
assert torch.equal(out, expected)
def test_seedvr2_adain_color_transfer_matches_huang_belongie_formula():
from comfy.ldm.seedvr import vae as seedvr_vae
torch.manual_seed(0)
content = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
style = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
out = seedvr_vae.adain_color_transfer(content.clone(), style.clone())
b, c = 2, 3
cf = content.float().reshape(b, c, -1)
sf = style.float().reshape(b, c, -1)
eps = 1e-5
mu_c = cf.mean(dim=2).reshape(b, c, 1, 1)
sd_c = (cf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
mu_s = sf.mean(dim=2).reshape(b, c, 1, 1)
sd_s = (sf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
expected = ((content.float() - mu_c) / sd_c) * sd_s + mu_s
expected = expected.clamp(-1.0, 1.0)
assert torch.allclose(out, expected, atol=1e-6)
def test_seedvr2_adain_single_pixel_uses_population_variance_without_nan():
from comfy.ldm.seedvr import vae as seedvr_vae
content = torch.tensor([[[[0.25]], [[-0.5]], [[0.75]]]], dtype=torch.float32)
style = torch.tensor([[[[-0.25]], [[0.5]], [[-0.75]]]], dtype=torch.float32)
out = seedvr_vae.adain_color_transfer(content, style)
assert torch.isfinite(out).all()
assert torch.equal(out, style)
def test_seedvr2_adain_preserves_input_dtype():
from comfy.ldm.seedvr import vae as seedvr_vae
content = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
style = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
out = seedvr_vae.adain_color_transfer(content, style)
assert out.dtype == torch.float16
def test_seedvr2_adain_resizes_mismatched_style_to_content_shape():
from comfy.ldm.seedvr import vae as seedvr_vae
content = torch.rand(1, 3, 8, 10) * 2.0 - 1.0
style = torch.rand(1, 3, 16, 20) * 2.0 - 1.0
out = seedvr_vae.adain_color_transfer(content, style)
assert tuple(out.shape) == (1, 3, 8, 10)
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
decoded = torch.zeros(1, 2, 4, 4, 3)
original = torch.zeros(1, 2, 4, 4, 3)
try:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus")
except ValueError as exc:
assert "color_correction_method" in str(exc)
else:
raise AssertionError("expected ValueError for unknown color_correction_method")