mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
134 lines
4.0 KiB
Python
134 lines
4.0 KiB
Python
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
|
|
|
|
|
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
|
|
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
|
vae_mod.VideoAutoencoderKLWrapper
|
|
)
|
|
nn.Module.__init__(wrapper)
|
|
return wrapper
|
|
|
|
|
|
def _fingerprint_decode_(self, z, return_dict=True):
|
|
b = int(z.shape[0])
|
|
t = int(z.shape[2])
|
|
h = int(z.shape[3])
|
|
w = int(z.shape[4])
|
|
out = torch.empty(b, 3, t, h * 8, w * 8)
|
|
for batch_idx in range(b):
|
|
out[batch_idx].fill_(float(batch_idx + 1))
|
|
return out
|
|
|
|
|
|
def _decode_with_patches(wrapper, z):
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
|
|
return wrapper.decode(z)
|
|
|
|
|
|
def test_decode_b1_t1_shape_and_ordering_correct():
|
|
wrapper = _make_wrapper()
|
|
|
|
out = _decode_with_patches(wrapper, torch.zeros(1, 16, 2, 2))
|
|
|
|
assert tuple(out.shape) == (1, 3, 1, 16, 16)
|
|
assert out[0, 0, 0, 0, 0].item() == 1.0
|
|
|
|
|
|
def test_decode_b1_t5_video_shape_unchanged():
|
|
wrapper = _make_wrapper()
|
|
|
|
out = _decode_with_patches(wrapper, torch.zeros(1, 16 * 5, 2, 2))
|
|
|
|
assert tuple(out.shape) == (1, 3, 5, 16, 16)
|
|
|
|
|
|
def test_decode_b2_t1_preserves_batch_time_axes():
|
|
wrapper = _make_wrapper()
|
|
|
|
out = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
|
|
|
|
assert tuple(out.shape) == (2, 3, 1, 16, 16)
|
|
assert out[0, 0, 0, 0, 0].item() == 1.0
|
|
assert out[1, 0, 0, 0, 0].item() == 2.0
|
|
|
|
|
|
def test_decode_b4_t1_preserves_batch_time_axes():
|
|
wrapper = _make_wrapper()
|
|
|
|
out = _decode_with_patches(wrapper, torch.zeros(4, 16, 2, 2))
|
|
|
|
assert tuple(out.shape) == (4, 3, 1, 16, 16)
|
|
assert [out[b, 0, 0, 0, 0].item() for b in range(4)] == [1.0, 2.0, 3.0, 4.0]
|
|
|
|
|
|
def test_decode_b2_t3_multi_frame_batch_unchanged():
|
|
wrapper = _make_wrapper()
|
|
|
|
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
|
|
|
|
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
|
|
|
|
|
def _tiled_vae_4d_stub(latent, vae_model, **kwargs):
|
|
b = int(latent.shape[0])
|
|
h = int(latent.shape[3]) * 8
|
|
w = int(latent.shape[4]) * 8
|
|
out = torch.empty(b, 3, h, w)
|
|
for batch_idx in range(b):
|
|
out[batch_idx].fill_(float(batch_idx + 1))
|
|
return out
|
|
|
|
|
|
def test_decode_tiled_single_frame_4d_output_normalized():
|
|
wrapper = _make_wrapper()
|
|
|
|
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
|
|
out = wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
|
|
|
|
assert tuple(out.shape) == (1, 3, 1, 16, 16)
|
|
assert out[0, 0, 0, 0, 0].item() == 1.0
|
|
|
|
|
|
def test_decode_tiled_b2_t1_per_sample_ordering():
|
|
wrapper = _make_wrapper()
|
|
|
|
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
|
|
out = wrapper.decode(torch.zeros(2, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
|
|
|
|
assert tuple(out.shape) == (2, 3, 1, 16, 16)
|
|
assert out[0, 0, 0, 0, 0].item() == 1.0
|
|
assert out[1, 0, 0, 0, 0].item() == 2.0
|
|
|
|
|
|
def test_decode_b2_t1_stacked_equals_individual_per_sample_ordering():
|
|
wrapper = _make_wrapper()
|
|
out_stacked = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
|
|
|
|
def _decode_pinned(value):
|
|
def _stub(self, z, return_dict=True):
|
|
b = int(z.shape[0])
|
|
t = int(z.shape[2])
|
|
h = int(z.shape[3])
|
|
w = int(z.shape[4])
|
|
return torch.full((b, 3, t, h * 8, w * 8), value)
|
|
return _stub
|
|
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(1.0)):
|
|
out_individual_0 = wrapper.decode(torch.zeros(1, 16, 2, 2))
|
|
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(2.0)):
|
|
out_individual_1 = wrapper.decode(torch.zeros(1, 16, 2, 2))
|
|
|
|
assert torch.equal(out_stacked[0, :, 0, :, :], out_individual_0[0, :, 0, :, :])
|
|
assert torch.equal(out_stacked[1, :, 0, :, :], out_individual_1[0, :, 0, :, :])
|