mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
import inspect
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
|
|
|
VideoAutoencoderKLWrapper = vae_mod.VideoAutoencoderKLWrapper
|
|
|
|
|
|
_INPUT_SHAPE = (1, 3, 5, 16, 16)
|
|
_POSTERIOR_SHAPE = (1, 16, 1, 2, 2)
|
|
_DECODE_OUT_SHAPE = (1, 3, 5, 16, 16)
|
|
|
|
|
|
def _build_wrapper_standin() -> VideoAutoencoderKLWrapper:
|
|
wrapper = VideoAutoencoderKLWrapper.__new__(VideoAutoencoderKLWrapper)
|
|
nn.Module.__init__(wrapper)
|
|
return wrapper
|
|
|
|
|
|
def test_wrapper_forward_returns_tensor_triple(monkeypatch):
|
|
wrapper = _build_wrapper_standin()
|
|
wrapper.original_image_video = torch.zeros(*_INPUT_SHAPE)
|
|
wrapper.img_dims = (16, 16)
|
|
wrapper.freeze_encoder = True
|
|
|
|
posterior = torch.full(_POSTERIOR_SHAPE, 7.0)
|
|
decode_out = torch.full(_DECODE_OUT_SHAPE, 13.0)
|
|
|
|
def stub_encode(self, x, orig_dims=None):
|
|
return posterior.squeeze(2), posterior
|
|
|
|
def stub_decode(self, z):
|
|
return decode_out
|
|
|
|
monkeypatch.setattr(VideoAutoencoderKLWrapper, "encode", stub_encode)
|
|
monkeypatch.setattr(VideoAutoencoderKLWrapper, "decode", stub_decode)
|
|
|
|
x = torch.zeros(*_INPUT_SHAPE)
|
|
result = wrapper.forward(x)
|
|
|
|
assert isinstance(result, tuple)
|
|
assert len(result) == 3
|
|
x_out, z, p = result
|
|
assert type(x_out) is torch.Tensor
|
|
assert type(z) is torch.Tensor
|
|
assert type(p) is torch.Tensor
|
|
assert x_out.shape == decode_out.shape
|
|
assert z.shape == posterior.squeeze(2).shape
|
|
assert torch.equal(x_out, decode_out)
|
|
assert torch.equal(z, posterior.squeeze(2))
|
|
assert p is posterior
|
|
|
|
|
|
def test_wrapper_forward_source_has_no_sample_access():
|
|
src = inspect.getsource(VideoAutoencoderKLWrapper.forward)
|
|
assert ".sample" not in src
|