ComfyUI/tests-unit/comfy_test/test_seedvr_vae_decode_batch_axes.py
2026-05-26 00:28:36 -05:00

134 lines
4.0 KiB
Python

from unittest.mock import patch
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
return wrapper
def _fingerprint_decode_(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
out = torch.empty(b, 3, t, h * 8, w * 8)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def _decode_with_patches(wrapper, z):
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
return wrapper.decode(z)
def test_decode_b1_t1_shape_and_ordering_correct():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(1, 16, 2, 2))
assert tuple(out.shape) == (1, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
def test_decode_b1_t5_video_shape_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(1, 16 * 5, 2, 2))
assert tuple(out.shape) == (1, 3, 5, 16, 16)
def test_decode_b2_t1_preserves_batch_time_axes():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
assert tuple(out.shape) == (2, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
assert out[1, 0, 0, 0, 0].item() == 2.0
def test_decode_b4_t1_preserves_batch_time_axes():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(4, 16, 2, 2))
assert tuple(out.shape) == (4, 3, 1, 16, 16)
assert [out[b, 0, 0, 0, 0].item() for b in range(4)] == [1.0, 2.0, 3.0, 4.0]
def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16)
def _tiled_vae_4d_stub(latent, vae_model, **kwargs):
b = int(latent.shape[0])
h = int(latent.shape[3]) * 8
w = int(latent.shape[4]) * 8
out = torch.empty(b, 3, h, w)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def test_decode_tiled_single_frame_4d_output_normalized():
wrapper = _make_wrapper()
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
assert tuple(out.shape) == (1, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
def test_decode_tiled_b2_t1_per_sample_ordering():
wrapper = _make_wrapper()
with patch.object(vae_mod, "tiled_vae", _tiled_vae_4d_stub):
out = wrapper.decode(torch.zeros(2, 16, 2, 2), seedvr2_tiling={"enable_tiling": True})
assert tuple(out.shape) == (2, 3, 1, 16, 16)
assert out[0, 0, 0, 0, 0].item() == 1.0
assert out[1, 0, 0, 0, 0].item() == 2.0
def test_decode_b2_t1_stacked_equals_individual_per_sample_ordering():
wrapper = _make_wrapper()
out_stacked = _decode_with_patches(wrapper, torch.zeros(2, 16, 2, 2))
def _decode_pinned(value):
def _stub(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
return torch.full((b, 3, t, h * 8, w * 8), value)
return _stub
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(1.0)):
out_individual_0 = wrapper.decode(torch.zeros(1, 16, 2, 2))
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_pinned(2.0)):
out_individual_1 = wrapper.decode(torch.zeros(1, 16, 2, 2))
assert torch.equal(out_stacked[0, :, 0, :, :], out_individual_0[0, :, 0, :, :])
assert torch.equal(out_stacked[1, :, 0, :, :], out_individual_1[0, :, 0, :, :])