mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
228 lines
7.3 KiB
Python
228 lines
7.3 KiB
Python
import torch
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy_extras.nodes_seedvr as nodes_seedvr
|
|
import nodes
|
|
|
|
|
|
def test_seedvr2_postprocessing_restores_flat_decoded_batch_time():
|
|
decoded = torch.arange(6 * 4 * 6 * 1, dtype=torch.float32).reshape(6, 4, 6, 1)
|
|
original = torch.ones((2, 3, 4, 6, 1), dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0]
|
|
|
|
assert output.shape == (6, 4, 6, 1)
|
|
torch.testing.assert_close(output, decoded)
|
|
|
|
|
|
def test_seedvr2_postprocessing_crops_to_resized_original_size():
|
|
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
|
|
original = torch.full((1, 1, 120, 169, 3), 0.25, dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
|
|
|
assert output.shape == (1, 120, 168, 3)
|
|
|
|
|
|
def test_seedvr2_postprocessing_uses_decoded_size_when_resized_original_is_larger():
|
|
decoded = torch.ones((1, 128, 160, 3), dtype=torch.float32)
|
|
original = torch.full((1, 1, 480, 640, 3), 0.25, dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0]
|
|
|
|
assert output.shape == (1, 128, 160, 3)
|
|
|
|
|
|
def test_seedvr2_postprocessing_does_not_trim_real_black_original_edges():
|
|
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
|
|
original = torch.zeros((1, 1, 128, 176, 3), dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 128, "none").result[0]
|
|
|
|
assert output.shape == (1, 128, 176, 3)
|
|
|
|
|
|
def test_seedvr2_postprocessing_crops_height_only_to_resized_original_size():
|
|
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
|
|
original = torch.full((1, 1, 120, 176, 3), 0.25, dtype=torch.float32)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
|
|
|
|
assert output.shape == (1, 120, 176, 3)
|
|
|
|
|
|
def test_seedvr2_postprocessing_lab_uses_resized_original_size(monkeypatch):
|
|
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
|
|
original = torch.full((1, 1, 120, 169, 3), 0.25, dtype=torch.float32)
|
|
calls = []
|
|
|
|
def fake_lab_color_transfer(decoded_flat, reference_flat):
|
|
calls.append((tuple(decoded_flat.shape), tuple(reference_flat.shape)))
|
|
return decoded_flat
|
|
|
|
monkeypatch.setattr(nodes_seedvr, "lab_color_transfer", fake_lab_color_transfer)
|
|
|
|
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "lab").result[0]
|
|
|
|
assert calls == [((1, 3, 120, 169), (1, 3, 120, 169))]
|
|
assert output.shape == (1, 120, 168, 3)
|
|
|
|
|
|
def test_seedvr2_tiled_decode_node_ignores_seedvr2_sideband_metadata():
|
|
class FakeVAE:
|
|
def __init__(self):
|
|
self.decode_call = None
|
|
|
|
def temporal_compression_decode(self):
|
|
return 4
|
|
|
|
def spacial_compression_decode(self):
|
|
return 8
|
|
|
|
def decode_tiled(self, samples, **kwargs):
|
|
self.decode_call = kwargs
|
|
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
|
|
|
|
vae = FakeVAE()
|
|
samples = {
|
|
"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32),
|
|
"seedvr2_channel_last": True,
|
|
}
|
|
|
|
nodes.VAEDecodeTiled().decode(
|
|
vae,
|
|
samples,
|
|
tile_size=64,
|
|
overlap=0,
|
|
temporal_size=64,
|
|
temporal_overlap=8,
|
|
)
|
|
|
|
assert "seedvr2_channel_last" not in vae.decode_call
|
|
|
|
|
|
def test_seedvr2_decode_node_ignores_seedvr2_sideband_metadata():
|
|
class FakeVAE:
|
|
def __init__(self):
|
|
self.decode_call = None
|
|
|
|
def decode(self, samples, **kwargs):
|
|
self.decode_call = kwargs
|
|
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
|
|
|
|
vae = FakeVAE()
|
|
samples = {
|
|
"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32),
|
|
"seedvr2_channel_last": True,
|
|
}
|
|
|
|
nodes.VAEDecode().decode(vae, samples)
|
|
|
|
assert "seedvr2_channel_last" not in vae.decode_call
|
|
|
|
|
|
def test_seedvr2_decode_node_leaves_unmarked_ambiguous_latent_unforced():
|
|
class FakeVAE:
|
|
def __init__(self):
|
|
self.decode_call = None
|
|
|
|
def decode(self, samples, **kwargs):
|
|
self.decode_call = kwargs
|
|
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
|
|
|
|
vae = FakeVAE()
|
|
samples = {"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)}
|
|
|
|
nodes.VAEDecode().decode(vae, samples)
|
|
|
|
assert "seedvr2_channel_last" not in vae.decode_call
|
|
|
|
|
|
def test_seedvr2_encode_node_does_not_mark_model_specific_layout_metadata():
|
|
class FakeVAE:
|
|
def encode(self, pixels):
|
|
return torch.zeros((1, 16, 2, 3, 4), dtype=torch.float32)
|
|
|
|
output = nodes.VAEEncode().encode(FakeVAE(), torch.zeros((1, 8, 8, 3)))[0]
|
|
|
|
assert set(output) == {"samples"}
|
|
|
|
|
|
def test_seedvr2_tiled_encode_node_does_not_mark_model_specific_layout_metadata():
|
|
class FakeVAE:
|
|
def encode_tiled(self, pixels, **kwargs):
|
|
return torch.zeros((1, 16, 2, 3, 4), dtype=torch.float32)
|
|
|
|
output = nodes.VAEEncodeTiled().encode(FakeVAE(), torch.zeros((1, 8, 8, 3)), 64, 0)[0]
|
|
|
|
assert set(output) == {"samples"}
|
|
|
|
|
|
def test_seedvr2_saved_latent_does_not_persist_model_specific_layout_metadata(monkeypatch):
|
|
saved = {}
|
|
|
|
def fake_save_image_path(filename_prefix, output_dir):
|
|
return output_dir, filename_prefix, 1, "", filename_prefix
|
|
|
|
def fake_save_torch_file(output, file, metadata=None):
|
|
saved.update(output)
|
|
|
|
monkeypatch.setattr(nodes.folder_paths, "get_save_image_path", fake_save_image_path)
|
|
monkeypatch.setattr(nodes.comfy.utils, "save_torch_file", fake_save_torch_file)
|
|
monkeypatch.setattr(nodes.folder_paths, "get_annotated_filepath", lambda latent: latent)
|
|
monkeypatch.setattr(nodes.safetensors.torch, "load_file", lambda latent_path, device="cpu": saved)
|
|
|
|
original = torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)
|
|
nodes.SaveLatent().save({"samples": original, "seedvr2_channel_last": True}, "seedvr2_latent")
|
|
loaded = nodes.LoadLatent().load("seedvr2_latent")[0]
|
|
|
|
assert "seedvr2_channel_last" not in saved
|
|
assert "seedvr2_channel_last" not in loaded
|
|
torch.testing.assert_close(loaded["samples"], original)
|
|
|
|
|
|
def test_seedvr2_tiled_decode_node_preserves_legacy_decode_tiled_signature():
|
|
class FakeVAE:
|
|
def __init__(self):
|
|
self.decode_call = None
|
|
|
|
def temporal_compression_decode(self):
|
|
return 4
|
|
|
|
def spacial_compression_decode(self):
|
|
return 8
|
|
|
|
def decode_tiled(self, samples, tile_x, tile_y, overlap, tile_t, overlap_t):
|
|
self.decode_call = {
|
|
"tile_x": tile_x,
|
|
"tile_y": tile_y,
|
|
"overlap": overlap,
|
|
"tile_t": tile_t,
|
|
"overlap_t": overlap_t,
|
|
}
|
|
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
|
|
|
|
vae = FakeVAE()
|
|
samples = {"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)}
|
|
|
|
nodes.VAEDecodeTiled().decode(
|
|
vae,
|
|
samples,
|
|
tile_size=64,
|
|
overlap=0,
|
|
temporal_size=64,
|
|
temporal_overlap=8,
|
|
)
|
|
|
|
assert vae.decode_call == {
|
|
"tile_x": 8,
|
|
"tile_y": 8,
|
|
"overlap": 0,
|
|
"tile_t": 16,
|
|
"overlap_t": 2,
|
|
}
|