mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
86 lines
2.6 KiB
Python
86 lines
2.6 KiB
Python
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
|
|
|
|
|
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self.calls = []
|
|
|
|
def parameters(self):
|
|
return iter([torch.nn.Parameter(torch.zeros(()))])
|
|
|
|
def _decode_stub(self, latent):
|
|
self.calls.append(tuple(latent.shape))
|
|
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
|
|
wrapper = _Wrapper()
|
|
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
|
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
|
|
|
|
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
|
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_accepts_collapsed_4d_latents_without_preprocessor_state():
|
|
wrapper = _Wrapper()
|
|
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
|
out = wrapper.decode(torch.zeros(1, 32, 4, 5))
|
|
|
|
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
|
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_accepts_noncontiguous_collapsed_4d_latents():
|
|
wrapper = _Wrapper()
|
|
latent = torch.zeros(1, 4, 5, 32).permute(0, 3, 1, 2)
|
|
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
|
out = wrapper.decode(latent)
|
|
|
|
assert not latent.is_contiguous()
|
|
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
|
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_rejects_non_dict_tiling_options():
|
|
wrapper = _Wrapper()
|
|
|
|
with pytest.raises(RuntimeError, match="seedvr2_tiling.*dict"):
|
|
wrapper.decode(torch.zeros(1, 16, 2, 4, 5), seedvr2_tiling=True)
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_rejects_wrong_5d_channel_count():
|
|
wrapper = _Wrapper()
|
|
|
|
with pytest.raises(RuntimeError, match="5-D latent input must have 16 channels"):
|
|
wrapper.decode(torch.zeros(1, 8, 2, 4, 5))
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_rejects_misaligned_collapsed_4d_latents():
|
|
wrapper = _Wrapper()
|
|
|
|
with pytest.raises(RuntimeError, match=r"4-D latent input must use collapsed channel layout"):
|
|
wrapper.decode(torch.zeros(1, 17, 4, 5))
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
|
|
wrapper = _Wrapper()
|
|
|
|
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
|
|
wrapper.decode(torch.zeros(1, 16, 4))
|