Cleanups using AGENTS.md

This commit is contained in:
comfyanonymous 2026-07-01 22:17:51 -04:00
parent e595965392
commit f437d87155
15 changed files with 313 additions and 489 deletions

View File

@ -60,14 +60,7 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_dtype = q_i.dtype
if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
q_i = q_i.to(torch.bfloat16)
k_i = k_i.to(torch.bfloat16)
v_i = v_i.to(torch.bfloat16)
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
if out_i.dtype != out_dtype:
out_i = out_i.to(out_dtype)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)

View File

@ -2,8 +2,6 @@ import torch
import torch.nn.functional as F
from torch import Tensor
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.seedvr.vae import safe_interpolate_operation
from comfy.ldm.seedvr.constants import (
CIELAB_DELTA,
CIELAB_KAPPA,
@ -28,7 +26,7 @@ def wavelet_blur(image: Tensor, radius):
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
@ -49,8 +47,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
# safe_interpolate_operation handles FP16 conversion automatically
style_feat = safe_interpolate_operation(
style_feat = F.interpolate(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
@ -65,7 +62,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
del style_high_freq # Free memory immediately
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = safe_interpolate_operation(
style_low_freq = F.interpolate(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
@ -227,7 +224,7 @@ def lab_color_transfer(
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat = F.interpolate(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
@ -308,7 +305,7 @@ def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat = F.interpolate(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',

View File

@ -1,7 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops
from einops import rearrange
import torch.nn.functional as F
from math import ceil, pi
import torch
@ -23,52 +21,6 @@ from comfy.ldm.seedvr.constants import (
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
)
import comfy.model_management
import numbers
def _torch_float8_types():
return tuple(
getattr(torch, name)
for name in (
"float8_e4m3fn",
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
)
if hasattr(torch, name)
)
class CustomRMSNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None):
super(CustomRMSNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype))
else:
self.register_parameter('weight', None)
def forward(self, input):
dims = tuple(range(-len(self.normalized_shape), 0))
# Norm statistics in fp32 (fp16 variance underflows); activations return
# in the input dtype so downstream linears run at the model compute dtype.
normalized = input.float()
variance = normalized.pow(2).mean(dim=dims, keepdim=True)
rms = torch.sqrt(variance + self.eps)
normalized = normalized / rms
if self.elementwise_affine:
return (normalized * self.weight.to(torch.float32)).to(input.dtype)
return normalized.to(input.dtype)
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
@ -81,12 +33,10 @@ class Cache:
return fn()
key = self.prefix + key
try:
result = self.cache[key]
except KeyError:
if key not in self.cache:
result = fn()
self.cache[key] = result
return result
return self.cache[key]
def namespace(self, namespace: str):
return Cache(
@ -144,15 +94,6 @@ class MMArg:
vid: Any
txt: Any
def safe_pad_operation(x, padding, mode='constant', value=0.0):
try:
return F.pad(x, padding, mode=mode, value=value)
except RuntimeError as e:
if "not implemented for" in str(e) and x.dtype in (torch.float16, torch.bfloat16):
return F.pad(x.float(), padding, mode=mode, value=value).to(x.dtype)
raise
def get_args(key: str, args: List[Any]) -> List[Any]:
return [getattr(v, key) if isinstance(v, MMArg) else v for v in args]
@ -235,8 +176,6 @@ class RotaryEmbedding(nn.Module):
theta = 10000,
max_freq = 10,
learned_freq = False,
cache_if_possible = True,
cache_max_seq_len = 8192
):
super().__init__()
@ -247,12 +186,6 @@ class RotaryEmbedding(nn.Module):
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
self.cache_if_possible = cache_if_possible
self.cache_max_seq_len = cache_max_seq_len
self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False)
self.cached_freqs_seq_len = 0
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
self.learned_freq = learned_freq
@ -310,29 +243,10 @@ class RotaryEmbedding(nn.Module):
seq_len: int | None = None,
offset = 0
):
should_cache = (
self.cache_if_possible and
not self.learned_freq and
exists(seq_len) and
self.freqs_for != 'pixel' and
(offset + seq_len) <= self.cache_max_seq_len
)
if (
should_cache and \
exists(self.cached_freqs) and \
(offset + seq_len) <= self.cached_freqs_seq_len
):
return self.cached_freqs[offset:(offset + seq_len)].detach()
freqs = self.freqs
freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2)
if should_cache and offset == 0:
self.cached_freqs[:seq_len] = freqs.detach()
self.cached_freqs_seq_len = seq_len
freqs = freqs.unsqueeze(-1).expand(*freqs.shape, 2).flatten(-2)
return freqs
@ -346,7 +260,7 @@ class RotaryEmbeddingBase(nn.Module):
)
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.data)
self.rope.register_buffer("freqs", freqs.detach())
def get_axial_freqs(self, *dims):
return self.rope.get_axial_freqs(*dims)
@ -371,12 +285,12 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
]:
freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape))
freqs = freqs.to(device=q.device)
q = rearrange(q, "L h d -> h L d")
k = rearrange(k, "L h d -> h L d")
q = q.transpose(0, 1)
k = k.transpose(0, 1)
q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype)
k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype)
q = rearrange(q, "h L d -> L h d")
k = rearrange(k, "h L d -> L h d")
q = q.transpose(0, 1)
k = k.transpose(0, 1)
return q, k
@torch._dynamo.disable
@ -407,11 +321,10 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
dim=dim // rope_dim,
freqs_for="lang",
theta=ROPE_THETA,
cache_if_possible=False,
)
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.data)
self.rope.register_buffer("freqs", freqs.detach())
self.mm = True
def slice_at_dim(t, dim_slice: slice, *, dim):
@ -423,10 +336,10 @@ def slice_at_dim(t, dim_slice: slice, *, dim):
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
return x.flatten(-2)
def exists(val):
return val is not None
@ -465,7 +378,7 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
cos = torch.cos(angles)
sin = torch.sin(angles)
out = torch.stack([cos, -sin, sin, cos], dim=-1)
return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2)
return out.reshape(*out.shape[:-1], 2, 2)
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
@ -516,19 +429,19 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
vid_freqs = vid_freqs.to(target_device)
if txt_freqs.device != target_device:
txt_freqs = txt_freqs.to(target_device)
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q = vid_q.transpose(0, 1)
vid_k = vid_k.transpose(0, 1)
vid_q = _apply_rope1_partial(vid_q, vid_freqs)
vid_k = _apply_rope1_partial(vid_k, vid_freqs)
vid_q = rearrange(vid_q, "h L d -> L h d")
vid_k = rearrange(vid_k, "h L d -> L h d")
vid_q = vid_q.transpose(0, 1)
vid_k = vid_k.transpose(0, 1)
txt_q = rearrange(txt_q, "L h d -> h L d")
txt_k = rearrange(txt_k, "L h d -> h L d")
txt_q = txt_q.transpose(0, 1)
txt_k = txt_k.transpose(0, 1)
txt_q = _apply_rope1_partial(txt_q, txt_freqs)
txt_k = _apply_rope1_partial(txt_k, txt_freqs)
txt_q = rearrange(txt_q, "h L d -> L h d")
txt_k = rearrange(txt_k, "h L d -> L h d")
txt_q = txt_q.transpose(0, 1)
txt_k = txt_k.transpose(0, 1)
return vid_q, vid_k, txt_q, txt_k
@torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks
@ -684,7 +597,7 @@ def window(
):
hid = unflatten(hid, hid_shape)
hid = list(map(window_fn, hid))
hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device)
hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device)
hid, hid_shape = flatten(list(chain(*hid)))
return hid, hid_shape, hid_windows
@ -747,8 +660,8 @@ class NaSwinAttention(NaMMAttention):
)
vid_qkv_win = window_partition(vid_qkv)
vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim)
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
vid_qkv_win = vid_qkv_win.reshape(vid_qkv_win.shape[0], 3, self.heads, self.head_dim)
txt_qkv = txt_qkv.reshape(txt_qkv.shape[0], 3, self.heads, self.head_dim)
vid_q, vid_k, vid_v = vid_qkv_win.unbind(1)
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
@ -768,19 +681,19 @@ class NaSwinAttention(NaMMAttention):
elif self.rope.mm:
# repeat text q and k for window mmrope
_, num_h, _ = txt_q.shape
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
txt_q_repeat = txt_q.flatten(1, 2)
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
txt_q_repeat = list(chain(*txt_q_repeat))
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim)
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
txt_k_repeat = txt_k.flatten(1, 2)
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
txt_k_repeat = list(chain(*txt_k_repeat))
txt_k_repeat, _ = flatten(txt_k_repeat)
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim)
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
@ -799,16 +712,16 @@ class NaSwinAttention(NaMMAttention):
v=concat_win(vid_v, txt_v),
heads=self.heads, skip_reshape=True, skip_output_reshape=True,
cu_seqlens_q=cache_win(
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
"vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
),
cu_seqlens_k=cache_win(
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
"vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
),
)
vid_out, txt_out = unconcat_win(out)
vid_out = rearrange(vid_out, "l h d -> l (h d)")
txt_out = rearrange(txt_out, "l h d -> l (h d)")
vid_out = vid_out.flatten(1, 2)
txt_out = txt_out.flatten(1, 2)
vid_out = window_reverse(vid_out)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
@ -1005,7 +918,9 @@ class PatchOut(nn.Module):
) -> torch.Tensor:
t, h, w = self.patch_size
vid = self.proj(vid)
vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w)
b, T, H, W, channels = vid.shape
c = channels // (t * h * w)
vid = vid.view(b, T, H, W, t, h, w, c).permute(0, 7, 1, 4, 2, 5, 3, 6).reshape(b, c, T * t, H * h, W * w)
if t > 1:
vid = vid[:, :, (t - 1) :]
return vid
@ -1015,7 +930,7 @@ class NaPatchOut(PatchOut):
self,
vid: torch.FloatTensor, # l c
vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), # for test
cache: Cache = Cache(disable=True),
vid_shape_before_patchify = None
) -> Tuple[
torch.FloatTensor,
@ -1028,7 +943,9 @@ class NaPatchOut(PatchOut):
if not (t == h == w == 1):
vid = unflatten(vid, vid_shape)
for i in range(len(vid)):
vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w)
T, H, W, channels = vid[i].shape
c = channels // (t * h * w)
vid[i] = vid[i].view(T, H, W, t, h, w, c).permute(0, 3, 1, 4, 2, 5, 6).reshape(T * t, H * h, W * w, c)
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :]
vid, vid_shape = flatten(vid)
@ -1056,7 +973,8 @@ class PatchIn(nn.Module):
if t > 1:
assert vid.size(2) % t == 1
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2)
vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
b, c, Tt, Hh, Ww = vid.shape
vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c)
vid = self.proj(vid)
return vid
@ -1065,7 +983,7 @@ class NaPatchIn(PatchIn):
self,
vid: torch.Tensor, # l c
vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), # for test
cache: Cache = Cache(disable=True),
) -> torch.Tensor:
cache = cache.namespace("patch")
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape)
@ -1075,7 +993,8 @@ class NaPatchIn(PatchIn):
for i in range(len(vid)):
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0)
vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w)
Tt, Hh, Ww, c = vid[i].shape
vid[i] = vid[i].view(Tt // t, t, Hh // h, h, Ww // w, w, c).permute(0, 2, 4, 1, 3, 5, 6).reshape(Tt // t, Hh // h, Ww // w, t * h * w * c)
vid, vid_shape = flatten(vid)
vid = self.proj(vid)
@ -1102,17 +1021,14 @@ class AdaSingle(nn.Module):
self.emb_dim = emb_dim
self.layers = layers
param_kwargs = {"device": device}
fp8_types = _torch_float8_types()
if dtype is not None and dtype not in fp8_types:
param_kwargs["dtype"] = dtype
param_kwargs = {"device": device, "dtype": dtype}
for l in layers:
if "in" in modes:
self.register_parameter(f"{l}_shift", nn.Parameter(torch.zeros(dim, **param_kwargs)))
self.register_parameter(f"{l}_scale", nn.Parameter(torch.ones(dim, **param_kwargs)))
self.register_parameter(f"{l}_shift", nn.Parameter(torch.empty(dim, **param_kwargs)))
self.register_parameter(f"{l}_scale", nn.Parameter(torch.empty(dim, **param_kwargs)))
if "out" in modes:
self.register_parameter(f"{l}_gate", nn.Parameter(torch.zeros(dim, **param_kwargs)))
self.register_parameter(f"{l}_gate", nn.Parameter(torch.empty(dim, **param_kwargs)))
def forward(
self,
@ -1125,7 +1041,7 @@ class AdaSingle(nn.Module):
hid_len: Optional[torch.LongTensor] = None, # b
) -> torch.FloatTensor:
idx = self.layers.index(layer)
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1)
if hid_len is not None:
@ -1145,17 +1061,6 @@ class AdaSingle(nn.Module):
getattr(self, f"{layer}_gate", None),
)
fp8_types = _torch_float8_types()
if fp8_types:
target_dtype = hid.dtype
if shiftB is not None and shiftB.dtype in fp8_types:
shiftB = shiftB.to(target_dtype)
if scaleB is not None and scaleB.dtype in fp8_types:
scaleB = scaleB.to(target_dtype)
if gateB is not None and gateB.dtype in fp8_types:
gateB = gateB.to(target_dtype)
if mode == "in":
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
if mode == "out":
@ -1213,7 +1118,7 @@ def flatten(
torch.LongTensor, # (b n)
]:
assert len(hid) > 0
shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid])
shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device)
hid = torch.cat([x.flatten(0, -2) for x in hid])
return hid, shape
@ -1227,19 +1132,6 @@ def unflatten(
hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)]
return hid
def repeat(
hid: torch.FloatTensor, # (L c)
hid_shape: torch.LongTensor, # (b n)
pattern: str,
**kwargs: Dict[str, torch.LongTensor], # (b)
) -> Tuple[
torch.FloatTensor,
torch.LongTensor,
]:
hid = unflatten(hid, hid_shape)
kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
class NaDiT(nn.Module):
def __init__(
@ -1275,23 +1167,11 @@ class NaDiT(nn.Module):
emb_dim = vid_dim * 6
window = num_layers * [(4,3,3)]
ada = AdaSingle
norm = CustomRMSNorm
qk_norm = CustomRMSNorm
norm = operations.RMSNorm
qk_norm = operations.RMSNorm
super().__init__()
# ``torch.empty`` returns uninitialized memory, not zeros. The
# SeedVR2Conditioning fail-loud guard at
# ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded"
# from "buffer was never populated by the file" by checking
# ``positive_conditioning.abs().sum() == 0``. That sentinel is only
# reliable if the post-construction buffer state is deterministically
# zero, so explicitly zero-fill here rather than relying on the
# allocator's zero-on-alloc behavior (allocator-dependent and not
# contractual). When ``load_state_dict`` populates these buffers
# from a properly-baked SeedVR2 .safetensors, the in-place copy
# overwrites the zeros with the universal SeedVR2 conditioning
# tensors (shape (58, 5120) and (64, 5120) bf16).
self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype))
self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype))
self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype))
self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype))
self.vid_in = NaPatchIn(
in_channels=vid_in_channels,
patch_size=patch_size,
@ -1354,7 +1234,7 @@ class NaDiT(nn.Module):
self.vid_out_norm = None
if vid_out_norm is not None:
self.vid_out_norm = CustomRMSNorm(
self.vid_out_norm = operations.RMSNorm(
normalized_shape=vid_dim,
eps=norm_eps,
elementwise_affine=True,
@ -1369,7 +1249,7 @@ class NaDiT(nn.Module):
)
def _resolve_text_conditioning(self, context, cond_or_uncond=None):
if context is None or getattr(context, "numel", lambda: None)() == 0:
if context is None or context.numel() == 0:
context = self.positive_conditioning
return flatten([context])
if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond):
@ -1407,7 +1287,7 @@ class NaDiT(nn.Module):
x,
timestep,
context, # l c
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
disable_cache: bool = False,
**kwargs
):
transformer_options = kwargs.get("transformer_options", {})
@ -1483,5 +1363,5 @@ class NaDiT(nn.Module):
vid = unflatten(vid, vid_shape)
out = torch.stack(vid)
out = out.movedim(-1, 1)
out = rearrange(out, "b c t h w -> b (c t) h w")
out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4])
return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond"))

View File

@ -1,15 +1,11 @@
from contextlib import nullcontext
from typing import Literal, Optional, Tuple
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from contextlib import contextmanager
from comfy.utils import ProgressBar
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.seedvr.constants import (
BYTEDANCE_BLOCK_OUT_CHANNELS,
BYTEDANCE_GN_CHUNKS_FP16,
@ -58,13 +54,6 @@ def _seedvr2_clamped_spatial_overlap(overlap, tile_size):
return min(overlap, tile_size - 1)
def _seedvr2_clear_temporal_memory(model):
for module in model.modules():
if hasattr(module, "memory"):
module.memory = None
@torch.inference_mode()
def tiled_vae(
x,
vae_model,
@ -75,10 +64,6 @@ def tiled_vae(
encode=True,
**kwargs,
):
gc.collect()
comfy.model_management.soft_empty_cache()
x = x.to(next(vae_model.parameters()).dtype)
if x.ndim != 5:
x = x.unsqueeze(2)
@ -121,7 +106,6 @@ def tiled_vae(
count = None
def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device):
device = torch.device(device)
_seedvr2_clear_temporal_memory(model)
t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous()
old_device = getattr(model, "device", None)
model.device = device
@ -133,7 +117,7 @@ def tiled_vae(
setattr(model, slicing_attr, slicing_min_size)
try:
if encode:
out = model.encode(t_chunk)[0]
out = model.encode(t_chunk)
else:
out = model.decode_(t_chunk)
finally:
@ -141,8 +125,6 @@ def tiled_vae(
setattr(model, slicing_attr, old_slicing_min_size)
if old_device is not None:
model.device = old_device
if isinstance(out, (tuple, list)):
out = out[0]
if out.ndim == 4:
out = out.unsqueeze(2)
return out.to(storage_device)
@ -169,8 +151,6 @@ def tiled_vae(
bar = ProgressBar(total_tiles)
single_spatial_tile = h <= ti_h and w <= ti_w
_seedvr2_clear_temporal_memory(vae_model)
def run_tile(tile_index, tile_range):
y_idx, y_end, x_idx, x_end = tile_range
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
@ -186,7 +166,6 @@ def tiled_vae(
if single_spatial_tile:
result = tile_out[:, :, :target_d, :target_h, :target_w]
_seedvr2_clear_temporal_memory(vae_model)
if result.device != x.device:
result = result.to(x.device).to(x.dtype)
if x.shape[2] == 1 and sf_t == 1:
@ -241,7 +220,6 @@ def tiled_vae(
bar.update(1)
result.div_(count.clamp(min=1e-6))
_seedvr2_clear_temporal_memory(vae_model)
if result.device != x.device:
result = result.to(x.device).to(x.dtype)
@ -336,7 +314,6 @@ class Attention(nn.Module):
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
out_dim: int = None,
pre_only=False,
):
@ -356,10 +333,6 @@ class Attention(nn.Module):
self.out_dim = out_dim if out_dim is not None else query_dim
self.pre_only = pre_only
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
@ -480,21 +453,21 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)):
if x.ndim == 4:
x = rearrange(x, "b c h w -> b h w c")
x = x.permute(0, 2, 3, 1)
x = norm_layer(x)
x = rearrange(x, "b h w c -> b c h w")
x = x.permute(0, 3, 1, 2)
return x.to(input_dtype)
if x.ndim == 5:
x = rearrange(x, "b c t h w -> b t h w c")
x = x.permute(0, 2, 3, 4, 1)
x = norm_layer(x)
x = rearrange(x, "b t h w c -> b c t h w")
x = x.permute(0, 4, 1, 2, 3)
return x.to(input_dtype)
if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
if x.ndim <= 4:
return norm_layer(x).to(input_dtype)
if x.ndim == 5:
t = x.size(2)
x = rearrange(x, "b c t h w -> (b t) c h w")
b, c, t, h, w = x.shape
x = x.transpose(1, 2).reshape(b * t, c, h, w)
memory_occupy = x.numel() * x.element_size() / 1024**3
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit():
num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups)
@ -504,54 +477,16 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
x = list(x.chunk(num_chunks, dim=1))
weights = norm_layer.weight.chunk(num_chunks, dim=0)
biases = norm_layer.bias.chunk(num_chunks, dim=0)
for i, (w, b) in enumerate(zip(weights, biases)):
x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps)
for i, (w, bias) in enumerate(zip(weights, biases)):
x[i] = F.group_norm(x[i], num_groups_per_chunk, w, bias, norm_layer.eps)
x[i] = x[i].to(input_dtype)
x = torch.cat(x, dim=1)
else:
x = norm_layer(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2)
return x.to(input_dtype)
raise NotImplementedError
def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
if mode in problematic_modes:
try:
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor
)
except RuntimeError as e:
if ("not implemented for 'Half'" in str(e) or
"compute_indices_weights" in str(e)):
original_dtype = x.dtype
return F.interpolate(
x.float(),
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor
).to(original_dtype)
else:
raise e
else:
# Pour 'nearest' et autres modes compatibles, pas de fix nécessaire
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor
)
_receptive_field_t = Literal["half", "full"]
def extend_head(tensor, times: int = 2, memory = None):
@ -585,7 +520,6 @@ class InflatedCausalConv3d(ops.Conv3d):
**kwargs,
):
self.inflation_mode = inflation_mode
self.memory = None
super().__init__(*args, **kwargs)
self.temporal_padding = self.padding[0]
self.padding = (0, *self.padding[1:])
@ -620,18 +554,19 @@ class InflatedCausalConv3d(ops.Conv3d):
return super().forward(x)
# Compute tensor shape after concat & padding.
shape = torch.tensor(x.size())
shape = list(x.size())
if prev_cache is not None:
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0)
memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB
for i, pad_sum in enumerate((padding[4] + padding[5], padding[2] + padding[3], padding[0] + padding[1])):
shape[-3 + i] += pad_sum
memory_occupy = math.prod(shape) * x.element_size() / 1024**3 # GiB
if memory_occupy < self.memory_limit or split_dim == x.ndim:
x_concat = x
if prev_cache is not None:
x_concat = torch.cat([prev_cache, x], dim=split_dim - 1)
def pad_and_forward():
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
padded = F.pad(x_concat, padding, mode='constant', value=0.0)
if not padded.is_contiguous():
padded = padded.contiguous()
with ignore_padding(self):
@ -689,46 +624,57 @@ class InflatedCausalConv3d(ops.Conv3d):
def forward(
self,
input,
memory_state: MemoryState = MemoryState.UNSET
memory_state: MemoryState = MemoryState.UNSET,
memory_cache = None,
) -> Tensor:
assert memory_state != MemoryState.UNSET
if memory_cache is None:
memory_cache = {}
if memory_state != MemoryState.ACTIVE:
self.memory = None
memory_cache.pop(self, None)
if (
math.isinf(self.memory_limit)
and torch.is_tensor(input)
):
return self.basic_forward(input, memory_state)
return self.slicing_forward(input, memory_state)
return self.basic_forward(input, memory_state, memory_cache)
return self.slicing_forward(input, memory_state, memory_cache)
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET):
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET, memory_cache = None):
mem_size = self.stride[0] - self.kernel_size[0]
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE):
input = extend_head(input, memory=self.memory, times=-1)
memory = memory_cache.get(self) if memory_cache is not None else None
if (memory is not None) and (memory_state == MemoryState.ACTIVE):
input = extend_head(input, memory=memory, times=-1)
else:
input = extend_head(input, times=self.temporal_padding * 2)
memory = (
next_memory = (
input[:, :, mem_size:].detach()
if (mem_size != 0 and memory_state != MemoryState.DISABLED)
else None
)
if memory_state != MemoryState.DISABLED:
self.memory = memory
if memory_cache is not None and memory_state != MemoryState.DISABLED:
if next_memory is None:
memory_cache.pop(self, None)
else:
memory_cache[self] = next_memory
return super().forward(input)
def slicing_forward(
self,
input,
memory_state: MemoryState = MemoryState.UNSET,
memory_cache = None,
) -> Tensor:
if memory_cache is None:
memory_cache = {}
squeeze_out = False
if torch.is_tensor(input):
input = [input]
squeeze_out = True
cache_size = self.kernel_size[0] - self.stride[0]
memory = memory_cache.get(self) if memory_cache is not None else None
cache = cache_send_recv(
input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2
input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2
)
# Single GPU inference - simplified memory management
@ -740,7 +686,7 @@ class InflatedCausalConv3d(ops.Conv3d):
input[0] = torch.cat([cache, input[0]], dim=2)
cache = None
if cache_size <= input[-1].size(2):
self.memory = input[-1][:, :, -cache_size:].detach().contiguous()
memory_cache[self] = input[-1][:, :, -cache_size:].detach().contiguous()
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
for i in range(len(input)):
@ -802,17 +748,10 @@ class Upsample3D(nn.Module):
self.temporal_ratio = 2 if temporal_up else 1
self.spatial_ratio = 2 if spatial_up else 1
# [Override] MAGViT v2 learnable upsample
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio
self.upscale_conv = ops.Conv3d(
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0
)
identity = (
torch.eye(self.channels)
.repeat(upscale_ratio, 1)
.reshape_as(self.upscale_conv.weight)
)
self.upscale_conv.weight.data.copy_(identity)
self.conv = conv
@ -820,23 +759,27 @@ class Upsample3D(nn.Module):
self,
hidden_states: torch.FloatTensor,
memory_state=None,
memory_cache=None,
**kwargs,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.upscale_conv(hidden_states)
hidden_states = rearrange(
hidden_states,
"b (x y z c) f h w -> b c (f z) (h x) (w y)",
x=self.spatial_ratio,
y=self.spatial_ratio,
z=self.temporal_ratio,
b, channels, f, h, w = hidden_states.shape
c = channels // (self.spatial_ratio * self.spatial_ratio * self.temporal_ratio)
hidden_states = hidden_states.view(b, self.spatial_ratio, self.spatial_ratio, self.temporal_ratio, c, f, h, w)
hidden_states = hidden_states.permute(0, 4, 5, 3, 6, 1, 7, 2).reshape(
b,
c,
f * self.temporal_ratio,
h * self.spatial_ratio,
w * self.spatial_ratio,
)
if self.temporal_up and memory_state != MemoryState.ACTIVE:
hidden_states = remove_head(hidden_states)
hidden_states = self.conv(hidden_states, memory_state=memory_state)
hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
return hidden_states
@ -879,6 +822,7 @@ class Downsample3D(nn.Module):
self,
hidden_states: torch.FloatTensor,
memory_state = None,
memory_cache = None,
**kwargs,
) -> torch.FloatTensor:
@ -890,11 +834,11 @@ class Downsample3D(nn.Module):
if self.spatial_down:
pad = (0, 1, 0, 1)
hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, memory_state=memory_state)
hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
return hidden_states
@ -962,7 +906,7 @@ class ResnetBlock3D(nn.Module):
)
def forward(
self, input_tensor, temb, memory_state = None, **kwargs
self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs
):
hidden_states = input_tensor
@ -970,7 +914,7 @@ class ResnetBlock3D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states, memory_state=memory_state)
hidden_states = self.conv1(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
if self.time_emb_proj is not None:
if not self.skip_time_act:
@ -985,10 +929,10 @@ class ResnetBlock3D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, memory_state=memory_state)
hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state)
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state, memory_cache=memory_cache)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@ -1055,15 +999,16 @@ class DownEncoderBlock3D(nn.Module):
self,
hidden_states: torch.FloatTensor,
memory_state = None,
memory_cache = None,
**kwargs,
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
hidden_states = temporal(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, memory_state=memory_state)
hidden_states = downsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
return hidden_states
@ -1132,15 +1077,16 @@ class UpDecoderBlock3D(nn.Module):
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
memory_state=None
memory_state=None,
memory_cache=None,
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
hidden_states = temporal(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, memory_state=memory_state)
hidden_states = upsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
return hidden_states
@ -1203,7 +1149,6 @@ class UNetMidBlock3D(nn.Module):
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
else:
@ -1226,17 +1171,16 @@ class UNetMidBlock3D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, memory_state=None):
def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None):
video_length, frame_height, frame_width = hidden_states.size()[-3:]
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state)
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
b, c, f, h, w = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2).reshape(b * f, c, h, w)
hidden_states = attn(hidden_states, temb=temb)
hidden_states = rearrange(
hidden_states, "(b f) c h w -> b c f h w", f=video_length
)
hidden_states = resnet(hidden_states, temb, memory_state=memory_state)
hidden_states = hidden_states.reshape(b, video_length, c, h, w).transpose(1, 2)
hidden_states = resnet(hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache)
return hidden_states
@ -1327,22 +1271,23 @@ class Encoder3D(nn.Module):
def forward(
self,
sample: torch.FloatTensor,
memory_state = None
memory_state = None,
memory_cache = None,
) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample, memory_state = memory_state)
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
# down
for down_block in self.down_blocks:
sample = down_block(sample, memory_state=memory_state)
sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache)
# middle
sample = self.mid_block(sample, memory_state=memory_state)
sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state = memory_state)
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
return sample
@ -1436,24 +1381,25 @@ class Decoder3D(nn.Module):
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
memory_state = None,
memory_cache = None,
) -> torch.FloatTensor:
sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample, memory_state=memory_state)
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# middle
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds, memory_state=memory_state)
sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state=memory_state)
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
return sample
@ -1529,22 +1475,23 @@ class VideoAutoencoderKL(nn.Module):
return decoded
def _encode(
self, x, memory_state = MemoryState.DISABLED
self, x, memory_state = MemoryState.DISABLED, memory_cache = None
) -> torch.Tensor:
_x = x.to(self.device)
h = self.encoder(_x, memory_state=memory_state)
h = self.encoder(_x, memory_state=memory_state, memory_cache=memory_cache)
return h.to(x.device)
def _decode(
self, z, memory_state = MemoryState.DISABLED
self, z, memory_state = MemoryState.DISABLED, memory_cache = None
) -> torch.Tensor:
_z = z.to(self.device)
output = self.decoder(_z, memory_state=memory_state)
output = self.decoder(_z, memory_state=memory_state, memory_cache=memory_cache)
return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size =1
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
memory_cache = {}
split_size = max(
self.slicing_sample_min_size * sp_size,
getattr(self, "temporal_downsample_factor", 1),
@ -1558,17 +1505,14 @@ class VideoAutoencoderKL(nn.Module):
self._encode(
torch.cat((x[:, :, :1], x_slices[0]), dim=2),
memory_state=MemoryState.INITIALIZING,
memory_cache=memory_cache,
)
]
for x_idx in range(1, len(x_slices)):
encoded_slices.append(
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE)
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache)
)
out = torch.cat(encoded_slices, dim=2)
modules_with_memory = [m for m in self.modules()
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
for m in modules_with_memory:
m.memory = None
return out
else:
return self._encode(x)
@ -1576,22 +1520,20 @@ class VideoAutoencoderKL(nn.Module):
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = 1
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
memory_cache = {}
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
decoded_slices = [
self._decode(
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
memory_state=MemoryState.INITIALIZING
memory_state=MemoryState.INITIALIZING,
memory_cache=memory_cache,
)
]
for z_idx in range(1, len(z_slices)):
decoded_slices.append(
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE)
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache)
)
out = torch.cat(decoded_slices, dim=2)
modules_with_memory = [m for m in self.modules()
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
for m in modules_with_memory:
m.memory = None
return out
else:
return self._decode(z)
@ -1612,32 +1554,25 @@ class VideoAutoencoderKL(nn.Module):
return _unwrap(self.decode_(latent))
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
# Signals to comfy.sd.VAE that this model performs its own VAE tiling, so the
# generic tiled-decode/encode dispatch defers to decode_tiled/encode_tiled below.
comfy_handles_tiling = True
def __init__(
self,
*args,
spatial_downsample_factor = 8,
temporal_downsample_factor = 4,
freeze_encoder = True,
**kwargs,
):
self.spatial_downsample_factor = spatial_downsample_factor
self.temporal_downsample_factor = temporal_downsample_factor
self.freeze_encoder = freeze_encoder
self.enable_tiling = False
super().__init__(*args, **kwargs)
self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB)
def forward(self, x: torch.FloatTensor):
with torch.no_grad() if self.freeze_encoder else nullcontext():
z, p = self.encode(x)
z, p = self._encode_with_raw_latent(x)
x = self.decode(z)
return x, z, p
def encode(self, x, orig_dims=None):
def _encode_with_raw_latent(self, x):
if x.ndim == 4:
x = x.unsqueeze(2)
x = x.to(dtype=next(self.parameters()).dtype)
@ -1646,6 +1581,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
z = p.squeeze(2)
return z, p
def encode(self, x, orig_dims=None):
z, _ = self._encode_with_raw_latent(x)
return z
def decode(self, z, seedvr2_tiling=None):
seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling
if not isinstance(seedvr2_tiling, dict):

View File

@ -1151,9 +1151,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return unet_config
def model_config_from_unet_config(unet_config, state_dict=None):
def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""):
for model_config in comfy.supported_models.models:
if model_config.matches(unet_config, state_dict):
if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix):
return model_config(unet_config)
logging.error("no match {}".format(unet_config))
@ -1163,7 +1163,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None:
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
model_config = model_config_from_unet_config(unet_config, state_dict, unet_key_prefix)
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)

View File

@ -1,4 +1,3 @@
import inspect
import json
import torch
from enum import Enum
@ -500,6 +499,8 @@ class VAE:
self.upscale_index_formula = None
self.extra_1d_channel = None
self.crop_input = True
self.handles_tiling = False
self.format_encoded = None
self.audio_sample_rate = 44100
@ -554,6 +555,8 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.handles_tiling = True
self.format_encoded = self.first_stage_model.comfy_format_encoded
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
@ -1118,7 +1121,7 @@ class VAE:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
if self.handles_tiling:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
@ -1127,7 +1130,7 @@ class VAE:
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
if self.handles_tiling:
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
@ -1149,7 +1152,7 @@ class VAE:
args["overlap"] = overlap
with model_management.cuda_device_context(self.device):
if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3):
if self.handles_tiling and dims in (2, 3):
tiled_args = {}
if tile_x is not None:
tiled_args["tile_x"] = tile_x
@ -1204,8 +1207,6 @@ class VAE:
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
if isinstance(out, tuple):
out = out[0]
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
@ -1225,7 +1226,7 @@ class VAE:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
if self.handles_tiling:
samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
else:
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
@ -1234,9 +1235,8 @@ class VAE:
else:
samples = self.encode_tiled_(pixel_samples)
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
if formatter is not None:
samples = formatter(samples)
if self.format_encoded is not None:
samples = self.format_encoded(samples)
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
@ -1268,7 +1268,7 @@ class VAE:
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
if self.handles_tiling:
tiled_args = {}
if tile_x is not None:
tiled_args["tile_x"] = tile_x
@ -1298,9 +1298,8 @@ class VAE:
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
if formatter is not None:
samples = formatter(samples)
if self.format_encoded is not None:
samples = self.format_encoded(samples)
return samples
def get_sd(self):
@ -1852,16 +1851,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
set_dtype = model_config.set_inference_dtype
parameters = inspect.signature(set_dtype).parameters
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
if supports_device:
set_dtype(dtype, manual_cast_dtype, device=device)
else:
set_dtype(dtype, manual_cast_dtype)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
@ -1969,7 +1958,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device)
if model_config.clip_vision_prefix is not None:
if output_clipvision:
@ -2110,7 +2099,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device)
if custom_operations is not None:
model_config.custom_operations = custom_operations

View File

@ -1688,6 +1688,10 @@ class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
required_keys = {
"{}positive_conditioning",
"{}negative_conditioning",
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]

View File

@ -54,13 +54,13 @@ class BASE:
optimizations = {"fp8": False}
@classmethod
def matches(s, unet_config, state_dict=None):
def matches(s, unet_config, state_dict=None, unet_key_prefix=""):
for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
if k.format(unet_key_prefix) not in state_dict:
return False
return True

View File

@ -3,7 +3,6 @@ from comfy_api.latest import ComfyExtension, io
import torch
import math
import logging
from einops import rearrange
import comfy.model_management
import comfy.sample
@ -101,14 +100,6 @@ def _resolve_seedvr2_diffusion_model(model):
return diffusion_model
def _apply_rope_freqs_float32_cast(diffusion_model):
"""Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype."""
for module in diffusion_model.modules():
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
if module.rope.freqs.data.dtype != torch.float32:
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
@ -193,7 +184,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
images = images.reshape(b, t, c, new_h, new_w)
images = cut_videos(images)
images_bthwc = rearrange(images, "b t c h w -> b t h w c")
images_bthwc = images.permute(0, 1, 3, 4, 2).contiguous()
return io.NodeOutput(images_bthwc)
@ -265,12 +256,12 @@ class SeedVR2PostProcessing(io.ComfyNode):
output_device = decoded_5d.device
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
reference_raw = cls._to_seedvr2_raw(reference_5d)
decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w")
reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w")
decoded_flat = decoded_raw.permute(0, 1, 4, 2, 3).reshape(b * t, decoded_raw.shape[4], target_h, target_w)
reference_flat = reference_raw.permute(0, 1, 4, 2, 3).reshape(b * t, reference_raw.shape[4], target_h, target_w)
output = cls._color_transfer_chunked(
decoded_flat, reference_flat, output_device, color_correction_method,
)
output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t)
output = output.reshape(b, t, output.shape[1], output.shape[2], output.shape[3]).permute(0, 1, 3, 4, 2)
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
elif color_correction_method == "none":
output = decoded_5d
@ -359,7 +350,6 @@ class SeedVR2PostProcessing(io.ComfyNode):
) from e
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
comfy.model_management.soft_empty_cache()
chunk_size = next_chunk_size
@classmethod
@ -419,14 +409,14 @@ class SeedVR2PostProcessing(io.ComfyNode):
if reference.shape[2] == height and reference.shape[3] == width:
return reference
b, t = reference.shape[:2]
reference_flat = rearrange(reference, "b t h w c -> (b t) c h w")
reference_flat = reference.permute(0, 1, 4, 2, 3).reshape(b * t, reference.shape[4], reference.shape[2], reference.shape[3])
resized = TVF.resize(
reference_flat,
size=(height, width),
interpolation=InterpolationMode.BICUBIC,
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
)
return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t)
return resized.reshape(b, t, resized.shape[1], height, width).permute(0, 1, 3, 4, 2)
class SeedVR2Conditioning(io.ComfyNode):
@ -471,39 +461,12 @@ class SeedVR2Conditioning(io.ComfyNode):
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
# Fail-loud guard against silently-wrong output when a
# DiT-only ``.safetensors`` (no ``positive_conditioning`` /
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
# ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)``
# leaves them at zero when the keys are absent. Detect that state
# here rather than at ``BaseModel.extra_conds`` (per sampling step,
# wasteful) or at the resolver helper (mixes structural shape with
# semantic content). Both buffers must be checked together — partial
# bake regressions could populate one but not the other.
if (
pos_cond.float().abs().sum().item() == 0
and neg_cond.float().abs().sum().item() == 0
):
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
f"and negative_conditioning buffers are zero-valued — model "
f"file appears to be a DiT-only export missing "
f"the SeedVR2 conditioning tensors. "
f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
f"and ``negative_conditioning`` (64, 5120) keys at top level, "
f"or load via CheckpointLoaderSimple from a bundled "
f"checkpoint."
)
_apply_rope_freqs_float32_cast(model)
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning])
condition = condition.movedim(-1, 1)
latent = vae_conditioning.movedim(-1, 1)
latent = rearrange(latent, "b c t h w -> b (c t) h w")
condition = rearrange(condition, "b c t h w -> b (c t) h w")
latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4])
condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4])
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
@ -723,7 +686,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that
OOM on long sequences. The latent enters the sampler in SeedVR2's
collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning``
at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that
at ``reshape(b, c * t, h, w)``); this node slices that
tensor along the temporal axis, runs the configured inner sampler
sequentially per chunk against the standard ``comfy.sample.sample``
entry point, and concatenates per-chunk outputs back into a single
@ -882,7 +845,6 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
"frames_per_chunk=%s.",
attempt_frames_per_chunk, attempts[i + 1],
)
comfy.model_management.soft_empty_cache()
# Short-circuit: total fits in one chunk -> standard path with no
# chunking overhead. Output of this branch is byte-identical to the

View File

@ -11,7 +11,6 @@ import importlib
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
@ -53,7 +52,7 @@ def _import_nodes_seedvr_isolated():
mock_mm.WINDOWS = False
sys.modules["comfy.model_management"] = mock_mm
if sys.modules.get("comfy") is None:
import comfy as _comfy_pkg # noqa: F401
importlib.import_module("comfy")
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
@ -95,11 +94,10 @@ class _Block(nn.Module):
class _DiffusionModel(nn.Module):
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
def __init__(self, n_blocks=3, conditioning_dtype=torch.float32):
super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
pos = torch.zeros if zero_conditioning else torch.ones
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
self.register_buffer("positive_conditioning", torch.ones((2, 4), dtype=conditioning_dtype))
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
@ -185,29 +183,3 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
)
finally:
restore()
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(zero_conditioning=True)
patcher = _ModelPatcher(diffusion_model)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert message.startswith(
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
), (
"Fail-loud message must use the standard "
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
f"can match it. Got: {message!r}"
)
assert "positive_conditioning" in message
assert "negative_conditioning" in message
finally:
restore()

View File

@ -1,5 +1,6 @@
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args as cli_args
@ -32,26 +33,19 @@ def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypa
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
try:
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
decoded, reference, torch.device("cpu"), "lab",
)
except RuntimeError as exc:
assert "color_correction_method=lab" in str(exc)
assert " method=lab" not in str(exc)
else:
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
assert "color_correction_method=lab" in str(excinfo.value)
assert " method=lab" not in str(excinfo.value)
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
decoded = torch.zeros(1, 2, 4, 4, 3)
original = torch.zeros(1, 2, 4, 4, 3)
try:
with pytest.raises(ValueError) as excinfo:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
except ValueError as exc:
assert "color_correction_method" in str(exc)
else:
raise AssertionError("expected ValueError for unknown color_correction_method")
assert "color_correction_method" in str(excinfo.value)

View File

@ -2,7 +2,7 @@ from collections import defaultdict
import torch
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
from comfy.model_detection import detect_unet_config, model_config_from_unet, model_config_from_unet_config
import comfy.supported_models
@ -76,21 +76,31 @@ def _make_flux_schnell_comfyui_sd():
def _make_seedvr2_7b_separate_mm_sd():
return {
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
"positive_conditioning": torch.empty(58, 5120),
"negative_conditioning": torch.empty(64, 5120),
}
def _make_seedvr2_7b_shared_mm_sd():
return {
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
"positive_conditioning": torch.empty(58, 5120),
"negative_conditioning": torch.empty(64, 5120),
}
def _make_seedvr2_3b_shared_mm_sd():
return {
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
"positive_conditioning": torch.empty(58, 5120),
"negative_conditioning": torch.empty(64, 5120),
}
def _add_model_diffusion_prefix(sd):
return {f"model.diffusion_model.{k}": v for k, v in sd.items()}
class TestModelDetection:
"""Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity."""
@ -182,6 +192,20 @@ class TestModelDetection:
assert unet_config["num_layers"] == 32
assert unet_config["mlp_type"] == "swiglu"
def test_seedvr2_model_match_requires_conditioning_tensors(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert type(model_config_from_unet_config(unet_config, sd)).__name__ == "SeedVR2"
del sd["positive_conditioning"]
assert model_config_from_unet_config(unet_config, sd) is None
def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self):
sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd())
assert type(model_config_from_unet(sd, "model.diffusion_model.")).__name__ == "SeedVR2"
def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same

View File

@ -103,7 +103,7 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
heads=heads,
head_dim=head_dim,
qk_bias=False,
qk_norm=seedvr_model.CustomRMSNorm,
qk_norm=comfy_ops.disable_weight_init.RMSNorm,
qk_norm_eps=1e-6,
rope_type=None,
rope_dim=head_dim,

View File

@ -26,6 +26,7 @@ import comfy.ldm.seedvr.model # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
import comfy.sample # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
@ -81,6 +82,7 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis
txt_in_dim=txt_in_dim,
heads=24,
mm_layers=3,
operations=comfy_ops.disable_weight_init,
)
return flags
@ -140,6 +142,46 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
raw_latent = torch.full((1, 16, 1, 4, 5), 2.0)
seen_shapes = []
def base_encode(self, x):
seen_shapes.append(tuple(x.shape))
return raw_latent.to(device=x.device, dtype=x.dtype)
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
nn.Module.__init__(vae)
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
latent = vae.encode(torch.zeros(1, 3, 32, 40))
assert type(latent) is torch.Tensor
assert tuple(latent.shape) == (1, 16, 4, 5)
assert seen_shapes == [(1, 3, 1, 32, 40)]
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
raw_latent = torch.full((1, 16, 1, 4, 5), 3.0)
def base_encode(self, x):
return raw_latent.to(device=x.device, dtype=x.dtype)
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
nn.Module.__init__(vae)
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
assert tuple(latent.shape) == (1, 16, 4, 5)
assert tuple(raw.shape) == (1, 16, 1, 4, 5)
assert torch.equal(raw, raw_latent)
def _make_vae(wrapper):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = wrapper
@ -155,6 +197,8 @@ def _make_vae(wrapper):
vae.extra_1d_channel = None
vae.crop_input = False
vae.not_video = False
vae.handles_tiling = isinstance(wrapper, seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae.format_encoded = wrapper.comfy_format_encoded
vae.patcher = _Patcher()
vae.process_input = lambda image: image
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)

View File

@ -1,6 +1,7 @@
from contextlib import ExitStack
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
@ -21,8 +22,6 @@ from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
class StubVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
@ -37,9 +36,9 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
def decode_(self, t_chunk):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, t_chunk)
def _decode(self, z, memory_state=MemoryState.DISABLED):
def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None):
self.memory_states.append(memory_state)
b, c, d, h, w = z.shape
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
@ -68,8 +67,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
from comfy.ldm.seedvr.vae import tiled_vae
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
@ -85,8 +82,7 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises():
vae = RaisingVAEModel()
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
raised = False
try:
with pytest.raises(RuntimeError, match="simulated encode failure"):
tiled_vae(
x,
vae,
@ -96,15 +92,43 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises():
temporal_overlap=0,
encode=True,
)
except RuntimeError as exc:
if "simulated encode failure" not in str(exc):
raise
raised = True
assert raised
assert vae.slicing_sample_min_size == 4
def test_tiled_vae_encode_uses_tensor_return_without_indexing():
class TensorEncodeVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.calls = []
def encode(self, t_chunk):
self.calls.append(tuple(t_chunk.shape))
b, _, _, h, w = t_chunk.shape
return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype)
vae = TensorEncodeVAEModel()
x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32)
out = tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
assert vae.calls == [(2, 3, 1, 64, 64)]
assert tuple(out.shape) == (2, 16, 1, 8, 8)
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_temporal_slicing.py
# ---------------------------------------------------------------------------
@ -126,7 +150,7 @@ class _SlicingDecodeVAE(nn.Module):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
def _decode(self, z, memory_state=MemoryState.DISABLED):
def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None):
self.memory_states.append(memory_state)
x = z[:, :1].repeat(
1,
@ -205,6 +229,8 @@ def _make_vae(first_stage_model, latent_channels, latent_dim):
vae.latent_dim = latent_dim
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_decode = lambda: 8
vae.handles_tiling = isinstance(first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae.format_encoded = None
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
@ -240,7 +266,6 @@ def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
first_stage = MagicMock()
first_stage.comfy_handles_tiling = False
first_stage.decode = MagicMock(side_effect=_force_oom)
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
@ -273,6 +298,8 @@ def _populate_common_vae_attrs_fallback(vae):
vae.not_video = False
vae.crop_input = False
vae.pad_channel_value = None
vae.handles_tiling = isinstance(vae.first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae.format_encoded = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_encode = lambda: 8
@ -295,7 +322,6 @@ def _make_seedvr2_vae_fallback():
def _make_non_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock()
vae.first_stage_model.comfy_handles_tiling = False
_populate_common_vae_attrs_fallback(vae)
return vae