mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 10:57:23 +08:00
- Reduce SeedVR2 coverage down to production unit tests - Route SeedVR2 7B through Comfy varlength attention - Disable SeedVR2 RoPE cache reuse after the upstream DynamicVRAM change
92 lines
2.5 KiB
Python
92 lines
2.5 KiB
Python
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
|
from comfy_extras import nodes_seedvr # noqa: E402
|
|
|
|
|
|
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
|
|
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
|
vae_mod.VideoAutoencoderKLWrapper
|
|
)
|
|
nn.Module.__init__(wrapper)
|
|
return wrapper
|
|
|
|
|
|
def _fingerprint_decode_(self, z, return_dict=True):
|
|
b = int(z.shape[0])
|
|
t = int(z.shape[2])
|
|
h = int(z.shape[3])
|
|
w = int(z.shape[4])
|
|
out = torch.empty(b, 3, t, h * 8, w * 8)
|
|
for batch_idx in range(b):
|
|
out[batch_idx].fill_(float(batch_idx + 1))
|
|
return out
|
|
|
|
|
|
def _decode_with_patches(wrapper, z):
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
|
|
return wrapper.decode(z)
|
|
|
|
|
|
def test_decode_b2_t3_multi_frame_batch_unchanged():
|
|
wrapper = _make_wrapper()
|
|
|
|
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
|
|
|
|
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
|
|
|
|
|
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
|
|
def __init__(self):
|
|
nn.Module.__init__(self)
|
|
self.calls = []
|
|
|
|
def parameters(self):
|
|
return iter([torch.nn.Parameter(torch.zeros(()))])
|
|
|
|
def _decode_stub(self, latent):
|
|
self.calls.append(tuple(latent.shape))
|
|
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
|
|
wrapper = _Wrapper()
|
|
|
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
|
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
|
|
|
|
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
|
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
|
|
|
|
|
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
|
|
wrapper = _Wrapper()
|
|
|
|
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
|
|
wrapper.decode(torch.zeros(1, 16, 4))
|
|
|
|
|
|
def _t_padded(t_in: int) -> int:
|
|
if t_in == 1:
|
|
return 1
|
|
if t_in <= 4:
|
|
return 5
|
|
if (t_in - 1) % 4 == 0:
|
|
return t_in
|
|
return t_in + (4 - ((t_in - 1) % 4))
|
|
|
|
|
|
@pytest.mark.parametrize("t_in", [1, 5, 9])
|
|
def test_t_padded_matches_cut_videos(t_in):
|
|
dummy = torch.zeros(1, t_in, 1, 1, 1)
|
|
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)
|