mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
"""Regression tests for the SeedVR2 VAE forward return contract."""
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
from comfy.ldm.seedvr.vae import SEEDVR2_LATENT_CHANNELS, VideoAutoencoderKL # noqa: E402
|
|
|
|
|
|
_LATENT_SHAPE = (1, SEEDVR2_LATENT_CHANNELS, 2, 2, 2)
|
|
_DECODED_SHAPE = (1, 3, 5, 16, 16)
|
|
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
|
|
_INPUT_DECODE_SHAPE = _LATENT_SHAPE
|
|
|
|
|
|
class _StubVAE(VideoAutoencoderKL):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self._encode_out = torch.zeros(*_LATENT_SHAPE)
|
|
self._decode_out = torch.zeros(*_DECODED_SHAPE)
|
|
|
|
def encode(self, x, return_dict=True):
|
|
return self._encode_out
|
|
|
|
def decode_(self, z, return_dict=True):
|
|
return self._decode_out
|
|
|
|
|
|
def test_forward_encode_returns_tensor():
|
|
vae = _StubVAE()
|
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
|
result = vae.forward(x, mode="encode")
|
|
assert type(result) is torch.Tensor
|
|
assert result.shape == torch.Size(_LATENT_SHAPE)
|
|
|
|
|
|
def test_forward_decode_returns_tensor():
|
|
vae = _StubVAE()
|
|
z = torch.zeros(*_INPUT_DECODE_SHAPE)
|
|
result = vae.forward(z, mode="decode")
|
|
assert type(result) is torch.Tensor
|
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
|
|
|
|
|
class _TupleReturningStubVAE(VideoAutoencoderKL):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
|
|
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
|
|
|
|
def encode(self, x, return_dict=True):
|
|
return (self._encode_tensor,)
|
|
|
|
def decode_(self, z, return_dict=True):
|
|
return (self._decode_tensor,)
|
|
|
|
|
|
def test_forward_all_unwraps_one_tuple_at_each_step():
|
|
vae = _TupleReturningStubVAE()
|
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
|
result = vae.forward(x, mode="all")
|
|
assert type(result) is torch.Tensor
|
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
|
|
|
|
|
def test_forward_rejects_unknown_mode():
|
|
vae = _StubVAE()
|
|
with pytest.raises(ValueError, match="Unknown SeedVR2 VAE forward mode"):
|
|
vae.forward(torch.zeros(*_INPUT_ENCODE_SHAPE), mode="bogus")
|