mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
462 lines
19 KiB
Python
462 lines
19 KiB
Python
import inspect
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
from comfy_extras import nodes_seedvr # noqa: E402
|
|
|
|
|
|
def _schema_ids(items):
|
|
return [item.id for item in items]
|
|
|
|
|
|
def test_seedvr2_post_processing_schema():
|
|
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
|
|
|
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"]
|
|
assert schema.inputs[2].default is None
|
|
assert schema.inputs[2].min == 2
|
|
assert schema.inputs[2].force_input is True
|
|
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
|
|
assert schema.inputs[3].default == "lab"
|
|
assert schema.outputs[0].get_io_type() == "IMAGE"
|
|
|
|
|
|
def test_seedvr2_post_processing_color_correction_memory_multipliers_are_named():
|
|
assert nodes_seedvr.LAB_SCALE_MULTIPLIER == 13
|
|
assert nodes_seedvr.WAVELET_SCALE_MULTIPLIER == 10
|
|
assert nodes_seedvr.ADAIN_SCALE_MULTIPLIER == 6
|
|
|
|
|
|
def test_seedvr2_post_processing_lab_autochunks_from_memory_estimate(monkeypatch):
|
|
decoded = torch.full((1, 5, 2, 2, 3), 0.25)
|
|
original = torch.full((1, 5, 2, 2, 3), 0.75)
|
|
calls = []
|
|
|
|
def _lab(content, style):
|
|
calls.append(content.shape[0])
|
|
return content
|
|
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1700)
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
|
|
|
|
assert calls == [1, 1, 1, 1, 1]
|
|
assert tuple(output.shape) == (1, 5, 2, 2, 3)
|
|
|
|
|
|
def test_seedvr2_post_processing_lab_runs_each_frame_independently(monkeypatch):
|
|
decoded = torch.full((1, 4, 2, 2, 3), 0.25)
|
|
original = torch.full((1, 4, 2, 2, 3), 0.75)
|
|
calls = []
|
|
|
|
def _lab(content, style):
|
|
calls.append(content.shape[0])
|
|
return content
|
|
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 2, "lab").result[0]
|
|
|
|
assert calls == [1, 1, 1, 1]
|
|
assert tuple(output.shape) == (1, 4, 2, 2, 3)
|
|
|
|
|
|
def test_seedvr2_post_processing_lab_derives_reference_from_original_and_upscaled_shorter_edge():
|
|
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
|
calls = []
|
|
|
|
def _lab(content, style):
|
|
calls.append((content.clone(), style.clone()))
|
|
return torch.zeros_like(content)
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
|
assert len(calls) == 2
|
|
assert calls[0][0].shape == (1, 3, 8, 10)
|
|
assert calls[0][1].shape == (1, 3, 8, 10)
|
|
assert torch.equal(calls[0][0], torch.full_like(calls[0][0], -0.5))
|
|
assert torch.allclose(calls[0][1], torch.full_like(calls[0][1], 0.5))
|
|
|
|
|
|
def test_seedvr2_post_processing_lab_runs_color_transfer_on_vae_device():
|
|
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing.execute)
|
|
chunk_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks)
|
|
helper_source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._lab_color_transfer_on_vae_device)
|
|
|
|
assert "_color_transfer_chunked" in source
|
|
assert "_lab_color_transfer_on_vae_device" in chunk_source
|
|
assert "torch.cat" not in chunk_source
|
|
assert "torch.empty" in chunk_source
|
|
assert ".copy_(" in chunk_source
|
|
assert "reference_5d.to(device=decoded_5d.device)" not in source
|
|
assert "comfy.model_management.vae_device()" in helper_source
|
|
assert ".to(device=color_device)" in helper_source
|
|
assert ".to(device=output_device)" in helper_source
|
|
|
|
|
|
def test_seedvr2_post_processing_lab_chunking_is_frame_independent(monkeypatch):
|
|
decoded = torch.linspace(-0.9, 0.9, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
|
|
reference = torch.linspace(0.8, -0.8, 3 * 3 * 24 * 24).reshape(3, 3, 24, 24)
|
|
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
|
|
|
one_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
|
|
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 1,
|
|
)
|
|
multi_frame = nodes_seedvr.SeedVR2PostProcessing._run_color_transfer_chunks(
|
|
decoded.clone(), reference.clone(), torch.device("cpu"), "lab", 3,
|
|
)
|
|
|
|
assert torch.equal(one_frame, multi_frame)
|
|
|
|
|
|
def test_seedvr2_post_processing_lab_retry_does_not_mutate_reference(monkeypatch):
|
|
decoded = torch.full((2, 3, 4, 4), 0.25)
|
|
reference = torch.full((2, 3, 4, 4), 0.75)
|
|
original_reference = reference.clone()
|
|
calls = []
|
|
cache_clears = []
|
|
|
|
def _lab(content, style):
|
|
calls.append((content.clone(), style.clone()))
|
|
style.add_(10.0)
|
|
if len(calls) == 1:
|
|
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
|
return content
|
|
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: cache_clears.append(True))
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
|
decoded, reference, torch.device("cpu"), "lab",
|
|
)
|
|
|
|
assert len(cache_clears) == 1
|
|
assert torch.equal(reference, original_reference)
|
|
assert torch.equal(calls[1][1], original_reference[0:1])
|
|
|
|
|
|
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
|
|
decoded = torch.full((1, 3, 4, 4), 0.25)
|
|
reference = torch.full((1, 3, 4, 4), 0.75)
|
|
|
|
def _lab(content, style):
|
|
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
|
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
try:
|
|
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
|
decoded, reference, torch.device("cpu"), "lab",
|
|
)
|
|
except RuntimeError as exc:
|
|
assert "color_correction_method=lab" in str(exc)
|
|
assert " method=lab" not in str(exc)
|
|
else:
|
|
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
|
|
|
|
|
def test_seedvr2_post_processing_raw_conversion_does_not_probe_full_tensor_range():
|
|
source = inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._to_seedvr2_raw)
|
|
|
|
assert ".amin" not in source
|
|
assert ".item" not in source
|
|
|
|
|
|
def test_seedvr2_post_processing_none_does_not_resize_reference_pixels():
|
|
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
|
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
|
|
|
with patch.object(nodes_seedvr, "side_resize") as resize:
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
|
|
|
resize.assert_not_called()
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
|
|
|
|
def test_seedvr2_post_processing_rejects_invalid_upscaled_shorter_edge():
|
|
decoded = torch.full((1, 2, 10, 12, 3), 0.25)
|
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
|
|
|
for edge in (None, 1, 1.5):
|
|
try:
|
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, edge, "none")
|
|
except ValueError as exc:
|
|
assert "upscaled_shorter_edge" in str(exc)
|
|
else:
|
|
raise AssertionError(f"expected ValueError for upscaled_shorter_edge={edge!r}")
|
|
|
|
|
|
def test_seedvr2_post_processing_lab_resizes_full_reference_frame():
|
|
decoded = torch.full((1, 2, 4, 5, 3), 0.25)
|
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
|
resize_calls = []
|
|
lab_calls = []
|
|
|
|
def _resize(images, size, interpolation=None, antialias=None):
|
|
resize_calls.append((images.clone(), size, interpolation, antialias))
|
|
if isinstance(size, int):
|
|
return torch.full((2, 3, size, round(images.shape[-1] * size / images.shape[-2])), 0.5)
|
|
return torch.full((2, 3, size[0], size[1]), 0.5)
|
|
|
|
def _lab(content, style):
|
|
lab_calls.append((content.clone(), style.clone()))
|
|
return torch.zeros_like(content)
|
|
|
|
with patch.object(nodes_seedvr.TVF, "resize", _resize):
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "lab").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 2, 4, 4, 3)
|
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
|
assert resize_calls[0][0].shape == (2, 3, 16, 20)
|
|
assert resize_calls[0][1] == 8
|
|
assert resize_calls[1][0].shape == (2, 3, 8, 10)
|
|
assert resize_calls[1][1] == (4, 5)
|
|
assert len(lab_calls) == 2
|
|
assert lab_calls[0][1].shape == (1, 3, 4, 5)
|
|
assert torch.equal(lab_calls[0][1], torch.zeros_like(lab_calls[0][1]))
|
|
|
|
|
|
def test_seedvr2_post_processing_none_trims_and_crops_without_color_correction():
|
|
decoded = torch.arange(1 * 3 * 9 * 11 * 3, dtype=torch.float32).reshape(1, 3, 9, 11, 3)
|
|
original = torch.zeros(1, 2, 16, 20, 3)
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
|
|
|
assert lab.call_count == 0
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
assert torch.equal(output, decoded[:, :2, :8, :10, :])
|
|
|
|
|
|
def test_seedvr2_post_processing_restores_flattened_padded_batches_before_trimming():
|
|
decoded = torch.arange(10 * 4 * 6 * 1, dtype=torch.float32).reshape(10, 4, 6, 1)
|
|
original = torch.zeros(2, 2, 4, 6, 1)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0]
|
|
|
|
expected = torch.cat((decoded[0:2], decoded[5:7]), dim=0)
|
|
assert tuple(output.shape) == (4, 4, 6, 1)
|
|
assert torch.equal(output, expected)
|
|
|
|
|
|
def test_seedvr2_post_processing_none_preserves_decoded_spatial_size_when_reference_is_larger():
|
|
decoded = torch.arange(1 * 3 * 8 * 10 * 3, dtype=torch.float32).reshape(1, 3, 8, 10, 3)
|
|
original = torch.zeros(1, 2, 16, 20, 3)
|
|
|
|
with patch.object(nodes_seedvr, "lab_color_transfer") as lab:
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 16, "none").result[0]
|
|
|
|
assert lab.call_count == 0
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
assert torch.equal(output, decoded[:, :2, :, :, :])
|
|
|
|
|
|
def test_seedvr2_post_processing_crops_to_reference_tensor_when_reference_is_smaller():
|
|
decoded = torch.ones((1, 1, 720, 1280, 3), dtype=torch.float32)
|
|
original = torch.ones((1, 1, 360, 640, 3), dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 360, "none").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 1, 360, 640, 3)
|
|
|
|
|
|
def test_seedvr2_post_processing_uses_decoded_size_when_reference_is_larger():
|
|
decoded = torch.ones((1, 1, 128, 160, 3), dtype=torch.float32)
|
|
original = torch.ones((1, 1, 480, 640, 3), dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 1, 128, 160, 3)
|
|
|
|
|
|
def test_seedvr2_post_processing_derives_crop_from_upscaled_shorter_edge():
|
|
decoded = torch.ones((1, 1, 128, 224, 3), dtype=torch.float32)
|
|
original = torch.ones((1, 1, 1080, 1920, 3), dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 1, 120, 212, 3)
|
|
|
|
|
|
def test_seedvr2_post_processing_uses_even_crop_from_odd_resized_width():
|
|
decoded = torch.ones((1, 1, 128, 256, 3), dtype=torch.float32)
|
|
original = torch.ones((1, 1, 120, 169, 3), dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 1, 120, 168, 3)
|
|
|
|
|
|
def test_seedvr2_post_processing_none_preserves_black_bottom_row_content():
|
|
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
|
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
|
original[:, :, -1, :, :] = -1.0
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
assert torch.equal(output, decoded)
|
|
|
|
|
|
def test_seedvr2_post_processing_none_preserves_black_right_column_content():
|
|
decoded = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
|
original = torch.ones((1, 2, 8, 10, 3), dtype=torch.float32)
|
|
original[:, :, :, -1, :] = -1.0
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "none").result[0]
|
|
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
assert torch.equal(output, decoded)
|
|
|
|
|
|
def test_seedvr2_post_processing_wavelet_dispatch_routes_through_wavelet_color_transfer():
|
|
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
|
wavelet_calls = []
|
|
lab_calls = []
|
|
|
|
def _wavelet(content, style):
|
|
wavelet_calls.append((content.clone(), style.clone()))
|
|
return torch.zeros_like(content)
|
|
|
|
def _lab(content, style):
|
|
lab_calls.append((content.clone(), style.clone()))
|
|
return torch.zeros_like(content)
|
|
|
|
with patch.object(nodes_seedvr, "wavelet_color_transfer", _wavelet):
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "wavelet").result[0]
|
|
|
|
assert len(wavelet_calls) == 1
|
|
assert len(lab_calls) == 0
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
|
assert wavelet_calls[0][0].shape == (2, 3, 8, 10)
|
|
assert wavelet_calls[0][1].shape == (2, 3, 8, 10)
|
|
assert torch.equal(wavelet_calls[0][0], torch.full_like(wavelet_calls[0][0], -0.5))
|
|
assert torch.allclose(wavelet_calls[0][1], torch.full_like(wavelet_calls[0][1], 0.5))
|
|
|
|
|
|
def test_seedvr2_post_processing_adain_dispatch_routes_through_adain_color_transfer():
|
|
decoded = torch.full((1, 3, 9, 11, 3), 0.25)
|
|
original = torch.full((1, 2, 16, 20, 3), 0.75)
|
|
adain_calls = []
|
|
lab_calls = []
|
|
|
|
def _adain(content, style):
|
|
adain_calls.append((content.clone(), style.clone()))
|
|
return torch.zeros_like(content)
|
|
|
|
def _lab(content, style):
|
|
lab_calls.append((content.clone(), style.clone()))
|
|
return torch.zeros_like(content)
|
|
|
|
with patch.object(nodes_seedvr, "adain_color_transfer", _adain):
|
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 8, "adain").result[0]
|
|
|
|
assert len(adain_calls) == 1
|
|
assert len(lab_calls) == 0
|
|
assert tuple(output.shape) == (1, 2, 8, 10, 3)
|
|
assert torch.equal(output, torch.full_like(output, 0.5))
|
|
assert adain_calls[0][0].shape == (2, 3, 8, 10)
|
|
assert adain_calls[0][1].shape == (2, 3, 8, 10)
|
|
|
|
|
|
def test_seedvr2_color_transfer_helper_runs_on_vae_device():
|
|
import inspect as _inspect
|
|
helper_source = _inspect.getsource(nodes_seedvr.SeedVR2PostProcessing._color_transfer_on_vae_device)
|
|
assert "comfy.model_management.vae_device()" in helper_source
|
|
assert ".to(device=color_device)" in helper_source
|
|
assert ".to(device=output_device)" in helper_source
|
|
assert "transfer_fn" in helper_source
|
|
|
|
|
|
def test_seedvr2_wavelet_color_transfer_matches_primary_source_reconstruction():
|
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
|
torch.manual_seed(0)
|
|
content = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
|
|
style = torch.rand(1, 3, 12, 16) * 2.0 - 1.0
|
|
out = seedvr_vae.wavelet_color_transfer(content, style)
|
|
expected = seedvr_vae.wavelet_reconstruction(content.clone(), style.clone())
|
|
assert torch.equal(out, expected)
|
|
|
|
|
|
def test_seedvr2_adain_color_transfer_matches_huang_belongie_formula():
|
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
|
torch.manual_seed(0)
|
|
content = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
|
|
style = torch.rand(2, 3, 5, 7) * 2.0 - 1.0
|
|
out = seedvr_vae.adain_color_transfer(content.clone(), style.clone())
|
|
|
|
b, c = 2, 3
|
|
cf = content.float().reshape(b, c, -1)
|
|
sf = style.float().reshape(b, c, -1)
|
|
eps = 1e-5
|
|
mu_c = cf.mean(dim=2).reshape(b, c, 1, 1)
|
|
sd_c = (cf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
|
mu_s = sf.mean(dim=2).reshape(b, c, 1, 1)
|
|
sd_s = (sf.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
|
expected = ((content.float() - mu_c) / sd_c) * sd_s + mu_s
|
|
expected = expected.clamp(-1.0, 1.0)
|
|
assert torch.allclose(out, expected, atol=1e-6)
|
|
|
|
|
|
def test_seedvr2_adain_single_pixel_uses_population_variance_without_nan():
|
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
|
content = torch.tensor([[[[0.25]], [[-0.5]], [[0.75]]]], dtype=torch.float32)
|
|
style = torch.tensor([[[[-0.25]], [[0.5]], [[-0.75]]]], dtype=torch.float32)
|
|
|
|
out = seedvr_vae.adain_color_transfer(content, style)
|
|
|
|
assert torch.isfinite(out).all()
|
|
assert torch.equal(out, style)
|
|
|
|
|
|
def test_seedvr2_adain_preserves_input_dtype():
|
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
|
content = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
|
|
style = (torch.rand(1, 3, 4, 4) * 2.0 - 1.0).to(torch.float16)
|
|
out = seedvr_vae.adain_color_transfer(content, style)
|
|
assert out.dtype == torch.float16
|
|
|
|
|
|
def test_seedvr2_adain_resizes_mismatched_style_to_content_shape():
|
|
from comfy.ldm.seedvr import vae as seedvr_vae
|
|
content = torch.rand(1, 3, 8, 10) * 2.0 - 1.0
|
|
style = torch.rand(1, 3, 16, 20) * 2.0 - 1.0
|
|
out = seedvr_vae.adain_color_transfer(content, style)
|
|
assert tuple(out.shape) == (1, 3, 8, 10)
|
|
|
|
|
|
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
|
decoded = torch.zeros(1, 2, 4, 4, 3)
|
|
original = torch.zeros(1, 2, 4, 4, 3)
|
|
try:
|
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus")
|
|
except ValueError as exc:
|
|
assert "color_correction_method" in str(exc)
|
|
else:
|
|
raise AssertionError("expected ValueError for unknown color_correction_method")
|