mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
211 lines
7.0 KiB
Python
211 lines
7.0 KiB
Python
from unittest.mock import MagicMock
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
|
import comfy.sd as sd_mod # noqa: E402
|
|
import nodes as nodes_mod # noqa: E402
|
|
|
|
|
|
class _Patcher:
|
|
def get_free_memory(self, device):
|
|
return 1024 * 1024 * 1024
|
|
|
|
|
|
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self, encoded):
|
|
nn.Module.__init__(self)
|
|
self.encoded = encoded
|
|
self.spatial_downsample_factor = 8
|
|
self.temporal_downsample_factor = 4
|
|
self.seen = []
|
|
|
|
def encode(self, x):
|
|
self.seen.append(tuple(x.shape))
|
|
return self.encoded.to(device=x.device, dtype=x.dtype)
|
|
|
|
|
|
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self.spatial_downsample_factor = 8
|
|
self.temporal_downsample_factor = 4
|
|
self.calls = []
|
|
|
|
def decode(self, z, seedvr2_tiling=None):
|
|
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
|
if z.ndim == 4:
|
|
b, tc, h, w = z.shape
|
|
t = tc // 16
|
|
else:
|
|
b, _, t, h, w = z.shape
|
|
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
|
|
|
|
|
def _make_vae(wrapper):
|
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
|
vae.first_stage_model = wrapper
|
|
vae.device = torch.device("cpu")
|
|
vae.output_device = torch.device("cpu")
|
|
vae.vae_dtype = torch.float32
|
|
vae.latent_channels = 16
|
|
vae.latent_dim = 3
|
|
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
|
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
|
vae.output_channels = 3
|
|
vae.disable_offload = True
|
|
vae.extra_1d_channel = None
|
|
vae.crop_input = False
|
|
vae.not_video = False
|
|
vae.patcher = _Patcher()
|
|
vae.process_input = lambda image: image
|
|
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
|
vae.vae_output_dtype = lambda: torch.float32
|
|
vae.memory_used_encode = lambda shape, dtype: 1
|
|
vae.memory_used_decode = lambda shape, dtype: 1
|
|
vae.throw_exception_if_invalid = lambda: None
|
|
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
|
vae.spacial_compression_decode = lambda: 8
|
|
vae.temporal_compression_decode = lambda: 4
|
|
return vae
|
|
|
|
|
|
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
|
|
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
|
vae = _make_vae(_EncodeWrapper(encoded))
|
|
pixels = torch.zeros(1, 5, 32, 40, 3)
|
|
|
|
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
|
node_latent = node_output["samples"]
|
|
assert set(node_output) == {"samples"}
|
|
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
|
assert node_latent.dtype == torch.float32
|
|
assert node_latent.stride()[-1] == 1
|
|
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
|
|
|
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
|
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
|
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
|
vae,
|
|
pixels,
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)[0]
|
|
tiled_latent = tiled_output["samples"]
|
|
assert set(tiled_output) == {"samples"}
|
|
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
|
assert tiled_latent.dtype == torch.float32
|
|
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
|
|
|
|
|
def test_seedvr2_decode_and_decode_tiled_do_not_require_preprocessor_state(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
latent = {"samples": torch.zeros(1, 32, 4, 5)}
|
|
decoded = nodes_mod.VAEDecode().decode(vae, latent)[0]
|
|
assert tuple(decoded.shape) == (2, 32, 40, 3)
|
|
|
|
tiled = nodes_mod.VAEDecodeTiled().decode(
|
|
vae,
|
|
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)[0]
|
|
assert tuple(tiled.shape) == (2, 32, 40, 3)
|
|
|
|
|
|
def test_seedvr2_vaedecode_does_not_repair_latent_layout(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
latent = {"samples": torch.zeros(1, 2, 4, 5, 16)}
|
|
nodes_mod.VAEDecode().decode(vae, latent)
|
|
|
|
assert vae.first_stage_model.calls == [{"shape": (1, 2, 4, 5, 16), "seedvr2_tiling": None}]
|
|
|
|
|
|
def test_seedvr2_vaedecode_keeps_public_channel_first_width_16_latents(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
nodes_mod.VAEDecode().decode(
|
|
vae,
|
|
{"samples": torch.zeros(1, 16, 4, 5, 16)},
|
|
)
|
|
|
|
assert vae.first_stage_model.calls == [{"shape": (1, 16, 4, 5, 16), "seedvr2_tiling": None}]
|
|
|
|
|
|
def test_seedvr2_direct_decode_preserves_channel_first_width_16(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
vae.decode(torch.zeros(1, 16, 2, 4, 16))
|
|
|
|
assert vae.first_stage_model.calls == [{"shape": (1, 16, 2, 4, 16), "seedvr2_tiling": None}]
|
|
|
|
|
|
def test_seedvr2_decode_tiled_preserves_direct_channel_first_width_16(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
vae.decode_tiled_seedvr2(torch.zeros(1, 16, 2, 4, 16))
|
|
|
|
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 2, 4, 16)
|
|
|
|
|
|
def test_seedvr2_vaedecode_tiled_keeps_public_channel_first_width_16_latents(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
nodes_mod.VAEDecodeTiled().decode(
|
|
vae,
|
|
{"samples": torch.zeros(1, 16, 4, 5, 16)},
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)
|
|
|
|
assert vae.first_stage_model.calls[0]["shape"] == (1, 16, 4, 5, 16)
|
|
|
|
|
|
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
|
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
|
vae = _make_vae(_DecodeWrapper())
|
|
|
|
nodes_mod.VAEDecodeTiled().decode(
|
|
vae,
|
|
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
|
tile_size=512,
|
|
overlap=64,
|
|
temporal_size=16,
|
|
temporal_overlap=4,
|
|
)
|
|
|
|
assert vae.first_stage_model.calls == [
|
|
{
|
|
"shape": (1, 16, 2, 4, 5),
|
|
"seedvr2_tiling": {
|
|
"enable_tiling": True,
|
|
"tile_size": (512, 512),
|
|
"tile_overlap": (64, 64),
|
|
"temporal_size": 16,
|
|
"temporal_overlap": 4,
|
|
},
|
|
}
|
|
]
|