ComfyUI/tests-unit/comfy_test/seedvr_vae_wrapper_forward_test.py
2026-05-26 00:28:29 -05:00

64 lines
1.8 KiB
Python

import inspect
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
VideoAutoencoderKLWrapper = vae_mod.VideoAutoencoderKLWrapper
_INPUT_SHAPE = (1, 3, 5, 16, 16)
_POSTERIOR_SHAPE = (1, 16, 1, 2, 2)
_DECODE_OUT_SHAPE = (1, 3, 5, 16, 16)
def _build_wrapper_standin() -> VideoAutoencoderKLWrapper:
wrapper = VideoAutoencoderKLWrapper.__new__(VideoAutoencoderKLWrapper)
nn.Module.__init__(wrapper)
return wrapper
def test_wrapper_forward_returns_tensor_triple(monkeypatch):
wrapper = _build_wrapper_standin()
wrapper.original_image_video = torch.zeros(*_INPUT_SHAPE)
wrapper.img_dims = (16, 16)
wrapper.freeze_encoder = True
posterior = torch.full(_POSTERIOR_SHAPE, 7.0)
decode_out = torch.full(_DECODE_OUT_SHAPE, 13.0)
def stub_encode(self, x, orig_dims=None):
return posterior.squeeze(2), posterior
def stub_decode(self, z):
return decode_out
monkeypatch.setattr(VideoAutoencoderKLWrapper, "encode", stub_encode)
monkeypatch.setattr(VideoAutoencoderKLWrapper, "decode", stub_decode)
x = torch.zeros(*_INPUT_SHAPE)
result = wrapper.forward(x)
assert isinstance(result, tuple)
assert len(result) == 3
x_out, z, p = result
assert type(x_out) is torch.Tensor
assert type(z) is torch.Tensor
assert type(p) is torch.Tensor
assert x_out.shape == decode_out.shape
assert z.shape == posterior.squeeze(2).shape
assert torch.equal(x_out, decode_out)
assert torch.equal(z, posterior.squeeze(2))
assert p is posterior
def test_wrapper_forward_source_has_no_sample_access():
src = inspect.getsource(VideoAutoencoderKLWrapper.forward)
assert ".sample" not in src