diff --git a/comfy/ldm/seedvr/attention.py b/comfy/ldm/seedvr/attention.py index 29ffded38..5d4054ab9 100644 --- a/comfy/ldm/seedvr/attention.py +++ b/comfy/ldm/seedvr/attention.py @@ -60,14 +60,7 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a q_i = q_i.permute(1, 0, 2).unsqueeze(0) k_i = k_i.permute(1, 0, 2).unsqueeze(0) v_i = v_i.permute(1, 0, 2).unsqueeze(0) - out_dtype = q_i.dtype - if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): - q_i = q_i.to(torch.bfloat16) - k_i = k_i.to(torch.bfloat16) - v_i = v_i.to(torch.bfloat16) out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) - if out_i.dtype != out_dtype: - out_i = out_i.to(out_dtype) out.append(out_i.squeeze(0).permute(1, 0, 2)) out = torch.cat(out, dim=0) diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py index 7ddfc03af..440b3d26c 100644 --- a/comfy/ldm/seedvr/color_fix.py +++ b/comfy/ldm/seedvr/color_fix.py @@ -2,8 +2,6 @@ import torch import torch.nn.functional as F from torch import Tensor -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.vae import safe_interpolate_operation from comfy.ldm.seedvr.constants import ( CIELAB_DELTA, CIELAB_KAPPA, @@ -28,7 +26,7 @@ def wavelet_blur(image: Tensor, radius): kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) return output @@ -49,8 +47,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: if content_feat.shape != style_feat.shape: # Resize style to match content spatial dimensions if len(content_feat.shape) >= 3: - # safe_interpolate_operation handles FP16 conversion automatically - style_feat = safe_interpolate_operation( + style_feat = F.interpolate( style_feat, size=content_feat.shape[-2:], mode='bilinear', @@ -65,7 +62,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: del style_high_freq # Free memory immediately if content_high_freq.shape != style_low_freq.shape: - style_low_freq = safe_interpolate_operation( + style_low_freq = F.interpolate( style_low_freq, size=content_high_freq.shape[-2:], mode='bilinear', @@ -227,7 +224,7 @@ def lab_color_transfer( content_feat = wavelet_reconstruction(content_feat, style_feat) if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( + style_feat = F.interpolate( style_feat, size=content_feat.shape[-2:], mode='bilinear', @@ -308,7 +305,7 @@ def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( + style_feat = F.interpolate( style_feat, size=content_feat.shape[-2:], mode='bilinear', diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index e7d3deb35..ee50449a4 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,7 +1,5 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable -import einops -from einops import rearrange import torch.nn.functional as F from math import ceil, pi import torch @@ -23,52 +21,6 @@ from comfy.ldm.seedvr.constants import ( SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, ) import comfy.model_management -import numbers - -def _torch_float8_types(): - return tuple( - getattr(torch, name) - for name in ( - "float8_e4m3fn", - "float8_e4m3fnuz", - "float8_e5m2", - "float8_e5m2fnuz", - "float8_e8m0fnu", - ) - if hasattr(torch, name) - ) - -class CustomRMSNorm(nn.Module): - - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): - super(CustomRMSNorm, self).__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - - if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) - else: - self.register_parameter('weight', None) - - def forward(self, input): - - dims = tuple(range(-len(self.normalized_shape), 0)) - - # Norm statistics in fp32 (fp16 variance underflows); activations return - # in the input dtype so downstream linears run at the model compute dtype. - normalized = input.float() - variance = normalized.pow(2).mean(dim=dims, keepdim=True) - rms = torch.sqrt(variance + self.eps) - - normalized = normalized / rms - - if self.elementwise_affine: - return (normalized * self.weight.to(torch.float32)).to(input.dtype) - return normalized.to(input.dtype) class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -81,12 +33,10 @@ class Cache: return fn() key = self.prefix + key - try: - result = self.cache[key] - except KeyError: + if key not in self.cache: result = fn() self.cache[key] = result - return result + return self.cache[key] def namespace(self, namespace: str): return Cache( @@ -144,15 +94,6 @@ class MMArg: vid: Any txt: Any -def safe_pad_operation(x, padding, mode='constant', value=0.0): - try: - return F.pad(x, padding, mode=mode, value=value) - except RuntimeError as e: - if "not implemented for" in str(e) and x.dtype in (torch.float16, torch.bfloat16): - return F.pad(x.float(), padding, mode=mode, value=value).to(x.dtype) - raise - - def get_args(key: str, args: List[Any]) -> List[Any]: return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] @@ -235,8 +176,6 @@ class RotaryEmbedding(nn.Module): theta = 10000, max_freq = 10, learned_freq = False, - cache_if_possible = True, - cache_max_seq_len = 8192 ): super().__init__() @@ -247,12 +186,6 @@ class RotaryEmbedding(nn.Module): elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - self.cache_if_possible = cache_if_possible - self.cache_max_seq_len = cache_max_seq_len - - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_freqs_seq_len = 0 - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) self.learned_freq = learned_freq @@ -310,29 +243,10 @@ class RotaryEmbedding(nn.Module): seq_len: int | None = None, offset = 0 ): - should_cache = ( - self.cache_if_possible and - not self.learned_freq and - exists(seq_len) and - self.freqs_for != 'pixel' and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_freqs) and \ - (offset + seq_len) <= self.cached_freqs_seq_len - ): - return self.cached_freqs[offset:(offset + seq_len)].detach() - freqs = self.freqs freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) - - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len + freqs = freqs.unsqueeze(-1).expand(*freqs.shape, 2).flatten(-2) return freqs @@ -346,7 +260,7 @@ class RotaryEmbeddingBase(nn.Module): ) freqs = self.rope.freqs del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) + self.rope.register_buffer("freqs", freqs.detach()) def get_axial_freqs(self, *dims): return self.rope.get_axial_freqs(*dims) @@ -371,12 +285,12 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d): ]: freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) freqs = freqs.to(device=q.device) - q = rearrange(q, "L h d -> h L d") - k = rearrange(k, "L h d -> h L d") + q = q.transpose(0, 1) + k = k.transpose(0, 1) q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "h L d -> L h d") - k = rearrange(k, "h L d -> L h d") + q = q.transpose(0, 1) + k = k.transpose(0, 1) return q, k @torch._dynamo.disable @@ -407,11 +321,10 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase): dim=dim // rope_dim, freqs_for="lang", theta=ROPE_THETA, - cache_if_possible=False, ) freqs = self.rope.freqs del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) + self.rope.register_buffer("freqs", freqs.detach()) self.mm = True def slice_at_dim(t, dim_slice: slice, *, dim): @@ -423,10 +336,10 @@ def slice_at_dim(t, dim_slice: slice, *, dim): # rotary embedding helper functions def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) + x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2) x1, x2 = x.unbind(dim = -1) x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') + return x.flatten(-2) def exists(val): return val is not None @@ -465,7 +378,7 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: cos = torch.cos(angles) sin = torch.sin(angles) out = torch.stack([cos, -sin, sin, cos], dim=-1) - return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) + return out.reshape(*out.shape[:-1], 2, 2) def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: @@ -516,19 +429,19 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): vid_freqs = vid_freqs.to(target_device) if txt_freqs.device != target_device: txt_freqs = txt_freqs.to(target_device) - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = vid_q.transpose(0, 1) + vid_k = vid_k.transpose(0, 1) vid_q = _apply_rope1_partial(vid_q, vid_freqs) vid_k = _apply_rope1_partial(vid_k, vid_freqs) - vid_q = rearrange(vid_q, "h L d -> L h d") - vid_k = rearrange(vid_k, "h L d -> L h d") + vid_q = vid_q.transpose(0, 1) + vid_k = vid_k.transpose(0, 1) - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = txt_q.transpose(0, 1) + txt_k = txt_k.transpose(0, 1) txt_q = _apply_rope1_partial(txt_q, txt_freqs) txt_k = _apply_rope1_partial(txt_k, txt_freqs) - txt_q = rearrange(txt_q, "h L d -> L h d") - txt_k = rearrange(txt_k, "h L d -> L h d") + txt_q = txt_q.transpose(0, 1) + txt_k = txt_k.transpose(0, 1) return vid_q, vid_k, txt_q, txt_k @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks @@ -684,7 +597,7 @@ def window( ): hid = unflatten(hid, hid_shape) hid = list(map(window_fn, hid)) - hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device) hid, hid_shape = flatten(list(chain(*hid))) return hid, hid_shape, hid_windows @@ -747,8 +660,8 @@ class NaSwinAttention(NaMMAttention): ) vid_qkv_win = window_partition(vid_qkv) - vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + vid_qkv_win = vid_qkv_win.reshape(vid_qkv_win.shape[0], 3, self.heads, self.head_dim) + txt_qkv = txt_qkv.reshape(txt_qkv.shape[0], 3, self.heads, self.head_dim) vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) txt_q, txt_k, txt_v = txt_qkv.unbind(1) @@ -768,19 +681,19 @@ class NaSwinAttention(NaMMAttention): elif self.rope.mm: # repeat text q and k for window mmrope _, num_h, _ = txt_q.shape - txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = txt_q.flatten(1, 2) txt_q_repeat = unflatten(txt_q_repeat, txt_shape) txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] txt_q_repeat = list(chain(*txt_q_repeat)) txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) - txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim) - txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = txt_k.flatten(1, 2) txt_k_repeat = unflatten(txt_k_repeat, txt_shape) txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] txt_k_repeat = list(chain(*txt_k_repeat)) txt_k_repeat, _ = flatten(txt_k_repeat) - txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim) vid_q, vid_k, txt_q, txt_k = self.rope( vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win @@ -799,16 +712,16 @@ class NaSwinAttention(NaMMAttention): v=concat_win(vid_v, txt_v), heads=self.heads, skip_reshape=True, skip_output_reshape=True, cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() ), cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() ), ) vid_out, txt_out = unconcat_win(out) - vid_out = rearrange(vid_out, "l h d -> l (h d)") - txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = vid_out.flatten(1, 2) + txt_out = txt_out.flatten(1, 2) vid_out = window_reverse(vid_out) vid_out, txt_out = self.proj_out(vid_out, txt_out) @@ -1005,7 +918,9 @@ class PatchOut(nn.Module): ) -> torch.Tensor: t, h, w = self.patch_size vid = self.proj(vid) - vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + b, T, H, W, channels = vid.shape + c = channels // (t * h * w) + vid = vid.view(b, T, H, W, t, h, w, c).permute(0, 7, 1, 4, 2, 5, 3, 6).reshape(b, c, T * t, H * h, W * w) if t > 1: vid = vid[:, :, (t - 1) :] return vid @@ -1015,7 +930,7 @@ class NaPatchOut(PatchOut): self, vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test + cache: Cache = Cache(disable=True), vid_shape_before_patchify = None ) -> Tuple[ torch.FloatTensor, @@ -1028,7 +943,9 @@ class NaPatchOut(PatchOut): if not (t == h == w == 1): vid = unflatten(vid, vid_shape) for i in range(len(vid)): - vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + T, H, W, channels = vid[i].shape + c = channels // (t * h * w) + vid[i] = vid[i].view(T, H, W, t, h, w, c).permute(0, 3, 1, 4, 2, 5, 6).reshape(T * t, H * h, W * w, c) if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] vid, vid_shape = flatten(vid) @@ -1056,7 +973,8 @@ class PatchIn(nn.Module): if t > 1: assert vid.size(2) % t == 1 vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) - vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + b, c, Tt, Hh, Ww = vid.shape + vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c) vid = self.proj(vid) return vid @@ -1065,7 +983,7 @@ class NaPatchIn(PatchIn): self, vid: torch.Tensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test + cache: Cache = Cache(disable=True), ) -> torch.Tensor: cache = cache.namespace("patch") vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) @@ -1075,7 +993,8 @@ class NaPatchIn(PatchIn): for i in range(len(vid)): if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) - vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + Tt, Hh, Ww, c = vid[i].shape + vid[i] = vid[i].view(Tt // t, t, Hh // h, h, Ww // w, w, c).permute(0, 2, 4, 1, 3, 5, 6).reshape(Tt // t, Hh // h, Ww // w, t * h * w * c) vid, vid_shape = flatten(vid) vid = self.proj(vid) @@ -1102,17 +1021,14 @@ class AdaSingle(nn.Module): self.emb_dim = emb_dim self.layers = layers - param_kwargs = {"device": device} - fp8_types = _torch_float8_types() - if dtype is not None and dtype not in fp8_types: - param_kwargs["dtype"] = dtype + param_kwargs = {"device": device, "dtype": dtype} for l in layers: if "in" in modes: - self.register_parameter(f"{l}_shift", nn.Parameter(torch.zeros(dim, **param_kwargs))) - self.register_parameter(f"{l}_scale", nn.Parameter(torch.ones(dim, **param_kwargs))) + self.register_parameter(f"{l}_shift", nn.Parameter(torch.empty(dim, **param_kwargs))) + self.register_parameter(f"{l}_scale", nn.Parameter(torch.empty(dim, **param_kwargs))) if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.zeros(dim, **param_kwargs))) + self.register_parameter(f"{l}_gate", nn.Parameter(torch.empty(dim, **param_kwargs))) def forward( self, @@ -1125,7 +1041,7 @@ class AdaSingle(nn.Module): hid_len: Optional[torch.LongTensor] = None, # b ) -> torch.FloatTensor: idx = self.layers.index(layer) - emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :] emb = expand_dims(emb, 1, hid.ndim + 1) if hid_len is not None: @@ -1145,17 +1061,6 @@ class AdaSingle(nn.Module): getattr(self, f"{layer}_gate", None), ) - fp8_types = _torch_float8_types() - if fp8_types: - target_dtype = hid.dtype - - if shiftB is not None and shiftB.dtype in fp8_types: - shiftB = shiftB.to(target_dtype) - if scaleB is not None and scaleB.dtype in fp8_types: - scaleB = scaleB.to(target_dtype) - if gateB is not None and gateB.dtype in fp8_types: - gateB = gateB.to(target_dtype) - if mode == "in": return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) if mode == "out": @@ -1213,7 +1118,7 @@ def flatten( torch.LongTensor, # (b n) ]: assert len(hid) > 0 - shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device) hid = torch.cat([x.flatten(0, -2) for x in hid]) return hid, shape @@ -1227,19 +1132,6 @@ def unflatten( hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] return hid -def repeat( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, torch.LongTensor], # (b) -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - hid = unflatten(hid, hid_shape) - kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] - return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) - class NaDiT(nn.Module): def __init__( @@ -1275,23 +1167,11 @@ class NaDiT(nn.Module): emb_dim = vid_dim * 6 window = num_layers * [(4,3,3)] ada = AdaSingle - norm = CustomRMSNorm - qk_norm = CustomRMSNorm + norm = operations.RMSNorm + qk_norm = operations.RMSNorm super().__init__() - # ``torch.empty`` returns uninitialized memory, not zeros. The - # SeedVR2Conditioning fail-loud guard at - # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" - # from "buffer was never populated by the file" by checking - # ``positive_conditioning.abs().sum() == 0``. That sentinel is only - # reliable if the post-construction buffer state is deterministically - # zero, so explicitly zero-fill here rather than relying on the - # allocator's zero-on-alloc behavior (allocator-dependent and not - # contractual). When ``load_state_dict`` populates these buffers - # from a properly-baked SeedVR2 .safetensors, the in-place copy - # overwrites the zeros with the universal SeedVR2 conditioning - # tensors (shape (58, 5120) and (64, 5120) bf16). - self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) - self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) + self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype)) + self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype)) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, @@ -1354,7 +1234,7 @@ class NaDiT(nn.Module): self.vid_out_norm = None if vid_out_norm is not None: - self.vid_out_norm = CustomRMSNorm( + self.vid_out_norm = operations.RMSNorm( normalized_shape=vid_dim, eps=norm_eps, elementwise_affine=True, @@ -1369,7 +1249,7 @@ class NaDiT(nn.Module): ) def _resolve_text_conditioning(self, context, cond_or_uncond=None): - if context is None or getattr(context, "numel", lambda: None)() == 0: + if context is None or context.numel() == 0: context = self.positive_conditioning return flatten([context]) if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): @@ -1407,7 +1287,7 @@ class NaDiT(nn.Module): x, timestep, context, # l c - disable_cache: bool = False, # for test # TODO ? // gives an error when set to True + disable_cache: bool = False, **kwargs ): transformer_options = kwargs.get("transformer_options", {}) @@ -1483,5 +1363,5 @@ class NaDiT(nn.Module): vid = unflatten(vid, vid_shape) out = torch.stack(vid) out = out.movedim(-1, 1) - out = rearrange(out, "b c t h w -> b (c t) h w") + out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4]) return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 501896516..5daab022a 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1,15 +1,11 @@ -from contextlib import nullcontext from typing import Literal, Optional, Tuple -import gc import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from torch import Tensor from contextlib import contextmanager from comfy.utils import ProgressBar -from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.seedvr.constants import ( BYTEDANCE_BLOCK_OUT_CHANNELS, BYTEDANCE_GN_CHUNKS_FP16, @@ -58,13 +54,6 @@ def _seedvr2_clamped_spatial_overlap(overlap, tile_size): return min(overlap, tile_size - 1) -def _seedvr2_clear_temporal_memory(model): - for module in model.modules(): - if hasattr(module, "memory"): - module.memory = None - - -@torch.inference_mode() def tiled_vae( x, vae_model, @@ -75,10 +64,6 @@ def tiled_vae( encode=True, **kwargs, ): - gc.collect() - comfy.model_management.soft_empty_cache() - - x = x.to(next(vae_model.parameters()).dtype) if x.ndim != 5: x = x.unsqueeze(2) @@ -121,7 +106,6 @@ def tiled_vae( count = None def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): device = torch.device(device) - _seedvr2_clear_temporal_memory(model) t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() old_device = getattr(model, "device", None) model.device = device @@ -133,7 +117,7 @@ def tiled_vae( setattr(model, slicing_attr, slicing_min_size) try: if encode: - out = model.encode(t_chunk)[0] + out = model.encode(t_chunk) else: out = model.decode_(t_chunk) finally: @@ -141,8 +125,6 @@ def tiled_vae( setattr(model, slicing_attr, old_slicing_min_size) if old_device is not None: model.device = old_device - if isinstance(out, (tuple, list)): - out = out[0] if out.ndim == 4: out = out.unsqueeze(2) return out.to(storage_device) @@ -169,8 +151,6 @@ def tiled_vae( bar = ProgressBar(total_tiles) single_spatial_tile = h <= ti_h and w <= ti_w - _seedvr2_clear_temporal_memory(vae_model) - def run_tile(tile_index, tile_range): y_idx, y_end, x_idx, x_end = tile_range tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] @@ -186,7 +166,6 @@ def tiled_vae( if single_spatial_tile: result = tile_out[:, :, :target_d, :target_h, :target_w] - _seedvr2_clear_temporal_memory(vae_model) if result.device != x.device: result = result.to(x.device).to(x.dtype) if x.shape[2] == 1 and sf_t == 1: @@ -241,7 +220,6 @@ def tiled_vae( bar.update(1) result.div_(count.clamp(min=1e-6)) - _seedvr2_clear_temporal_memory(vae_model) if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -336,7 +314,6 @@ class Attention(nn.Module): eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, out_dim: int = None, pre_only=False, ): @@ -356,10 +333,6 @@ class Attention(nn.Module): self.out_dim = out_dim if out_dim is not None else query_dim self.pre_only = pre_only - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 @@ -480,21 +453,21 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): if x.ndim == 4: - x = rearrange(x, "b c h w -> b h w c") + x = x.permute(0, 2, 3, 1) x = norm_layer(x) - x = rearrange(x, "b h w c -> b c h w") + x = x.permute(0, 3, 1, 2) return x.to(input_dtype) if x.ndim == 5: - x = rearrange(x, "b c t h w -> b t h w c") + x = x.permute(0, 2, 3, 4, 1) x = norm_layer(x) - x = rearrange(x, "b t h w c -> b c t h w") + x = x.permute(0, 4, 1, 2, 3) return x.to(input_dtype) if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): if x.ndim <= 4: return norm_layer(x).to(input_dtype) if x.ndim == 5: - t = x.size(2) - x = rearrange(x, "b c t h w -> (b t) c h w") + b, c, t, h, w = x.shape + x = x.transpose(1, 2).reshape(b * t, c, h, w) memory_occupy = x.numel() * x.element_size() / 1024**3 if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) @@ -504,54 +477,16 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: x = list(x.chunk(num_chunks, dim=1)) weights = norm_layer.weight.chunk(num_chunks, dim=0) biases = norm_layer.bias.chunk(num_chunks, dim=0) - for i, (w, b) in enumerate(zip(weights, biases)): - x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + for i, (w, bias) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, bias, norm_layer.eps) x[i] = x[i].to(input_dtype) x = torch.cat(x, dim=1) else: x = norm_layer(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2) return x.to(input_dtype) raise NotImplementedError -def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): - problematic_modes = ['bilinear', 'bicubic', 'trilinear'] - - if mode in problematic_modes: - try: - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - except RuntimeError as e: - if ("not implemented for 'Half'" in str(e) or - "compute_indices_weights" in str(e)): - original_dtype = x.dtype - return F.interpolate( - x.float(), - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ).to(original_dtype) - else: - raise e - else: - # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - _receptive_field_t = Literal["half", "full"] def extend_head(tensor, times: int = 2, memory = None): @@ -585,7 +520,6 @@ class InflatedCausalConv3d(ops.Conv3d): **kwargs, ): self.inflation_mode = inflation_mode - self.memory = None super().__init__(*args, **kwargs) self.temporal_padding = self.padding[0] self.padding = (0, *self.padding[1:]) @@ -620,18 +554,19 @@ class InflatedCausalConv3d(ops.Conv3d): return super().forward(x) # Compute tensor shape after concat & padding. - shape = torch.tensor(x.size()) + shape = list(x.size()) if prev_cache is not None: shape[split_dim - 1] += prev_cache.size(split_dim - 1) - shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) - memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + for i, pad_sum in enumerate((padding[4] + padding[5], padding[2] + padding[3], padding[0] + padding[1])): + shape[-3 + i] += pad_sum + memory_occupy = math.prod(shape) * x.element_size() / 1024**3 # GiB if memory_occupy < self.memory_limit or split_dim == x.ndim: x_concat = x if prev_cache is not None: x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) def pad_and_forward(): - padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + padded = F.pad(x_concat, padding, mode='constant', value=0.0) if not padded.is_contiguous(): padded = padded.contiguous() with ignore_padding(self): @@ -689,46 +624,57 @@ class InflatedCausalConv3d(ops.Conv3d): def forward( self, input, - memory_state: MemoryState = MemoryState.UNSET + memory_state: MemoryState = MemoryState.UNSET, + memory_cache = None, ) -> Tensor: assert memory_state != MemoryState.UNSET + if memory_cache is None: + memory_cache = {} if memory_state != MemoryState.ACTIVE: - self.memory = None + memory_cache.pop(self, None) if ( math.isinf(self.memory_limit) and torch.is_tensor(input) ): - return self.basic_forward(input, memory_state) - return self.slicing_forward(input, memory_state) + return self.basic_forward(input, memory_state, memory_cache) + return self.slicing_forward(input, memory_state, memory_cache) - def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET, memory_cache = None): mem_size = self.stride[0] - self.kernel_size[0] - if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): - input = extend_head(input, memory=self.memory, times=-1) + memory = memory_cache.get(self) if memory_cache is not None else None + if (memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=memory, times=-1) else: input = extend_head(input, times=self.temporal_padding * 2) - memory = ( + next_memory = ( input[:, :, mem_size:].detach() if (mem_size != 0 and memory_state != MemoryState.DISABLED) else None ) - if memory_state != MemoryState.DISABLED: - self.memory = memory + if memory_cache is not None and memory_state != MemoryState.DISABLED: + if next_memory is None: + memory_cache.pop(self, None) + else: + memory_cache[self] = next_memory return super().forward(input) def slicing_forward( self, input, memory_state: MemoryState = MemoryState.UNSET, + memory_cache = None, ) -> Tensor: + if memory_cache is None: + memory_cache = {} squeeze_out = False if torch.is_tensor(input): input = [input] squeeze_out = True cache_size = self.kernel_size[0] - self.stride[0] + memory = memory_cache.get(self) if memory_cache is not None else None cache = cache_send_recv( - input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 + input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2 ) # Single GPU inference - simplified memory management @@ -740,7 +686,7 @@ class InflatedCausalConv3d(ops.Conv3d): input[0] = torch.cat([cache, input[0]], dim=2) cache = None if cache_size <= input[-1].size(2): - self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + memory_cache[self] = input[-1][:, :, -cache_size:].detach().contiguous() padding = tuple(x for x in reversed(self.padding) for _ in range(2)) for i in range(len(input)): @@ -802,17 +748,10 @@ class Upsample3D(nn.Module): self.temporal_ratio = 2 if temporal_up else 1 self.spatial_ratio = 2 if spatial_up else 1 - # [Override] MAGViT v2 learnable upsample upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio self.upscale_conv = ops.Conv3d( self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 ) - identity = ( - torch.eye(self.channels) - .repeat(upscale_ratio, 1) - .reshape_as(self.upscale_conv.weight) - ) - self.upscale_conv.weight.data.copy_(identity) self.conv = conv @@ -820,23 +759,27 @@ class Upsample3D(nn.Module): self, hidden_states: torch.FloatTensor, memory_state=None, + memory_cache=None, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels hidden_states = self.upscale_conv(hidden_states) - hidden_states = rearrange( - hidden_states, - "b (x y z c) f h w -> b c (f z) (h x) (w y)", - x=self.spatial_ratio, - y=self.spatial_ratio, - z=self.temporal_ratio, + b, channels, f, h, w = hidden_states.shape + c = channels // (self.spatial_ratio * self.spatial_ratio * self.temporal_ratio) + hidden_states = hidden_states.view(b, self.spatial_ratio, self.spatial_ratio, self.temporal_ratio, c, f, h, w) + hidden_states = hidden_states.permute(0, 4, 5, 3, 6, 1, 7, 2).reshape( + b, + c, + f * self.temporal_ratio, + h * self.spatial_ratio, + w * self.spatial_ratio, ) if self.temporal_up and memory_state != MemoryState.ACTIVE: hidden_states = remove_head(hidden_states) - hidden_states = self.conv(hidden_states, memory_state=memory_state) + hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -879,6 +822,7 @@ class Downsample3D(nn.Module): self, hidden_states: torch.FloatTensor, memory_state = None, + memory_cache = None, **kwargs, ) -> torch.FloatTensor: @@ -890,11 +834,11 @@ class Downsample3D(nn.Module): if self.spatial_down: pad = (0, 1, 0, 1) - hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states, memory_state=memory_state) + hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -962,7 +906,7 @@ class ResnetBlock3D(nn.Module): ) def forward( - self, input_tensor, temb, memory_state = None, **kwargs + self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs ): hidden_states = input_tensor @@ -970,7 +914,7 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states, memory_state=memory_state) + hidden_states = self.conv1(hidden_states, memory_state=memory_state, memory_cache=memory_cache) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -985,10 +929,10 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, memory_state=memory_state) + hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state, memory_cache=memory_cache) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor @@ -1055,15 +999,16 @@ class DownEncoderBlock3D(nn.Module): self, hidden_states: torch.FloatTensor, memory_state = None, + memory_cache = None, **kwargs, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, memory_state=memory_state) + hidden_states = downsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -1132,15 +1077,16 @@ class UpDecoderBlock3D(nn.Module): self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, - memory_state=None + memory_state=None, + memory_cache=None, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, memory_state=memory_state) + hidden_states = upsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -1203,7 +1149,6 @@ class UNetMidBlock3D(nn.Module): residual_connection=True, bias=True, upcast_softmax=True, - _from_deprecated_attn_block=True, ) ) else: @@ -1226,17 +1171,16 @@ class UNetMidBlock3D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, memory_state=None): + def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None): video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + b, c, f, h, w = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2).reshape(b * f, c, h, w) hidden_states = attn(hidden_states, temb=temb) - hidden_states = rearrange( - hidden_states, "(b f) c h w -> b c f h w", f=video_length - ) - hidden_states = resnet(hidden_states, temb, memory_state=memory_state) + hidden_states = hidden_states.reshape(b, video_length, c, h, w).transpose(1, 2) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -1327,22 +1271,23 @@ class Encoder3D(nn.Module): def forward( self, sample: torch.FloatTensor, - memory_state = None + memory_state = None, + memory_cache = None, ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state = memory_state) + sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) # down for down_block in self.down_blocks: - sample = down_block(sample, memory_state=memory_state) + sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache) # middle - sample = self.mid_block(sample, memory_state=memory_state) + sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state = memory_state) + sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) return sample @@ -1436,24 +1381,25 @@ class Decoder3D(nn.Module): sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, memory_state = None, + memory_cache = None, ) -> torch.FloatTensor: sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state=memory_state) + sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds, memory_state=memory_state) + sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) + sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) return sample @@ -1529,22 +1475,23 @@ class VideoAutoencoderKL(nn.Module): return decoded def _encode( - self, x, memory_state = MemoryState.DISABLED + self, x, memory_state = MemoryState.DISABLED, memory_cache = None ) -> torch.Tensor: _x = x.to(self.device) - h = self.encoder(_x, memory_state=memory_state) + h = self.encoder(_x, memory_state=memory_state, memory_cache=memory_cache) return h.to(x.device) def _decode( - self, z, memory_state = MemoryState.DISABLED + self, z, memory_state = MemoryState.DISABLED, memory_cache = None ) -> torch.Tensor: _z = z.to(self.device) - output = self.decoder(_z, memory_state=memory_state) + output = self.decoder(_z, memory_state=memory_state, memory_cache=memory_cache) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: sp_size =1 if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + memory_cache = {} split_size = max( self.slicing_sample_min_size * sp_size, getattr(self, "temporal_downsample_factor", 1), @@ -1558,17 +1505,14 @@ class VideoAutoencoderKL(nn.Module): self._encode( torch.cat((x[:, :, :1], x_slices[0]), dim=2), memory_state=MemoryState.INITIALIZING, + memory_cache=memory_cache, ) ] for x_idx in range(1, len(x_slices)): encoded_slices.append( - self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache) ) out = torch.cat(encoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None return out else: return self._encode(x) @@ -1576,22 +1520,20 @@ class VideoAutoencoderKL(nn.Module): def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: sp_size = 1 if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + memory_cache = {} z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) decoded_slices = [ self._decode( torch.cat((z[:, :, :1], z_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING + memory_state=MemoryState.INITIALIZING, + memory_cache=memory_cache, ) ] for z_idx in range(1, len(z_slices)): decoded_slices.append( - self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache) ) out = torch.cat(decoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None return out else: return self._decode(z) @@ -1612,32 +1554,25 @@ class VideoAutoencoderKL(nn.Module): return _unwrap(self.decode_(latent)) class VideoAutoencoderKLWrapper(VideoAutoencoderKL): - # Signals to comfy.sd.VAE that this model performs its own VAE tiling, so the - # generic tiled-decode/encode dispatch defers to decode_tiled/encode_tiled below. - comfy_handles_tiling = True - def __init__( self, *args, spatial_downsample_factor = 8, temporal_downsample_factor = 4, - freeze_encoder = True, **kwargs, ): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor - self.freeze_encoder = freeze_encoder self.enable_tiling = False super().__init__(*args, **kwargs) self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) def forward(self, x: torch.FloatTensor): - with torch.no_grad() if self.freeze_encoder else nullcontext(): - z, p = self.encode(x) + z, p = self._encode_with_raw_latent(x) x = self.decode(z) return x, z, p - def encode(self, x, orig_dims=None): + def _encode_with_raw_latent(self, x): if x.ndim == 4: x = x.unsqueeze(2) x = x.to(dtype=next(self.parameters()).dtype) @@ -1646,6 +1581,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): z = p.squeeze(2) return z, p + def encode(self, x, orig_dims=None): + z, _ = self._encode_with_raw_latent(x) + return z + def decode(self, z, seedvr2_tiling=None): seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling if not isinstance(seedvr2_tiling, dict): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bcca99251..bf44b832c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1151,9 +1151,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return unet_config -def model_config_from_unet_config(unet_config, state_dict=None): +def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""): for model_config in comfy.supported_models.models: - if model_config.matches(unet_config, state_dict): + if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix): return model_config(unet_config) logging.error("no match {}".format(unet_config)) @@ -1163,7 +1163,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata) if unet_config is None: return None - model_config = model_config_from_unet_config(unet_config, state_dict) + model_config = model_config_from_unet_config(unet_config, state_dict, unet_key_prefix) if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) diff --git a/comfy/sd.py b/comfy/sd.py index 6e1340ea8..06c6196d3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,3 @@ -import inspect import json import torch from enum import Enum @@ -500,6 +499,8 @@ class VAE: self.upscale_index_formula = None self.extra_1d_channel = None self.crop_input = True + self.handles_tiling = False + self.format_encoded = None self.audio_sample_rate = 44100 @@ -554,6 +555,8 @@ class VAE: self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape) self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.handles_tiling = True + self.format_encoded = self.first_stage_model.comfy_format_encoded self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) @@ -1118,7 +1121,7 @@ class VAE: if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) @@ -1127,7 +1130,7 @@ class VAE: elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) else: pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) @@ -1149,7 +1152,7 @@ class VAE: args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3): + if self.handles_tiling and dims in (2, 3): tiled_args = {} if tile_x is not None: tiled_args["tile_x"] = tile_x @@ -1204,8 +1207,6 @@ class VAE: else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) - if isinstance(out, tuple): - out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1225,7 +1226,7 @@ class VAE: if self.latent_dim == 3: tile = 256 overlap = tile // 4 - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) else: samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) @@ -1234,9 +1235,8 @@ class VAE: else: samples = self.encode_tiled_(pixel_samples) - formatter = getattr(self.first_stage_model, "comfy_format_encoded", None) - if formatter is not None: - samples = formatter(samples) + if self.format_encoded is not None: + samples = self.format_encoded(samples) return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): @@ -1268,7 +1268,7 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: tiled_args = {} if tile_x is not None: tiled_args["tile_x"] = tile_x @@ -1298,9 +1298,8 @@ class VAE: samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) - formatter = getattr(self.first_stage_model, "comfy_format_encoded", None) - if formatter is not None: - samples = formatter(samples) + if self.format_encoded is not None: + samples = self.format_encoded(samples) return samples def get_sd(self): @@ -1852,16 +1851,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) -def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): - set_dtype = model_config.set_inference_dtype - parameters = inspect.signature(set_dtype).parameters - supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) - if supports_device: - set_dtype(dtype, manual_cast_dtype, device=device) - else: - set_dtype(dtype, manual_cast_dtype) - - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -1969,7 +1958,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -2110,7 +2099,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device) if custom_operations is not None: model_config.custom_operations = custom_operations diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1ce5f8c91..5c849358e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1688,6 +1688,10 @@ class SeedVR2(supported_models_base.BASE): unet_config = { "image_model": "seedvr2" } + required_keys = { + "{}positive_conditioning", + "{}negative_conditioning", + } latent_format = comfy.latent_formats.SeedVR2 vae_key_prefix = ["vae."] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 572f9984e..e3a8e131f 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -54,13 +54,13 @@ class BASE: optimizations = {"fp8": False} @classmethod - def matches(s, unet_config, state_dict=None): + def matches(s, unet_config, state_dict=None, unet_key_prefix=""): for k in s.unet_config: if k not in unet_config or s.unet_config[k] != unet_config[k]: return False if state_dict is not None: for k in s.required_keys: - if k not in state_dict: + if k.format(unet_key_prefix) not in state_dict: return False return True diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 978de3e41..1fb44ac36 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -3,7 +3,6 @@ from comfy_api.latest import ComfyExtension, io import torch import math import logging -from einops import rearrange import comfy.model_management import comfy.sample @@ -101,14 +100,6 @@ def _resolve_seedvr2_diffusion_model(model): return diffusion_model -def _apply_rope_freqs_float32_cast(diffusion_model): - """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" - for module in diffusion_model.modules(): - if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): - if module.rope.freqs.data.dtype != torch.float32: - module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) - - def get_conditions(latent, latent_blur): t, h, w, c = latent.shape cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) @@ -193,7 +184,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name): images = images.reshape(b, t, c, new_h, new_w) images = cut_videos(images) - images_bthwc = rearrange(images, "b t c h w -> b t h w c") + images_bthwc = images.permute(0, 1, 3, 4, 2).contiguous() return io.NodeOutput(images_bthwc) @@ -265,12 +256,12 @@ class SeedVR2PostProcessing(io.ComfyNode): output_device = decoded_5d.device decoded_raw = cls._to_seedvr2_raw(decoded_5d) reference_raw = cls._to_seedvr2_raw(reference_5d) - decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") - reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") + decoded_flat = decoded_raw.permute(0, 1, 4, 2, 3).reshape(b * t, decoded_raw.shape[4], target_h, target_w) + reference_flat = reference_raw.permute(0, 1, 4, 2, 3).reshape(b * t, reference_raw.shape[4], target_h, target_w) output = cls._color_transfer_chunked( decoded_flat, reference_flat, output_device, color_correction_method, ) - output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) + output = output.reshape(b, t, output.shape[1], output.shape[2], output.shape[3]).permute(0, 1, 3, 4, 2) output = output.add(1.0).div(2.0).clamp(0.0, 1.0) elif color_correction_method == "none": output = decoded_5d @@ -359,7 +350,6 @@ class SeedVR2PostProcessing(io.ComfyNode): ) from e next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) - comfy.model_management.soft_empty_cache() chunk_size = next_chunk_size @classmethod @@ -419,14 +409,14 @@ class SeedVR2PostProcessing(io.ComfyNode): if reference.shape[2] == height and reference.shape[3] == width: return reference b, t = reference.shape[:2] - reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") + reference_flat = reference.permute(0, 1, 4, 2, 3).reshape(b * t, reference.shape[4], reference.shape[2], reference.shape[3]) resized = TVF.resize( reference_flat, size=(height, width), interpolation=InterpolationMode.BICUBIC, antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), ) - return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) + return resized.reshape(b, t, resized.shape[1], height, width).permute(0, 1, 3, 4, 2) class SeedVR2Conditioning(io.ComfyNode): @@ -471,39 +461,12 @@ class SeedVR2Conditioning(io.ComfyNode): pos_cond = model.positive_conditioning neg_cond = model.negative_conditioning - # Fail-loud guard against silently-wrong output when a - # DiT-only ``.safetensors`` (no ``positive_conditioning`` / - # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. - # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see - # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` - # leaves them at zero when the keys are absent. Detect that state - # here rather than at ``BaseModel.extra_conds`` (per sampling step, - # wasteful) or at the resolver helper (mixes structural shape with - # semantic content). Both buffers must be checked together — partial - # bake regressions could populate one but not the other. - if ( - pos_cond.float().abs().sum().item() == 0 - and neg_cond.float().abs().sum().item() == 0 - ): - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " - f"and negative_conditioning buffers are zero-valued — model " - f"file appears to be a DiT-only export missing " - f"the SeedVR2 conditioning tensors. " - f"Re-bake the file with ``positive_conditioning`` (58, 5120) " - f"and ``negative_conditioning`` (64, 5120) keys at top level, " - f"or load via CheckpointLoaderSimple from a bundled " - f"checkpoint." - ) - - _apply_rope_freqs_float32_cast(model) - condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) condition = condition.movedim(-1, 1) latent = vae_conditioning.movedim(-1, 1) - latent = rearrange(latent, "b c t h w -> b (c t) h w") - condition = rearrange(condition, "b c t h w -> b (c t) h w") + latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4]) + condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4]) negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] @@ -723,7 +686,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that OOM on long sequences. The latent enters the sampler in SeedVR2's collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` - at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that + at ``reshape(b, c * t, h, w)``); this node slices that tensor along the temporal axis, runs the configured inner sampler sequentially per chunk against the standard ``comfy.sample.sample`` entry point, and concatenates per-chunk outputs back into a single @@ -882,7 +845,6 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): "frames_per_chunk=%s.", attempt_frames_per_chunk, attempts[i + 1], ) - comfy.model_management.soft_empty_cache() # Short-circuit: total fits in one chunk -> standard path with no # chunking overhead. Output of this branch is byte-identical to the diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py index 2a6e3d430..d36e50428 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -11,7 +11,6 @@ import importlib import sys from unittest.mock import MagicMock -import pytest import torch import torch.nn as nn @@ -53,7 +52,7 @@ def _import_nodes_seedvr_isolated(): mock_mm.WINDOWS = False sys.modules["comfy.model_management"] = mock_mm if sys.modules.get("comfy") is None: - import comfy as _comfy_pkg # noqa: F401 + importlib.import_module("comfy") comfy_pkg = sys.modules.get("comfy") if comfy_pkg is not None: setattr(comfy_pkg, "model_management", mock_mm) @@ -95,11 +94,10 @@ class _Block(nn.Module): class _DiffusionModel(nn.Module): """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" - def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): + def __init__(self, n_blocks=3, conditioning_dtype=torch.float32): super().__init__() self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) - pos = torch.zeros if zero_conditioning else torch.ones - self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) + self.register_buffer("positive_conditioning", torch.ones((2, 4), dtype=conditioning_dtype)) self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) @@ -185,29 +183,3 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): ) finally: restore() - - -def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel(zero_conditioning=True) - patcher = _ModelPatcher(diffusion_model) - vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} - - with pytest.raises(RuntimeError) as excinfo: - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, vae_conditioning, - ) - - message = str(excinfo.value) - assert message.startswith( - nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX - ), ( - "Fail-loud message must use the standard " - "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " - f"can match it. Got: {message!r}" - ) - assert "positive_conditioning" in message - assert "negative_conditioning" in message - finally: - restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py index a27a8f8df..6c821136d 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import pytest import torch from comfy.cli_args import args as cli_args @@ -32,26 +33,19 @@ def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypa monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) with patch.object(nodes_seedvr, "lab_color_transfer", _lab): - try: + with pytest.raises(RuntimeError) as excinfo: nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( decoded, reference, torch.device("cpu"), "lab", ) - except RuntimeError as exc: - assert "color_correction_method=lab" in str(exc) - assert " method=lab" not in str(exc) - else: - raise AssertionError("expected RuntimeError for one-frame LAB OOM") + assert "color_correction_method=lab" in str(excinfo.value) + assert " method=lab" not in str(excinfo.value) def test_seedvr2_post_processing_unknown_color_correction_method_raises(): decoded = torch.zeros(1, 2, 4, 4, 3) original = torch.zeros(1, 2, 4, 4, 3) - try: + with pytest.raises(ValueError) as excinfo: nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") - except ValueError as exc: - assert "color_correction_method" in str(exc) - else: - raise AssertionError("expected ValueError for unknown color_correction_method") + assert "color_correction_method" in str(excinfo.value) diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 109e2b13b..587c393c9 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -2,7 +2,7 @@ from collections import defaultdict import torch -from comfy.model_detection import detect_unet_config, model_config_from_unet_config +from comfy.model_detection import detect_unet_config, model_config_from_unet, model_config_from_unet_config import comfy.supported_models @@ -76,21 +76,31 @@ def _make_flux_schnell_comfyui_sd(): def _make_seedvr2_7b_separate_mm_sd(): return { "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), + "positive_conditioning": torch.empty(58, 5120), + "negative_conditioning": torch.empty(64, 5120), } def _make_seedvr2_7b_shared_mm_sd(): return { "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + "positive_conditioning": torch.empty(58, 5120), + "negative_conditioning": torch.empty(64, 5120), } def _make_seedvr2_3b_shared_mm_sd(): return { "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + "positive_conditioning": torch.empty(58, 5120), + "negative_conditioning": torch.empty(64, 5120), } +def _add_model_diffusion_prefix(sd): + return {f"model.diffusion_model.{k}": v for k, v in sd.items()} + + class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -182,6 +192,20 @@ class TestModelDetection: assert unet_config["num_layers"] == 32 assert unet_config["mlp_type"] == "swiglu" + def test_seedvr2_model_match_requires_conditioning_tensors(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert type(model_config_from_unet_config(unet_config, sd)).__name__ == "SeedVR2" + + del sd["positive_conditioning"] + assert model_config_from_unet_config(unet_config, sd) is None + + def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self): + sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd()) + + assert type(model_config_from_unet(sd, "model.diffusion_model.")).__name__ == "SeedVR2" + def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py index dd3121428..966e9465d 100644 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -103,7 +103,7 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa heads=heads, head_dim=head_dim, qk_bias=False, - qk_norm=seedvr_model.CustomRMSNorm, + qk_norm=comfy_ops.disable_weight_init.RMSNorm, qk_norm_eps=1e-6, rope_type=None, rope_dim=head_dim, diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py index feae2211f..06b2f1564 100644 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -26,6 +26,7 @@ import comfy.ldm.seedvr.model # noqa: E402 import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 import comfy.model_management # noqa: E402 +import comfy.ops as comfy_ops # noqa: E402 import comfy.sample # noqa: E402 import comfy.sd as sd_mod # noqa: E402 import nodes as nodes_mod # noqa: E402 @@ -81,6 +82,7 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis txt_in_dim=txt_in_dim, heads=24, mm_layers=3, + operations=comfy_ops.disable_weight_init, ) return flags @@ -140,6 +142,46 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) +def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch): + raw_latent = torch.full((1, 16, 1, 4, 5), 2.0) + seen_shapes = [] + + def base_encode(self, x): + seen_shapes.append(tuple(x.shape)) + return raw_latent.to(device=x.device, dtype=x.dtype) + + monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode) + + vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper) + nn.Module.__init__(vae) + vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32)) + + latent = vae.encode(torch.zeros(1, 3, 32, 40)) + + assert type(latent) is torch.Tensor + assert tuple(latent.shape) == (1, 16, 4, 5) + assert seen_shapes == [(1, 3, 1, 32, 40)] + + +def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch): + raw_latent = torch.full((1, 16, 1, 4, 5), 3.0) + + def base_encode(self, x): + return raw_latent.to(device=x.device, dtype=x.dtype) + + monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode) + + vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper) + nn.Module.__init__(vae) + vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32)) + + latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40)) + + assert tuple(latent.shape) == (1, 16, 4, 5) + assert tuple(raw.shape) == (1, 16, 1, 4, 5) + assert torch.equal(raw, raw_latent) + + def _make_vae(wrapper): vae = sd_mod.VAE.__new__(sd_mod.VAE) vae.first_stage_model = wrapper @@ -155,6 +197,8 @@ def _make_vae(wrapper): vae.extra_1d_channel = None vae.crop_input = False vae.not_video = False + vae.handles_tiling = isinstance(wrapper, seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae.format_encoded = wrapper.comfy_format_encoded vae.patcher = _Patcher() vae.process_input = lambda image: image vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py index ced2fe34f..0d3c97e4a 100644 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -1,6 +1,7 @@ from contextlib import ExitStack from unittest.mock import MagicMock, patch +import pytest import torch import torch.nn as nn @@ -21,8 +22,6 @@ from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): - from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae - class StubVAEModel(torch.nn.Module): def __init__(self): super().__init__() @@ -37,9 +36,9 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): def decode_(self, t_chunk): self.decode_min_sizes.append(self.slicing_latent_min_size) - return VideoAutoencoderKL.slicing_decode(self, t_chunk) + return vae_mod.VideoAutoencoderKL.slicing_decode(self, t_chunk) - def _decode(self, z, memory_state=MemoryState.DISABLED): + def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None): self.memory_states.append(memory_state) b, c, d, h, w = z.shape return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) @@ -68,8 +67,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): def test_zero_temporal_size_preserves_min_size_when_encode_raises(): - from comfy.ldm.seedvr.vae import tiled_vae - class RaisingVAEModel(torch.nn.Module): def __init__(self): super().__init__() @@ -85,8 +82,7 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises(): vae = RaisingVAEModel() x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) - raised = False - try: + with pytest.raises(RuntimeError, match="simulated encode failure"): tiled_vae( x, vae, @@ -96,15 +92,43 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises(): temporal_overlap=0, encode=True, ) - except RuntimeError as exc: - if "simulated encode failure" not in str(exc): - raise - raised = True - assert raised assert vae.slicing_sample_min_size == 4 +def test_tiled_vae_encode_uses_tensor_return_without_indexing(): + class TensorEncodeVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.calls = [] + + def encode(self, t_chunk): + self.calls.append(tuple(t_chunk.shape)) + b, _, _, h, w = t_chunk.shape + return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype) + + vae = TensorEncodeVAEModel() + x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32) + + out = tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + + assert vae.calls == [(2, 3, 1, 64, 64)] + assert tuple(out.shape) == (2, 16, 1, 8, 8) + + # --------------------------------------------------------------------------- # From test_seedvr_vae_tiled_temporal_slicing.py # --------------------------------------------------------------------------- @@ -126,7 +150,7 @@ class _SlicingDecodeVAE(nn.Module): self.decode_min_sizes.append(self.slicing_latent_min_size) return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) - def _decode(self, z, memory_state=MemoryState.DISABLED): + def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None): self.memory_states.append(memory_state) x = z[:, :1].repeat( 1, @@ -205,6 +229,8 @@ def _make_vae(first_stage_model, latent_channels, latent_dim): vae.latent_dim = latent_dim vae.vae_output_dtype = lambda: torch.float32 vae.spacial_compression_decode = lambda: 8 + vae.handles_tiling = isinstance(first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae.format_encoded = None vae.process_input = lambda x: x vae.process_output = lambda x: x vae.throw_exception_if_invalid = lambda: None @@ -240,7 +266,6 @@ def test_4d_seedvr2_latent_routes_to_owned_decode_tiled(): def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): first_stage = MagicMock() - first_stage.comfy_handles_tiling = False first_stage.decode = MagicMock(side_effect=_force_oom) vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) @@ -273,6 +298,8 @@ def _populate_common_vae_attrs_fallback(vae): vae.not_video = False vae.crop_input = False vae.pad_channel_value = None + vae.handles_tiling = isinstance(vae.first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae.format_encoded = None vae.vae_output_dtype = lambda: torch.float32 vae.spacial_compression_encode = lambda: 8 @@ -295,7 +322,6 @@ def _make_seedvr2_vae_fallback(): def _make_non_seedvr2_vae_fallback(): vae = sd_mod.VAE.__new__(sd_mod.VAE) vae.first_stage_model = MagicMock() - vae.first_stage_model.comfy_handles_tiling = False _populate_common_vae_attrs_fallback(vae) return vae