mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Add SeedVR2 core coverage
This commit is contained in:
parent
d54ce3d781
commit
0fdbc5d260
@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd():
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_7b_separate_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_7b_shared_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_3b_shared_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestModelDetection:
|
class TestModelDetection:
|
||||||
"""Verify that first-match model detection selects the correct model
|
"""Verify that first-match model detection selects the correct model
|
||||||
based on list ordering and unet_config specificity."""
|
based on list ordering and unet_config specificity."""
|
||||||
@ -125,6 +143,45 @@ class TestModelDetection:
|
|||||||
assert model_config is not None
|
assert model_config is not None
|
||||||
assert type(model_config).__name__ == "FluxSchnell"
|
assert type(model_config).__name__ == "FluxSchnell"
|
||||||
|
|
||||||
|
def test_seedvr2_7b_separate_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_7b_separate_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 3072
|
||||||
|
assert unet_config["heads"] == 24
|
||||||
|
assert unet_config["num_layers"] == 36
|
||||||
|
assert unet_config["mm_layers"] == 36
|
||||||
|
assert unet_config["mlp_type"] == "normal"
|
||||||
|
assert unet_config["rope_type"] == "rope3d"
|
||||||
|
assert unet_config["rope_dim"] == 64
|
||||||
|
|
||||||
|
def test_seedvr2_7b_shared_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_7b_shared_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 3072
|
||||||
|
assert unet_config["heads"] == 24
|
||||||
|
assert unet_config["num_layers"] == 36
|
||||||
|
assert unet_config["mm_layers"] == 10
|
||||||
|
assert unet_config["mlp_type"] == "swiglu"
|
||||||
|
assert unet_config["rope_type"] == "rope3d"
|
||||||
|
assert unet_config["rope_dim"] == 64
|
||||||
|
|
||||||
|
def test_seedvr2_3b_shared_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_3b_shared_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 2560
|
||||||
|
assert unet_config["heads"] == 20
|
||||||
|
assert unet_config["num_layers"] == 32
|
||||||
|
assert unet_config["mlp_type"] == "swiglu"
|
||||||
|
|
||||||
def test_unet_config_and_required_keys_combination_is_unique(self):
|
def test_unet_config_and_required_keys_combination_is_unique(self):
|
||||||
"""Each model in the registry must have a unique combination of
|
"""Each model in the registry must have a unique combination of
|
||||||
``unet_config`` and ``required_keys``. If two models share the same
|
``unet_config`` and ``required_keys``. If two models share the same
|
||||||
|
|||||||
49
tests-unit/comfy_test/test_seedvr2_dtype.py
Normal file
49
tests-unit/comfy_test/test_seedvr2_dtype.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.supported_models
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model
|
||||||
|
import comfy.ldm.seedvr.vae as seedvr_vae
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
|
||||||
|
bf16_device = object()
|
||||||
|
fp16_device = object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
comfy.supported_models.comfy.model_management,
|
||||||
|
"should_use_bf16",
|
||||||
|
lambda device=None: device is bf16_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||||
|
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
|
||||||
|
assert bf16_config.manual_cast_dtype is torch.bfloat16
|
||||||
|
|
||||||
|
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||||
|
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
|
||||||
|
assert fp16_config.manual_cast_dtype is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
|
||||||
|
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
|
||||||
|
|
||||||
|
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(txt, context.squeeze(0))
|
||||||
|
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
|
||||||
|
wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper)
|
||||||
|
estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160))
|
||||||
|
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
|
||||||
|
|
||||||
|
assert estimate == 101 * 960 * 1280 * 160
|
||||||
|
assert estimate > 15 * 1024 ** 3
|
||||||
|
assert estimate > old_estimate * 100
|
||||||
216
tests-unit/comfy_test/test_seedvr2_internals.py
Normal file
216
tests-unit/comfy_test/test_seedvr2_internals.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
"""Consolidated SeedVR2 internals regression tests.
|
||||||
|
|
||||||
|
Sources (all merged verbatim, helper names disambiguated where colliding):
|
||||||
|
|
||||||
|
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
||||||
|
memory_occupy against get_norm_limit(), not float('inf').
|
||||||
|
* SeedVR2 variable-length attention split-loop contract.
|
||||||
|
|
||||||
|
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
||||||
|
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
||||||
|
which probes torch.cuda.current_device() at import time unless args.cpu is
|
||||||
|
set first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||||
|
import comfy.ldm.modules.attention as attention # noqa: E402
|
||||||
|
import comfy.ops as comfy_ops # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.vae import ( # noqa: E402
|
||||||
|
causal_norm_wrapper,
|
||||||
|
set_norm_limit,
|
||||||
|
)
|
||||||
|
from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_NUM_CHANNELS = 8
|
||||||
|
_NUM_GROUPS = 4
|
||||||
|
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
|
||||||
|
|
||||||
|
_GROUPNORM_SUBCLASSES = [
|
||||||
|
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
|
||||||
|
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
|
||||||
|
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
||||||
|
real_group_norm = vae_mod.F.group_norm
|
||||||
|
set_norm_limit(1e-9)
|
||||||
|
try:
|
||||||
|
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
|
||||||
|
gn.eval()
|
||||||
|
|
||||||
|
forward_hook_calls = []
|
||||||
|
|
||||||
|
def _hook(module, inputs, output):
|
||||||
|
forward_hook_calls.append(tuple(inputs[0].shape))
|
||||||
|
|
||||||
|
spy_calls = []
|
||||||
|
|
||||||
|
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
|
||||||
|
spy_calls.append({"num_groups": int(num_groups_arg)})
|
||||||
|
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
|
||||||
|
|
||||||
|
handle = gn.register_forward_hook(_hook)
|
||||||
|
try:
|
||||||
|
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
|
||||||
|
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
|
||||||
|
finally:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
full_calls = len(forward_hook_calls)
|
||||||
|
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
|
||||||
|
|
||||||
|
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
|
||||||
|
assert full_calls == 0, (
|
||||||
|
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
|
||||||
|
)
|
||||||
|
assert chunked_calls > 0, (
|
||||||
|
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
set_norm_limit(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SeedVR2 var_attention split-loop tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
||||||
|
dim = 8
|
||||||
|
heads = 2
|
||||||
|
head_dim = 4
|
||||||
|
attn = seedvr_model.NaSwinAttention(
|
||||||
|
vid_dim=dim,
|
||||||
|
txt_dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
qk_bias=False,
|
||||||
|
qk_norm=seedvr_model.CustomRMSNorm,
|
||||||
|
qk_norm_eps=1e-6,
|
||||||
|
rope_type=None,
|
||||||
|
rope_dim=head_dim,
|
||||||
|
shared_weights=False,
|
||||||
|
window=(2, 1, 1),
|
||||||
|
window_method="720pwin_by_size_bysize",
|
||||||
|
version=True,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
operations=comfy_ops.disable_weight_init,
|
||||||
|
)
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(11)
|
||||||
|
vid = torch.randn(8, dim, generator=generator)
|
||||||
|
txt = torch.randn(3, dim, generator=generator)
|
||||||
|
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
|
||||||
|
txt_shape = torch.tensor([[3]], dtype=torch.long)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_optimized_var_attention(**kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
return kwargs["q"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
|
||||||
|
|
||||||
|
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
|
||||||
|
|
||||||
|
assert tuple(vid_out.shape) == (8, dim)
|
||||||
|
assert tuple(txt_out.shape) == (3, dim)
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert tuple(call["q"].shape) == (14, heads, head_dim)
|
||||||
|
assert tuple(call["k"].shape) == (14, heads, head_dim)
|
||||||
|
assert tuple(call["v"].shape) == (14, heads, head_dim)
|
||||||
|
assert call["heads"] == heads
|
||||||
|
assert call["skip_reshape"] is True
|
||||||
|
assert call["skip_output_reshape"] is True
|
||||||
|
torch.testing.assert_close(
|
||||||
|
call["cu_seqlens_q"],
|
||||||
|
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||||
|
rtol=0,
|
||||||
|
atol=0,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
call["cu_seqlens_k"],
|
||||||
|
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||||
|
rtol=0,
|
||||||
|
atol=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
|
||||||
|
heads = 2
|
||||||
|
head_dim = 3
|
||||||
|
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
|
||||||
|
k = q + 100
|
||||||
|
v = q + 200
|
||||||
|
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
|
||||||
|
calls.append(
|
||||||
|
{
|
||||||
|
"q_shape": tuple(q_arg.shape),
|
||||||
|
"k_shape": tuple(k_arg.shape),
|
||||||
|
"v_shape": tuple(v_arg.shape),
|
||||||
|
"heads": heads_arg,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return q_arg + v_arg
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
|
||||||
|
|
||||||
|
out = var_attention_optimized_split(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu,
|
||||||
|
cu,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == (5, heads, head_dim)
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
|
||||||
|
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
|
||||||
|
assert all(call["heads"] == heads for call in calls)
|
||||||
|
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
|
||||||
|
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
|
||||||
|
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_optimized_split_rejects_bad_offsets():
|
||||||
|
q = torch.randn(5, 2, 3)
|
||||||
|
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
|
||||||
|
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
|
||||||
|
var_attention_optimized_split(
|
||||||
|
q,
|
||||||
|
q,
|
||||||
|
q,
|
||||||
|
2,
|
||||||
|
cu_bad,
|
||||||
|
cu_ok,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=True,
|
||||||
|
)
|
||||||
307
tests-unit/comfy_test/test_seedvr2_model.py
Normal file
307
tests-unit/comfy_test/test_seedvr2_model.py
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
"""Consolidated SeedVR2 model/graph/forward regression tests.
|
||||||
|
|
||||||
|
Merged from:
|
||||||
|
- seedvr_model_test.py
|
||||||
|
- test_seedvr_7b_final_block_text_path.py
|
||||||
|
- test_seedvr_forward_no_device_cast.py
|
||||||
|
- test_seedvr_latent_format.py
|
||||||
|
- test_seedvr2_vae_graph_boundaries.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy # noqa: E402
|
||||||
|
import comfy.latent_formats # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.model # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||||
|
import comfy.model_management # noqa: E402
|
||||||
|
import comfy.sample # noqa: E402
|
||||||
|
import comfy.sd as sd_mod # noqa: E402
|
||||||
|
import nodes as nodes_mod # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from seedvr_model_test.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_standin(positive_conditioning):
|
||||||
|
class _StandIn(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer(
|
||||||
|
"positive_conditioning", positive_conditioning
|
||||||
|
)
|
||||||
|
|
||||||
|
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
|
||||||
|
|
||||||
|
return _StandIn()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from test_seedvr_7b_final_block_text_path.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _StubModule(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
|
||||||
|
flags = []
|
||||||
|
|
||||||
|
class _Block(_StubModule):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
flags.append(kwargs["is_last_layer"])
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
|
||||||
|
|
||||||
|
seedvr_model.NaDiT(
|
||||||
|
norm_eps=1e-5,
|
||||||
|
num_layers=4,
|
||||||
|
mlp_type="normal",
|
||||||
|
vid_dim=vid_dim,
|
||||||
|
txt_in_dim=txt_in_dim,
|
||||||
|
heads=24,
|
||||||
|
mm_layers=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from test_seedvr_latent_format.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _Model:
|
||||||
|
def __init__(self, latent_format):
|
||||||
|
self._latent_format = latent_format
|
||||||
|
|
||||||
|
def get_model_object(self, name):
|
||||||
|
assert name == "latent_format"
|
||||||
|
return self._latent_format
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from test_seedvr2_vae_graph_boundaries.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _Patcher:
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
return 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||||
|
def __init__(self, encoded):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.encoded = encoded
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.seen = []
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
self.seen.append(tuple(x.shape))
|
||||||
|
return self.encoded.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def decode(self, z, seedvr2_tiling=None):
|
||||||
|
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
||||||
|
if z.ndim == 4:
|
||||||
|
b, tc, h, w = z.shape
|
||||||
|
t = tc // 16
|
||||||
|
else:
|
||||||
|
b, _, t, h, w = z.shape
|
||||||
|
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vae(wrapper):
|
||||||
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||||
|
vae.first_stage_model = wrapper
|
||||||
|
vae.device = torch.device("cpu")
|
||||||
|
vae.output_device = torch.device("cpu")
|
||||||
|
vae.vae_dtype = torch.float32
|
||||||
|
vae.latent_channels = 16
|
||||||
|
vae.latent_dim = 3
|
||||||
|
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
||||||
|
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
vae.output_channels = 3
|
||||||
|
vae.disable_offload = True
|
||||||
|
vae.extra_1d_channel = None
|
||||||
|
vae.crop_input = False
|
||||||
|
vae.not_video = False
|
||||||
|
vae.patcher = _Patcher()
|
||||||
|
vae.process_input = lambda image: image
|
||||||
|
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||||
|
vae.vae_output_dtype = lambda: torch.float32
|
||||||
|
vae.memory_used_encode = lambda shape, dtype: 1
|
||||||
|
vae.memory_used_decode = lambda shape, dtype: 1
|
||||||
|
vae.throw_exception_if_invalid = lambda: None
|
||||||
|
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
||||||
|
vae.spacial_compression_decode = lambda: 8
|
||||||
|
vae.temporal_compression_decode = lambda: 4
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from seedvr_model_test.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_context_falls_back_to_positive_buffer():
|
||||||
|
"""``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion."""
|
||||||
|
pos_buffer = torch.full((58, 5120), 7.0)
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
txt, txt_shape = standin._resolve_text_conditioning(None)
|
||||||
|
assert txt.shape == (58, 5120)
|
||||||
|
assert (txt == 7.0).all(), (
|
||||||
|
"fallback path must use the positive_conditioning buffer "
|
||||||
|
"verbatim, not a zero tensor"
|
||||||
|
)
|
||||||
|
assert txt_shape.shape == (1, 1)
|
||||||
|
assert txt_shape[0, 0].item() == 58
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from test_seedvr_7b_final_block_text_path.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
|
||||||
|
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
||||||
|
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||||
|
q = torch.randn(4, 2, 128, generator=generator)
|
||||||
|
k = torch.randn(4, 2, 128, generator=generator)
|
||||||
|
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
||||||
|
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
||||||
|
|
||||||
|
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
|
||||||
|
freqs,
|
||||||
|
q.permute(1, 0, 2).float(),
|
||||||
|
).to(q.dtype).permute(1, 0, 2)
|
||||||
|
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
|
||||||
|
freqs,
|
||||||
|
k.permute(1, 0, 2).float(),
|
||||||
|
).to(k.dtype).permute(1, 0, 2)
|
||||||
|
|
||||||
|
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
|
||||||
|
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from test_seedvr_latent_format.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
|
||||||
|
latent_format = comfy.latent_formats.SeedVR2()
|
||||||
|
latent_image = torch.zeros(1, 1, 4, 5)
|
||||||
|
|
||||||
|
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
||||||
|
|
||||||
|
assert latent_format.latent_channels == 16
|
||||||
|
assert latent_format.latent_dimensions == 2
|
||||||
|
assert fixed.shape == (1, 16, 4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from test_seedvr2_vae_graph_boundaries.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
|
||||||
|
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
||||||
|
vae = _make_vae(_EncodeWrapper(encoded))
|
||||||
|
pixels = torch.zeros(1, 5, 32, 40, 3)
|
||||||
|
|
||||||
|
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
||||||
|
node_latent = node_output["samples"]
|
||||||
|
assert set(node_output) == {"samples"}
|
||||||
|
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
||||||
|
assert node_latent.dtype == torch.float32
|
||||||
|
assert node_latent.stride()[-1] == 1
|
||||||
|
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
||||||
|
|
||||||
|
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
||||||
|
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
||||||
|
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
||||||
|
vae,
|
||||||
|
pixels,
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)[0]
|
||||||
|
tiled_latent = tiled_output["samples"]
|
||||||
|
assert set(tiled_output) == {"samples"}
|
||||||
|
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
||||||
|
assert tiled_latent.dtype == torch.float32
|
||||||
|
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
||||||
|
|
||||||
|
|
||||||
|
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
nodes_mod.VAEDecodeTiled().decode(
|
||||||
|
vae,
|
||||||
|
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Spatial inputs flow through; temporal inputs are discarded — SeedVR2 owns
|
||||||
|
# temporal via the MemoryState causal cache, so VAEDecodeTiled's temporal
|
||||||
|
# knobs are no-ops at the wrapper.
|
||||||
|
assert vae.first_stage_model.calls == [
|
||||||
|
{
|
||||||
|
"shape": (1, 16, 2, 4, 5),
|
||||||
|
"seedvr2_tiling": {
|
||||||
|
"enable_tiling": True,
|
||||||
|
"tile_size": (512, 512),
|
||||||
|
"tile_overlap": (64, 64),
|
||||||
|
"temporal_size": 0,
|
||||||
|
"temporal_overlap": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
Loading…
Reference in New Issue
Block a user