From 08d93555d015ee0fd0a921097d5a67fd867698db Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:18:10 +0200 Subject: [PATCH 01/35] init --- comfy/latent_formats.py | 4 + comfy/ldm/modules/attention.py | 2 +- comfy/ldm/seedvr/model.py | 1287 ++++++++++++++++++++++++++++++++ comfy/ldm/seedvr/vae.py | 1260 +++++++++++++++++++++++++++++++ comfy/model_base.py | 6 + comfy/model_detection.py | 11 + comfy/sd.py | 14 + comfy/supported_models.py | 17 +- 8 files changed, 2599 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/seedvr/model.py create mode 100644 comfy/ldm/seedvr/vae.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 82d9f9bb8..f260528d4 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -470,3 +470,7 @@ class Hunyuan3Dv2mini(LatentFormat): class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + +class SeedVR2(LatentFormat): + latent_channels = 16 + latent_dimensions = 16 \ No newline at end of file diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee..256f9a989 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -428,7 +428,7 @@ else: SDP_BATCH_LIMIT = 2**31 -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=Falsez): if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py new file mode 100644 index 000000000..40a460d67 --- /dev/null +++ b/comfy/ldm/seedvr/model.py @@ -0,0 +1,1287 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union, List, Dict, Any, Callable +import einops +from einops import rearrange, einsum +from torch import nn +import torch.nn.functional as F +from math import ceil, sqrt, pi +import torch +from itertools import chain +from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding +from comfy.ldm.modules.attention import optimized_attention +from comfy.rmsnorm import RMSNorm +from torch.nn.modules.utils import _triple +from torch import nn + +class Cache: + def __init__(self, disable=False, prefix="", cache=None): + self.cache = cache if cache is not None else {} + self.disable = disable + self.prefix = prefix + + def __call__(self, key: str, fn: Callable): + if self.disable: + return fn() + + key = self.prefix + key + try: + result = self.cache[key] + except KeyError: + result = fn() + self.cache[key] = result + return result + + def namespace(self, namespace: str): + return Cache( + disable=self.disable, + prefix=self.prefix + namespace + ".", + cache=self.cache, + ) + + def get(self, key: str): + key = self.prefix + key + return self.cache[key] + +def repeat_concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: List, # (n) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + txt = [[x] * n for x, n in zip(txt, txt_repeat)] + txt = list(chain(*txt)) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat_idx( + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) + src_idx = torch.argsort(tgt_idx) + return ( + lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), + lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), + ) + + +def repeat_concat_idx( + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: torch.LongTensor, # (n) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + txt_repeat_list = txt_repeat.tolist() + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + src_idx = torch.argsort(tgt_idx) + txt_idx_len = len(tgt_idx) - len(vid_idx) + repeat_txt_len = (txt_len * txt_repeat).tolist() + + def unconcat_coalesce(all): + vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) + txt_out_coalesced = [] + for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): + txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) + txt_out_coalesced.append(txt) + return vid_out, torch.cat(txt_out_coalesced) + + return ( + lambda vid, txt: torch.cat([vid, txt])[tgt_idx], + lambda all: unconcat_coalesce(all), + ) + +@dataclass +class MMArg: + vid: Any + txt: Any + +def safe_pad_operation(x, padding, mode='constant', value=0.0): + """Safe padding operation that handles Half precision only for problematic modes""" + # Modes qui nécessitent le fix Half precision + problematic_modes = ['replicate', 'reflect', 'circular'] + + if mode in problematic_modes: + try: + return F.pad(x, padding, mode=mode, value=value) + except RuntimeError as e: + if "not implemented for 'Half'" in str(e): + original_dtype = x.dtype + return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) + else: + raise e + else: + # Pour 'constant' et autres modes compatibles, pas de fix nécessaire + return F.pad(x, padding, mode=mode, value=value) + + +def get_args(key: str, args: List[Any]) -> List[Any]: + return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] + + +def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} + + +def make_720Pwindows(size, num_windows, shift = False): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + + scale = sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) + wt = ceil(min(t, 30) / resized_nt) + + st, sh, sw = (0.5 * shift if wt < t else 0, + 0.5 * shift if wh < h else 0, + 0.5 * shift if ww < w else 0) + + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) + if shift: + nt += 1 if st > 0 else 0 + nh += 1 if sh > 0 else 0 + nw += 1 if sw > 0 else 0 + + windows = [] + for iw in range(nw): + w_start = max(int((iw - sw) * ww), 0) + w_end = min(int((iw - sw + 1) * ww), w) + if w_end <= w_start: + continue + + for ih in range(nh): + h_start = max(int((ih - sh) * wh), 0) + h_end = min(int((ih - sh + 1) * wh), h) + if h_end <= h_start: + continue + + for it in range(nt): + t_start = max(int((it - st) * wt), 0) + t_end = min(int((it - st + 1) * wt), t) + if t_end <= t_start: + continue + + windows.append((slice(t_start, t_end), + slice(h_start, h_end), + slice(w_start, w_end))) + + return windows + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True, + cache_max_seq_len = 8192 + ): + super().__init__() + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + self.cache_max_seq_len = cache_max_seq_len + + self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_freqs_seq_len = 0 + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.register_buffer('dummy', torch.tensor(0), persistent = False) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + + if not use_xpos: + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + + self.register_buffer('scale', scale, persistent = False) + self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_scales_seq_len = 0 + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def get_axial_freqs( + self, + *dims, + offsets = None + ): + Colon = slice(None) + all_freqs = [] + + # handle offset + + if exists(offsets): + assert len(offsets) == len(dims) + + # get frequencies for each axis + + for ind, dim in enumerate(dims): + + offset = 0 + if exists(offsets): + offset = offsets[ind] + + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + pos = pos + offset + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + # concat all freqs + + all_freqs = torch.broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + def forward( + self, + t, + seq_len: int | None = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and + not self.learned_freq and + exists(seq_len) and + self.freqs_for != 'pixel' and + (offset + seq_len) <= self.cache_max_seq_len + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs_seq_len + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len + + return freqs + +class RotaryEmbeddingBase(nn.Module): + def __init__(self, dim: int, rope_dim: int): + super().__init__() + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="pixel", + max_freq=256, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + + def get_axial_freqs(self, *dims): + return self.rope.get_axial_freqs(*dims) + + +class RotaryEmbedding3d(RotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + self.mm = False + + def forward( + self, + q: torch.FloatTensor, # b h l d + k: torch.FloatTensor, # b h l d + size: Tuple[int, int, int], + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + T, H, W = size + freqs = self.get_axial_freqs(T, H, W) + q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + q = apply_rotary_emb(freqs, q.float()).to(q.dtype) + k = apply_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "b h T H W d -> b h (T H W) d") + k = rearrange(k, "b h T H W d -> b h (T H W) d") + return q, k + + +class MMRotaryEmbeddingBase(RotaryEmbeddingBase): + def __init__(self, dim: int, rope_dim: int): + super().__init__(dim, rope_dim) + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="lang", + theta=10000, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + self.mm = True + +def slice_at_dim(t, dim_slice: slice, *, dim): + dim += (t.ndim if dim < 0 else 0) + colons = [slice(None)] * t.ndim + colons[dim] = dim_slice + return t[tuple(colons)] + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') +def exists(val): + return val is not None + +def apply_rotary_emb( + freqs, + t, + start_index = 0, + scale = 1., + seq_dim = -2, + freqs_seq_dim = None +): + dtype = t.dtype + + if not exists(freqs_seq_dim): + if freqs.ndim == 2 or t.ndim == 3: + freqs_seq_dim = 0 + + if t.ndim == 3 or exists(freqs_seq_dim): + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) + + out = torch.cat((t_left, t_transformed, t_right), dim=-1) + + return out.type(dtype) + +class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + + def forward( + self, + vid_q: torch.FloatTensor, # L h d + vid_k: torch.FloatTensor, # L h d + vid_shape: torch.LongTensor, # B 3 + txt_q: torch.FloatTensor, # L h d + txt_k: torch.FloatTensor, # L h d + txt_shape: torch.LongTensor, # B 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_freqs, txt_freqs = cache( + "mmrope_freqs_3d", + lambda: self.get_freqs(vid_shape, txt_shape), + ) + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) + vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) + vid_q = rearrange(vid_q, "h L d -> L h d") + vid_k = rearrange(vid_k, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) + txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) + txt_q = rearrange(txt_q, "h L d -> L h d") + txt_k = rearrange(txt_k, "h L d -> L h d") + return vid_q, vid_k, txt_q, txt_k + + def get_freqs( + self, + vid_shape: torch.LongTensor, + txt_shape: torch.LongTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + ]: + vid_freqs = self.get_axial_freqs(1024, 128, 128) + txt_freqs = self.get_axial_freqs(1024) + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) + txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) + +class MMModule(nn.Module): + def __init__( + self, + module: Callable[..., nn.Module], + *args, + shared_weights: bool = False, + vid_only: bool = False, + **kwargs, + ): + super().__init__() + self.shared_weights = shared_weights + self.vid_only = vid_only + if self.shared_weights: + assert get_args("vid", args) == get_args("txt", args) + assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + else: + self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + self.txt = ( + module(*get_args("txt", args), **get_kwargs("txt", kwargs)) + if not vid_only + else None + ) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + *args, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.vid if not self.shared_weights else self.all + vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) + if not self.vid_only: + txt_module = self.txt if not self.shared_weights else self.all + txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) + return vid, txt + +def get_na_rope(rope_type: Optional[str], dim: int): + # 7b doesn't use rope + if rope_type is None: + return None + if rope_type == "mmrope3d": + return NaMMRotaryEmbedding3d(dim=dim) + +class NaMMAttention(nn.Module): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_norm, + qk_norm_eps: float, + rope_type: Optional[str], + rope_dim: int, + shared_weights: bool, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + inner_dim = heads * head_dim + qkv_dim = inner_dim * 3 + self.head_dim = head_dim + self.proj_qkv = MMModule( + nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights + ) + self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) + self.norm_q = MMModule( + qk_norm, + dim=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + ) + self.norm_k = MMModule( + qk_norm, + dim=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + ) + + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + if self.rope: + if self.rope.mm: + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache) + + vid_len = cache("vid_len", lambda: vid_shape.prod(-1)) + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + all_len = cache("all_len", lambda: vid_len + txt_len) + + b = len(vid_len) + vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] + tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] + + q = torch.cat([vq, tq], dim=1) + k = torch.cat([vk, tk], dim=1) + v = torch.cat([vv, tv], dim=1) + + _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) + + attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True) + attn = attn.flatten(0, 1) # to continue working with the rest of the code + + attn = rearrange(attn, "l h d -> l (h d)") + vid_out, txt_out = unconcat(attn) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + return vid_out, txt_out + +def window( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid = unflatten(hid, hid_shape) + hid = list(map(window_fn, hid)) + hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid, hid_shape = flatten(list(chain(*hid))) + return hid, hid_shape, hid_windows + +def window_idx( + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + tgt_windows, + ) + +class NaSwinAttention(NaMMAttention): + def __init__( + self, + *args, + window: Union[int, Tuple[int, int, int]], + window_method: bool, # shifted or not + **kwargs, + ): + super().__init__(*args, **kwargs) + self.window = _triple(window) + self.window_method = window_method + assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + + self.window_op = window_method + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + + # re-org the input seq for window attn + cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") + + def make_window(x: torch.Tensor): + t, h, w, _ = x.shape + window_slices = self.window_op((t, h, w), self.window) + return [x[st, sh, sw] for (st, sh, sw) in window_slices] + + window_partition, window_reverse, window_shape, window_count = cache_win( + "win_transform", + lambda: window_idx(vid_shape, make_window), + ) + vid_qkv_win = window_partition(vid_qkv) + + vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + + vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + concat_win, unconcat_win = cache_win( + "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) + ) + + # window rope + if self.rope: + if self.rope.mm: + # repeat text q and k for window mmrope + _, num_h, _ = txt_q.shape + txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = unflatten(txt_q_repeat, txt_shape) + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = list(chain(*txt_q_repeat)) + txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) + txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + + txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = unflatten(txt_k_repeat, txt_shape) + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = list(chain(*txt_k_repeat)) + txt_k_repeat, _ = flatten(txt_k_repeat) + txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + + out = self.attn( + q=concat_win(vid_q, txt_q).bfloat16(), + k=concat_win(vid_k, txt_k).bfloat16(), + v=concat_win(vid_v, txt_v).bfloat16(), + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), + max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), + ).type_as(vid_q) + + # text pooling + vid_out, txt_out = unconcat_win(out) + + vid_out = rearrange(vid_out, "l h d -> l (h d)") + txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = window_reverse(vid_out) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + + return vid_out, txt_out + +class MLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + ): + super().__init__() + self.proj_in = nn.Linear(dim, dim * expand_ratio) + self.act = nn.GELU("tanh") + self.proj_out = nn.Linear(dim * expand_ratio, dim) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_in(x) + x = self.act(x) + x = self.proj_out(x) + return x + + +class SwiGLUMLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + multiple_of: int = 256, + ): + super().__init__() + hidden_dim = int(2 * dim * expand_ratio / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) + self.proj_out = nn.Linear(hidden_dim, dim, bias=False) + self.proj_in = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) + return x + +def get_mlp(mlp_type: Optional[str] = "normal"): + # 3b and 7b uses different mlp types + if mlp_type == "normal": + return MLP + elif mlp_type == "swiglu": + return SwiGLUMLP + +class NaMMSRTransformerBlock(nn.Module): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm, + norm_eps: float, + ada, + qk_bias: bool, + qk_norm, + mlp_type: str, + shared_weights: bool, + rope_type: str, + rope_dim: int, + is_last_layer: bool, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) + + self.attn = NaSwinAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + rope_type=rope_type, + rope_dim=rope_dim, + shared_weights=shared_weights, + window=kwargs.pop("window", None), + window_method=kwargs.pop("window_method", None), + ) + + self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) + self.mlp = MMModule( + get_mlp(mlp_type), + dim=dim, + expand_ratio=expand_ratio, + shared_weights=shared_weights, + vid_only=is_last_layer + ) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) + self.is_last_layer = is_last_layer + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + emb: torch.FloatTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.LongTensor, + torch.LongTensor, + ]: + hid_len = MMArg( + cache("vid_len", lambda: vid_shape.prod(-1)), + cache("txt_len", lambda: txt_shape.prod(-1)), + ) + ada_kwargs = { + "emb": emb, + "hid_len": hid_len, + "cache": cache, + "branch_tag": MMArg("vid", "txt"), + } + + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp, vid_shape, txt_shape + +class PatchOut(nn.Module): + def __init__( + self, + out_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(dim, out_channels * t * h * w) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = self.proj(vid) + vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + if t > 1: + vid = vid[:, :, (t - 1) :] + return vid + +class NaPatchOut(PatchOut): + def forward( + self, + vid: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, + ]: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache.get("vid_shape_before_patchify") + + t, h, w = self.patch_size + vid = self.proj(vid) + + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] + vid, vid_shape = flatten(vid) + + return vid, vid_shape + +class PatchIn(nn.Module): + def __init__( + self, + in_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(in_channels * t * h * w, dim) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + if t > 1: + assert vid.size(2) % t == 1 + vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) + vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + vid = self.proj(vid) + return vid + +class NaPatchIn(PatchIn): + def forward( + self, + vid: torch.Tensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> torch.Tensor: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) + t, h, w = self.patch_size + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) + vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + vid, vid_shape = flatten(vid) + + vid = self.proj(vid) + return vid, vid_shape + +def expand_dims(x: torch.Tensor, dim: int, ndim: int): + shape = x.shape + shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] + return x.reshape(shape) + + +class AdaSingle(nn.Module): + def __init__( + self, + dim: int, + emb_dim: int, + layers: List[str], + modes: List[str] = ["in", "out"], + ): + assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + super().__init__() + self.dim = dim + self.emb_dim = emb_dim + self.layers = layers + for l in layers: + if "in" in modes: + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter( + f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) + ) + if "out" in modes: + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) + + def forward( + self, + hid: torch.FloatTensor, # b ... c + emb: torch.FloatTensor, # b d + layer: str, + mode: str, + cache: Cache = Cache(disable=True), + branch_tag: str = "", + hid_len: Optional[torch.LongTensor] = None, # b + ) -> torch.FloatTensor: + idx = self.layers.index(layer) + emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = expand_dims(emb, 1, hid.ndim + 1) + + shiftA, scaleA, gateA = emb.unbind(-1) + shiftB, scaleB, gateB = ( + getattr(self, f"{layer}_shift", None), + getattr(self, f"{layer}_scale", None), + getattr(self, f"{layer}_gate", None), + ) + + if mode == "in": + return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) + if mode == "out": + return hid.mul_(gateA + gateB) + raise NotImplementedError + + +def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): + return emb1 if emb2 is None else emb1 + emb2 + + +class TimeEmbedding(nn.Module): + def __init__( + self, + sinusoidal_dim: int, + hidden_dim: int, + output_dim: int, + ): + super().__init__() + self.sinusoidal_dim = sinusoidal_dim + self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) + self.proj_hid = nn.Linear(hidden_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.act = nn.SiLU() + + def forward( + self, + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.FloatTensor: + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=device, dtype=dtype) + if timestep.ndim == 0: + timestep = timestep[None] + + emb = get_timestep_embedding( + timesteps=timestep, + embedding_dim=self.sinusoidal_dim, + flip_sin_to_cos=False, + downscale_freq_shift=0, + ) + emb = emb.to(dtype) + emb = self.proj_in(emb) + emb = self.act(emb) + emb = self.proj_hid(emb) + emb = self.act(emb) + emb = self.proj_out(emb) + return emb + +def flatten( + hid: List[torch.FloatTensor], # List of (*** c) +) -> Tuple[ + torch.FloatTensor, # (L c) + torch.LongTensor, # (b n) +]: + assert len(hid) > 0 + shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + hid = torch.cat([x.flatten(0, -2) for x in hid]) + return hid, shape + + +def unflatten( + hid: torch.FloatTensor, # (L c) or (L ... c) + hid_shape: torch.LongTensor, # (b n) +) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) + hid_len = hid_shape.prod(-1) + hid = hid.split(hid_len.tolist()) + hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] + return hid + +def repeat( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, torch.LongTensor], # (b) +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + hid = unflatten(hid, hid_shape) + kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] + return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) + +@dataclass +class NaDiTOutput: + vid_sample: torch.Tensor + + +class NaDiT(nn.Module): + + def __init__( + self, + norm_eps, + qk_rope, + num_layers, + mlp_type, + vid_in_channels = 33, + vid_out_channels = 16, + vid_dim = 2560, + txt_in_dim = 5120, + heads = 20, + head_dim = 128, + expand_ratio = 4, + qk_bias = False, + patch_size = [ 1,2,2 ], + shared_qkv: bool = False, + shared_mlp: bool = False, + window_method: Optional[Tuple[str]] = None, + temporal_window_size: int = None, + temporal_shifted: bool = False, + **kwargs, + ): + txt_dim = vid_dim + emb_dim = vid_dim * 6 + block_type = ["mmdit_sr"] * num_layers + window = num_layers * [(4,3,3)] + ada = AdaSingle + norm = RMSNorm + qk_norm = RMSNorm + if isinstance(block_type, str): + block_type = [block_type] * num_layers + elif len(block_type) != num_layers: + raise ValueError("The ``block_type`` list should equal to ``num_layers``.") + super().__init__() + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + ) + self.txt_in = ( + nn.Linear(txt_in_dim, txt_dim) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=256, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + if window_method is None or isinstance(window_method, str): + window_method = [window_method] * num_layers + if temporal_window_size is None or isinstance(temporal_window_size, int): + temporal_window_size = [temporal_window_size] * num_layers + if temporal_shifted is None or isinstance(temporal_shifted, bool): + temporal_shifted = [temporal_shifted] * num_layers + + self.blocks = nn.ModuleList( + [ + NaMMSRTransformerBlock( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + shared_qkv=shared_qkv, + shared_mlp=shared_mlp, + mlp_type=mlp_type, + window=window[i], + window_method=window_method[i], + temporal_window_size=temporal_window_size[i], + temporal_shifted=temporal_shifted[i], + **kwargs, + ) + for i in range(num_layers) + ] + ) + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + ) + + self.need_txt_repeat = block_type[0] in [ + "mmdit_stwin", + "mmdit_stwin_spatial", + "mmdit_stwin_3d_spatial", + ] + + def set_gradient_checkpointing(self, enable: bool): + self.gradient_checkpointing = enable + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b + disable_cache: bool = True, # for test + ): + # Text input. + if txt_shape.size(-1) == 1 and self.need_txt_repeat: + txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + # slice vid after patching in when using sequence parallelism + txt = self.txt_in(txt) + + # Video input. + # Sequence parallel slicing is done inside patching class. + vid, vid_shape = self.vid_in(vid, vid_shape) + + # Embedding input. + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + + # Body + cache = Cache(disable=disable_cache) + for i, block in enumerate(self.blocks): + vid, txt, vid_shape, txt_shape = block( + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + vid, vid_shape = self.vid_out(vid, vid_shape, cache) + return NaDiTOutput(vid_sample=vid) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py new file mode 100644 index 000000000..eb74e9442 --- /dev/null +++ b/comfy/ldm/seedvr/vae.py @@ -0,0 +1,1260 @@ +from contextlib import nullcontext +from typing import Literal, Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.models.upsampling import Upsample2D +from einops import rearrange + +from model import safe_pad_operation +from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution + +class SpatialNorm(nn.Module): + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + +def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x.to(input_dtype) + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x.to(input_dtype) + if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4: + return norm_layer(x).to(input_dtype) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + memory_occupy = x.numel() * x.element_size() / 1024**3 + if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae + num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) + assert norm_layer.num_groups % num_chunks == 0 + num_groups_per_chunk = norm_layer.num_groups // num_chunks + + x = list(x.chunk(num_chunks, dim=1)) + weights = norm_layer.weight.chunk(num_chunks, dim=0) + biases = norm_layer.bias.chunk(num_chunks, dim=0) + for i, (w, b) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + x[i] = x[i].to(input_dtype) + x = torch.cat(x, dim=1) + else: + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x.to(input_dtype) + raise NotImplementedError + +def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): + """Safe interpolate operation that handles Half precision for problematic modes""" + # Modes qui peuvent causer des problèmes avec Half precision + problematic_modes = ['bilinear', 'bicubic', 'trilinear'] + + if mode in problematic_modes: + try: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + except RuntimeError as e: + if ("not implemented for 'Half'" in str(e) or + "compute_indices_weights" in str(e)): + original_dtype = x.dtype + return F.interpolate( + x.float(), + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ).to(original_dtype) + else: + raise e + else: + # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + +_receptive_field_t = Literal["half", "full"] + +class InflatedCausalConv3d(nn.Conv3d): + def __init__( + self, + *args, + inflation_mode, + **kwargs, + ): + self.inflation_mode = inflation_mode + self.memory = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.padding = (0, *self.padding[1:]) + self.memory_limit = float("inf") + + def forward( + self, + input, + ): + return super().forward(input) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + # wirdly inflation_mode is pad, which would cause an assert error + #if self.inflation_mode != "none": + # state_dict = modify_state_dict( + # self, + # state_dict, + # prefix, + # inflate_weight_fn=inflate_weight, + # inflate_bias_fn=inflate_bias, + # ) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + (strict and self.inflation_mode == "none"), + missing_keys, + unexpected_keys, + error_msgs, + ) + +class Upsample3D(nn.Module): + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + temporal_up: bool = False, + spatial_up: bool = True, + slicing: bool = False, + interpolate = True, + name: str = "conv", + use_conv_transpose = False, + use_conv: bool = False, + padding = 1, + bias = True, + kernel_size = None, + **kwargs, + ): + super().__init__() + self.interpolate = interpolate + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv_transpose = use_conv_transpose + self.use_conv = use_conv + self.name = name + + self.conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + self.conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + conv = self.conv if self.name == "conv" else self.Conv2d_0 + + assert type(conv) is not nn.ConvTranspose2d + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + 3, + padding=1, + inflation_mode=inflation_mode, + ) + + self.temporal_up = temporal_up + self.spatial_up = spatial_up + self.temporal_ratio = 2 if temporal_up else 1 + self.spatial_ratio = 2 if spatial_up else 1 + self.slicing = slicing + + assert not self.interpolate + # [Override] MAGViT v2 implementation + if not self.interpolate: + upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio + self.upscale_conv = nn.Conv3d( + self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 + ) + identity = ( + torch.eye(self.channels) + .repeat(upscale_ratio, 1) + .reshape_as(self.upscale_conv.weight) + ) + self.upscale_conv.weight.data.copy_(identity) + + if self.name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + self.norm = False + + def forward( + self, + hidden_states: torch.FloatTensor, + **kwargs, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + if self.slicing: + split_size = hidden_states.size(2) // 2 + hidden_states = list( + hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) + ) + else: + hidden_states = [hidden_states] + + for i in range(len(hidden_states)): + hidden_states[i] = self.upscale_conv(hidden_states[i]) + hidden_states[i] = rearrange( + hidden_states[i], + "b (x y z c) f h w -> b c (f z) (h x) (w y)", + x=self.spatial_ratio, + y=self.spatial_ratio, + z=self.temporal_ratio, + ) + + if not self.slicing: + hidden_states = hidden_states[0] + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + if not self.slicing: + return hidden_states + else: + return torch.cat(hidden_states, dim=2) + + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution.""" + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + spatial_down: bool = False, + temporal_down: bool = False, + name: str = "conv", + padding = 1, + **kwargs, + ): + super().__init__() + self.padding = padding + self.name = name + self.channels = channels + self.out_channels = out_channels or channels + conv = self.conv + self.temporal_down = temporal_down + self.spatial_down = spatial_down + + self.temporal_ratio = 2 if temporal_down else 1 + self.spatial_ratio = 2 if spatial_down else 1 + + self.temporal_kernel = 3 if temporal_down else 1 + self.spatial_kernel = 3 if spatial_down else 1 + + if type(conv) in [nn.Conv2d]: + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + padding=( + 1 if self.temporal_down else 0, + self.padding if self.spatial_down else 0, + self.padding if self.spatial_down else 0, + ), + inflation_mode=inflation_mode, + ) + elif type(conv) is nn.AvgPool2d: + assert self.channels == self.out_channels + conv = nn.AvgPool3d( + kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + ) + else: + raise NotImplementedError + + if self.name == "conv": + self.Conv2d_0 = conv + self.conv = conv + else: + self.conv = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + **kwargs, + ) -> torch.FloatTensor: + + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv and self.padding == 0 and self.spatial_down: + pad = (0, 1, 0, 1) + hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + skip_time_act: bool = False, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + slicing: bool = False, + **kwargs, + ): + super().__init__() + self.up = up + self.down = down + self.use_in_shortcut = use_in_shortcut + self.output_scale_factor = output_scale_factor + self.skip_time_act = skip_time_act + self.nonlinearity = nn.SiLU() + if temb_channels is not None: + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + self.conv1 = InflatedCausalConv3d( + self.in_channels, + self.out_channels, + kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), + stride=1, + padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), + inflation_mode=inflation_mode, + ) + + self.conv2 = InflatedCausalConv3d( + self.out_channels, + self.conv2.out_channels, + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + if self.up: + self.upsample = Upsample3D( + self.in_channels, + use_conv=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + elif self.down: + self.downsample = Downsample3D( + self.in_channels, + use_conv=False, + padding=1, + name="op", + inflation_mode=inflation_mode, + ) + + if self.use_in_shortcut: + self.conv_shortcut = InflatedCausalConv3d( + self.in_channels, + self.conv_shortcut.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=(self.conv_shortcut.bias is not None), + inflation_mode=inflation_mode, + ) + + def forward( + self, input_tensor, temb, **kwargs + ): + hidden_states = input_tensor + + hidden_states = causal_norm_wrapper(self.norm1, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if temb is not None: + hidden_states = hidden_states + temb + + hidden_states = causal_norm_wrapper(self.norm2, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_down: bool = True, + spatial_down: bool = True, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + # [Override] Replace module. + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + temporal_down=temporal_down, + spatial_down=spatial_down, + inflation_mode=inflation_mode, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + **kwargs, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temporal(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up: bool = True, + spatial_up: bool = True, + slicing: bool = False, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + slicing=slicing, + ) + ) + + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_upsample: + # [Override] Replace module & use learnable upsample + self.upsamplers = nn.ModuleList( + [ + Upsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + temporal_up=temporal_up, + spatial_up=spatial_up, + interpolate=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + ] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temporal(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ] + attentions = [] + + if attention_head_dim is None: + print( + f"It is not recommend to pass `attention_head_dim=None`. " + f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=( + resnet_groups if resnet_time_scale_shift == "default" else None + ), + spatial_norm_dim=( + temb_channels if resnet_time_scale_shift == "spatial" else None + ), + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + video_length, frame_height, frame_width = hidden_states.size()[-3:] + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn(hidden_states, temb=temb) + hidden_states = rearrange( + hidden_states, "(b f) c h w -> b c f h w", f=video_length + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class Encoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + # [Override] add extra_cond_dim, temporal down num + temporal_down_num: int = 2, + extra_cond_dim: int = None, + gradient_checkpoint: bool = False, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_down_num = temporal_down_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + self.extra_cond_dim = extra_cond_dim + + self.conv_extra_cond = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + # [Override] to support temporal down block design + is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 + # Note: take the last ones + + assert down_block_type == "DownEncoderBlock3D" + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + # Note: Don't know why set it as 0 + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temporal_down=is_temporal_down_block, + spatial_down=True, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.down_blocks.append(down_block) + + def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + self.conv_extra_cond.append( + zero_module( + nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) + ) + if self.extra_cond_dim is not None and self.extra_cond_dim > 0 + else None + ) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = InflatedCausalConv3d( + block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + def forward( + self, + sample: torch.FloatTensor, + extra_cond=None, + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + sample = self.conv_in(sample) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample) + + else: + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = down_block(sample) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder3D(nn.Module): + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + # [Override] add temporal up block + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_up_num = temporal_up_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + print(f"slicing_up_num: {slicing_up_num}") + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + is_temporal_up_block = i < self.temporal_up_num + is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num + # Note: Keep symmetric + + assert up_block_type == "UpDecoderBlock3D" + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=norm_type, + temb_channels=temb_channels, + temporal_up=is_temporal_up_block, + slicing=is_slicing_up_block, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedCausalConv3d( + block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + # Note: Just copy from Decoder. + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + +class VideoAutoencoderKL(nn.Module): + """ + We simply inherit the model code from diffusers + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + attention: bool = True, + temporal_scale_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "full", + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + *args, + **kwargs, + ): + extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + extra_cond_dim=extra_cond_dim, + # [Override] add temporal_down_num parameter + temporal_down_num=temporal_scale_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + # [Override] add temporal_up_num parameter + temporal_up_num=temporal_scale_num, + slicing_up_num=slicing_up_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + self.quant_conv = ( + InflatedCausalConv3d( + in_channels=2 * latent_channels, + out_channels=2 * latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + InflatedCausalConv3d( + in_channels=latent_channels, + out_channels=latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_post_quant_conv + else None + ) + + # A hacky way to remove attention. + if not attention: + self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True): + h = self.slicing_encode(x) + posterior = DiagonalGaussianDistribution(h).sample() + + if not return_dict: + return (posterior,) + + return posterior + + def decode( + self, z: torch.Tensor, return_dict: bool = True + ): + decoded = self.slicing_decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def _encode( + self, x: torch.Tensor + ) -> torch.Tensor: + _x = x.to(self.device) + h = self.encoder(_x,) + if self.quant_conv is not None: + output = self.quant_conv(h) + else: + output = h + return output.to(x.device) + + def _decode( + self, z: torch.Tensor + ) -> torch.Tensor: + _z = z.to(self.device) + if self.post_quant_conv is not None: + _z = self.post_quant_conv(_z) + output = self.decoder(_z) + return output.to(z.device) + + def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: + return self._encode(x) + + def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: + return self._decode(z) + + def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs + ): + # x: [b c t h w] + if mode == "encode": + h = self.encode(x) + return h.latent_dist + elif mode == "decode": + h = self.decode(x) + return h.sample + else: + h = self.encode(x) + h = self.decode(h.latent_dist.mode()) + return h.sample + + def load_state_dict(self, state_dict, strict=False): + # Newer version of diffusers changed the model keys, + # causing incompatibility with old checkpoints. + # They provided a method for conversion. + # We call conversion before loading state_dict. + convert_deprecated_attention_blocks = getattr( + self, "_convert_deprecated_attention_blocks", None + ) + if callable(convert_deprecated_attention_blocks): + convert_deprecated_attention_blocks(state_dict) + return super().load_state_dict(state_dict, strict) + + +class VideoAutoencoderKLWrapper(VideoAutoencoderKL): + def __init__( + self, + *args, + spatial_downsample_factor = 8, + temporal_downsample_factor = 4, + freeze_encoder = True, + **kwargs, + ): + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + self.freeze_encoder = freeze_encoder + super().__init__(*args, **kwargs) + + def forward(self, x: torch.FloatTensor): + with torch.no_grad() if self.freeze_encoder else nullcontext(): + z, p = self.encode(x) + x = self.decode(z).sample + return x, z, p + + def encode(self, x: torch.FloatTensor): + if x.ndim == 4: + x = x.unsqueeze(2) + p = super().encode(x).latent_dist + z = p.sample().squeeze(2) + return z, p + + def decode(self, z: torch.FloatTensor): + if z.ndim == 4: + z = z.unsqueeze(2) + x = super().decode(z).sample.squeeze(2) + return x + + def preprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + assert x.ndim == 4 or x.size(2) % 4 == 1 + return x + + def postprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + return x + + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): + # TODO + #set_norm_limit(norm_max_mem) + for m in self.modules(): + if isinstance(m, InflatedCausalConv3d): + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index 4392355ea..bbab8627a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.seedvr.model import comfy.model_management import comfy.patcher_extension @@ -793,6 +794,11 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out + +class SeedVR2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) + # TODO: extra_conds could be needed to add class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 18232ade3..600c089fa 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -341,6 +341,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_dims"] = [32, 32, 32] dit_config["axes_lens"] = [300, 512, 512] return dit_config + + elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b + dit_config = {} + dit_config["vid_dim"] = 2560 + dit_config["heads"] = 20 + dit_config["num_layers"] = 32 + dit_config["norm_eps"] = 1.0e-05 + dit_config["qk_rope"] = None + dit_config["mlp_type"] = "swiglu" + + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} diff --git a/comfy/sd.py b/comfy/sd.py index 5b95cf75a..79b17073f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.hunyuan3d.vae +import comfy.ldm.seedvr.vae import comfy.ldm.ace.vae.music_dcae_pipeline import yaml import math @@ -391,6 +392,19 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + + elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 + self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() + ddconfig["conv3d"] = True + ddconfig["time_compress"] = 4 + self.memory_used_decode = lambda shape, dtype: (2000 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[2], 5) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float32] + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2669ca01e..2301b1188 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1153,6 +1153,21 @@ class Chroma(supported_models_base.BASE): pref = self.text_encoder_key_prefix[0] t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) + +class SeedVR2(supported_models_base.Base): + unet_config = { + "image_mode": "seedvr2" + } + latent_format = comfy.latent_formats.SeedVR2 + + vae_key_prefix = ["vae."] + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix = "", device=None): + out = model_base.SeedVR2(self, device=device) + return out + def clip_target(self, state_dict={}): + return None class ACEStep(supported_models_base.BASE): unet_config = { @@ -1217,6 +1232,6 @@ class Omnigen2(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, SeedVR2] models += [SVD_img2vid] From 041dbd6a8a241eccb8a35eddc80b78176f42b7f0 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 7 Dec 2025 01:00:08 +0200 Subject: [PATCH 02/35] add nodes --- comfy/ldm/seedvr/vae.py | 10 +-- comfy_extras/nodes_seedvr.py | 116 +++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 comfy_extras/nodes_seedvr.py diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index eb74e9442..51c5b2578 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1051,10 +1051,9 @@ class VideoAutoencoderKL(nn.Module): out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock3D",), up_block_types: Tuple[str] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, + layers_per_block: int = 2, act_fn: str = "silu", - latent_channels: int = 4, + latent_channels: int = 16, norm_num_groups: int = 32, attention: bool = True, temporal_scale_num: int = 2, @@ -1062,12 +1061,13 @@ class VideoAutoencoderKL(nn.Module): gradient_checkpoint: bool = False, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "full", - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, *args, **kwargs, ): extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None + block_out_channels = (128, 256, 512, 512) super().__init__() # pass init params to Encoder diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py new file mode 100644 index 000000000..60bd551dd --- /dev/null +++ b/comfy_extras/nodes_seedvr.py @@ -0,0 +1,116 @@ + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io, ui +import torch +import math +from einops import rearrange + +from torchvision.transforms import functional as TVF +from torchvision.transforms import Lambda, Normalize +from torchvision.transforms.functional import InterpolationMode + + +def area_resize(image, max_area): + + height, width = image.shape[-2:] + scale = math.sqrt(max_area / (height * width)) + + resized_height, resized_width = round(height * scale), round(width * scale) + + return TVF.resize( + image, + size=(resized_height, resized_width), + interpolation=InterpolationMode.BICUBIC, + ) + +def crop(image, factor): + height_factor, width_factor = factor + height, width = image.shape[-2:] + + cropped_height = height - (height % height_factor) + cropped_width = width - (width % width_factor) + + image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width)) + return image + +def cut_videos(videos): + t = videos.size(1) + if t == 1: + return videos + if t <= 4 : + padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 - ((t - 1) % (4)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4) == 0 + return videos + +class SeedVR2InputProcessing(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id = "SeedVR2InputProcessing", + category="image/video", + inputs = [ + io.Image.Input("images"), + io.Int.Input("resolution_height"), + io.Int.Input("resolution_width") + ], + outputs = [ + io.Image.Output("images") + ] + ) + + @classmethod + def execute(cls, images, resolution_height, resolution_width): + max_area = ((resolution_height * resolution_width)** 0.5) ** 2 + clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) + normalize = Normalize(0.5, 0.5) + images = area_resize(images, max_area) + images = clip(images) + images = crop(images, (16, 16)) + images = normalize(images) + images = rearrange(images, "t c h w -> c t h w") + images = cut_videos(images) + return + +class SeedVR2Conditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Conditioning", + category="image/video", + inputs=[ + io.Conditioning.Input("text_positive_conditioning"), + io.Conditioning.Input("text_negative_conditioning"), + io.Conditioning.Input("vae_conditioning") + ], + outputs=[io.Conditioning.Output("positive"), io.Conditioning.Output("negative")], + ) + + @classmethod + def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: + # TODO + pos_cond = text_positive_conditioning[0][0] + neg_cond = text_negative_conditioning[0][0] + + return io.NodeOutput() + +class SeedVRExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SeedVR2Conditioning, + SeedVR2InputProcessing + ] + +async def comfy_entrypoint() -> SeedVRExtension: + return SeedVRExtension() \ No newline at end of file From 4b9332cc215a8ab12163908044286ec4fc9bab87 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 7 Dec 2025 21:41:14 +0200 Subject: [PATCH 03/35] continue building nodes / testing vae --- comfy/ldm/seedvr/model.py | 40 ++--- comfy/ldm/seedvr/vae.py | 315 +++++++++++++++++++++++++++++++---- comfy_extras/nodes_seedvr.py | 79 ++++++++- 3 files changed, 378 insertions(+), 56 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 40a460d67..cf6287b03 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1141,11 +1141,6 @@ def repeat( kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) -@dataclass -class NaDiTOutput: - vid_sample: torch.Tensor - - class NaDiT(nn.Module): def __init__( @@ -1246,26 +1241,32 @@ class NaDiT(nn.Module): "mmdit_stwin_3d_spatial", ] - def set_gradient_checkpointing(self, enable: bool): - self.gradient_checkpointing = enable - def forward( self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b - disable_cache: bool = True, # for test - ): - # Text input. + x, + timestep, + context, # l c + txt_shape, # b 1 + disable_cache: bool = True, # for test # TODO ? + ): + pos_cond, neg_cond = context.chunk(2, dim=0) + pos_cond, pos_shape = flatten(pos_cond) + neg_cond, neg_shape = flatten(neg_cond) + diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) + if pos_shape.shape[1] > neg_shape.shape[1]: + neg_shape = F.pad(neg_shape, (0, 0, 0, diff)) + neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) + else: + pos_shape = F.pad(pos_shape, (0, 0, 0, diff)) + pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) + vid = x + txt = context + vid, vid_shape = flatten(x) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) # slice vid after patching in when using sequence parallelism txt = self.txt_in(txt) - # Video input. - # Sequence parallel slicing is done inside patching class. vid, vid_shape = self.vid_in(vid, vid_shape) # Embedding input. @@ -1284,4 +1285,5 @@ class NaDiT(nn.Module): ) vid, vid_shape = self.vid_out(vid, vid_shape, cache) - return NaDiTOutput(vid_sample=vid) + vid = unflatten(vid, vid_shape) + return vid diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 51c5b2578..3a0f8cfed 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -4,11 +4,11 @@ import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.attention_processor import Attention -from diffusers.models.upsampling import Upsample2D from einops import rearrange from model import safe_pad_operation from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution +from comfy.ldm.modules.attention import optimized_attention class SpatialNorm(nn.Module): def __init__( @@ -28,6 +28,259 @@ class SpatialNorm(nn.Module): new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f +# partial implementation of diffusers's Attention for comfyui +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + is_causal: bool = False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + self.norm_q = None + self.norm_k = None + + self.norm_cross = None + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + self.norm_added_q = None + self.norm_added_k = None + + def __call__( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + + residual = hidden_states + if self.spatial_norm is not None: + hidden_states = self.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + + return hidden_states + + +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution weight matrix to a 3D one. + Parameters: + weight_2d: The weight matrix of 2D conv to be inflated. + weight_3d: The weight matrix of 3D conv to be initialized. + inflation_mode: the mode of inflation + """ + assert inflation_mode in ["tail", "replicate"] + assert weight_3d.shape[:2] == weight_2d.shape[:2] + with torch.no_grad(): + if inflation_mode == "replicate": + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) + else: + weight_3d.fill_(0.0) + weight_3d[:, :, -1].copy_(weight_2d) + return weight_3d + + +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution bias tensor to a 3D one + Parameters: + bias_2d: The bias tensor of 2D conv to be inflated. + bias_3d: The bias tensor of 3D conv to be initialized. + inflation_mode: Placeholder to align `inflate_weight`. + """ + assert bias_3d.shape == bias_2d.shape + with torch.no_grad(): + bias_3d.copy_(bias_2d) + return bias_3d + + +def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): + """ + the main function to inflated 2D parameters to 3D. + """ + weight_name = prefix + "weight" + bias_name = prefix + "bias" + if weight_name in state_dict: + weight_2d = state_dict[weight_name] + if weight_2d.dim() == 4: + # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) + weight_3d = inflate_weight_fn( + weight_2d=weight_2d, + weight_3d=layer.weight, + inflation_mode=layer.inflation_mode, + ) + state_dict[weight_name] = weight_3d + else: + return state_dict + # It's a 3d state dict, should not do inflation on both bias and weight. + if bias_name in state_dict: + bias_2d = state_dict[bias_name] + if bias_2d.dim() == 1: + # Assuming the 2D biases are 1D tensors (out_channels,) + bias_3d = inflate_bias_fn( + bias_2d=bias_2d, + bias_3d=layer.bias, + inflation_mode=layer.inflation_mode, + ) + state_dict[bias_name] = bias_3d + return state_dict + def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): @@ -131,15 +384,14 @@ class InflatedCausalConv3d(nn.Conv3d): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - # wirdly inflation_mode is pad, which would cause an assert error - #if self.inflation_mode != "none": - # state_dict = modify_state_dict( - # self, - # state_dict, - # prefix, - # inflate_weight_fn=inflate_weight, - # inflate_bias_fn=inflate_bias, - # ) + if self.inflation_mode != "none": + state_dict = modify_state_dict( + self, + state_dict, + prefix, + inflate_weight_fn=inflate_weight, + inflate_bias_fn=inflate_bias, + ) super()._load_from_state_dict( state_dict, prefix, @@ -287,7 +539,10 @@ class Downsample3D(nn.Module): spatial_down: bool = False, temporal_down: bool = False, name: str = "conv", + kernel_size=3, + use_conv: bool = False, padding = 1, + bias=True, **kwargs, ): super().__init__() @@ -295,7 +550,6 @@ class Downsample3D(nn.Module): self.name = name self.channels = channels self.out_channels = out_channels or channels - conv = self.conv self.temporal_down = temporal_down self.spatial_down = spatial_down @@ -305,9 +559,7 @@ class Downsample3D(nn.Module): self.temporal_kernel = 3 if temporal_down else 1 self.spatial_kernel = 3 if spatial_down else 1 - if type(conv) in [nn.Conv2d]: - # Note: lora_layer is not passed into constructor in the original implementation. - # So we make a simplification. + if use_conv: conv = InflatedCausalConv3d( self.channels, self.out_channels, @@ -320,20 +572,15 @@ class Downsample3D(nn.Module): ), inflation_mode=inflation_mode, ) - elif type(conv) is nn.AvgPool2d: + else: assert self.channels == self.out_channels conv = nn.AvgPool3d( kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), ) - else: - raise NotImplementedError + + self.conv = conv - if self.name == "conv": - self.Conv2d_0 = conv - self.conv = conv - else: - self.conv = conv def forward( self, @@ -386,6 +633,9 @@ class ResnetBlock3D(nn.Module): super().__init__() self.up = up self.down = down + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + conv_2d_out_channels = conv_2d_out_channels or out_channels self.use_in_shortcut = use_in_shortcut self.output_scale_factor = output_scale_factor self.skip_time_act = skip_time_act @@ -394,6 +644,12 @@ class ResnetBlock3D(nn.Module): self.time_emb_proj = nn.Linear(temb_channels, out_channels) else: self.time_emb_proj = None + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + if groups_out is None: + groups_out = groups + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.use_in_shortcut = self.in_channels != out_channels + self.dropout = torch.nn.Dropout(dropout) self.conv1 = InflatedCausalConv3d( self.in_channels, self.out_channels, @@ -405,7 +661,7 @@ class ResnetBlock3D(nn.Module): self.conv2 = InflatedCausalConv3d( self.out_channels, - self.conv2.out_channels, + conv_2d_out_channels, kernel_size=3, stride=1, padding=1, @@ -431,11 +687,11 @@ class ResnetBlock3D(nn.Module): if self.use_in_shortcut: self.conv_shortcut = InflatedCausalConv3d( self.in_channels, - self.conv_shortcut.out_channels, + conv_2d_out_channels, kernel_size=1, stride=1, padding=0, - bias=(self.conv_shortcut.bias is not None), + bias=True, inflation_mode=inflation_mode, ) @@ -534,7 +790,6 @@ class DownEncoderBlock3D(nn.Module): if add_downsample: self.downsamplers = nn.ModuleList( [ - # [Override] Replace module. Downsample3D( out_channels, use_conv=True, @@ -1049,8 +1304,6 @@ class VideoAutoencoderKL(nn.Module): self, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), layers_per_block: int = 2, act_fn: str = "silu", latent_channels: int = 16, @@ -1059,7 +1312,7 @@ class VideoAutoencoderKL(nn.Module): temporal_scale_num: int = 2, slicing_up_num: int = 0, gradient_checkpoint: bool = False, - inflation_mode = "tail", + inflation_mode = "pad", time_receptive_field: _receptive_field_t = "full", use_quant_conv: bool = False, use_post_quant_conv: bool = False, @@ -1068,6 +1321,8 @@ class VideoAutoencoderKL(nn.Module): ): extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None block_out_channels = (128, 256, 512, 512) + down_block_types = ("DownEncoderBlock3D",) * 4 + up_block_types = ("UpDecoderBlock3D",) * 4 super().__init__() # pass init params to Encoder @@ -1257,4 +1512,4 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): #set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): - m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) \ No newline at end of file + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 60bd551dd..9d4e8bf34 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -1,6 +1,5 @@ - from typing_extensions import override -from comfy_api.latest import ComfyExtension, io, ui +from comfy_api.latest import ComfyExtension, io import torch import math from einops import rearrange @@ -9,7 +8,51 @@ from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode +def expand_dims(tensor, ndim): + shape = tensor.shape + (1,) * (ndim - tensor.ndim) + return tensor.reshape(shape) +def get_conditions(latent, latent_blur): + t, h, w, c = latent.shape + cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 + return cond + +def timestep_transform(timesteps, latents_shapes): + vt = 4 + vs = 8 + frames = (latents_shapes[:, 0] - 1) * vt + 1 + heights = latents_shapes[:, 1] * vs + widths = latents_shapes[:, 2] * vs + + # Compute shift factor. + def get_lin_function(x1, y1, x2, y2): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2) + vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0) + shift = torch.where( + frames > 1, + vid_shift_fn(heights * widths * frames), + img_shift_fn(heights * widths), + ) + + # Shift timesteps. + T = 1000.0 + timesteps = timesteps / T + timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) + timesteps = timesteps * T + return timesteps + +def inter(x_0, x_T, t): + t = expand_dims(t, x_0.ndim) + T = 1000.0 + B = lambda t: t / T + A = lambda t: 1 - (t / T) + return A(t) * x_0 + B(t) * x_T def area_resize(image, max_area): height, width = image.shape[-2:] @@ -80,7 +123,7 @@ class SeedVR2InputProcessing(io.ComfyNode): images = normalize(images) images = rearrange(images, "t c h w -> c t h w") images = cut_videos(images) - return + return io.NodeOutput(images) class SeedVR2Conditioning(io.ComfyNode): @classmethod @@ -93,16 +136,38 @@ class SeedVR2Conditioning(io.ComfyNode): io.Conditioning.Input("text_negative_conditioning"), io.Conditioning.Input("vae_conditioning") ], - outputs=[io.Conditioning.Output("positive"), io.Conditioning.Output("negative")], + outputs=[io.Conditioning.Output(display_name = "positive"), + io.Conditioning.Output(display_name = "negative"), + io.Latent.Output(display_name = "latent")], ) @classmethod def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: - # TODO + # TODO: should do the flattening logic as with the original code pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] - return io.NodeOutput() + noises = [torch.randn_like(latent) for latent in vae_conditioning] + aug_noises = [torch.randn_like(latent) for latent in vae_conditioning] + + cond_noise_scale = 0.0 + t = ( + torch.tensor([1000.0]) + * cond_noise_scale + ) + shape = torch.tensor(vae_conditioning.shape[1:])[None] + t = timestep_transform(t, shape) + cond = inter(vae_conditioning, aug_noises, t) + condition = get_conditions(noises, cond) + + # TODO / FIXME + pos_cond = torch.cat([condition, pos_cond], dim = 0) + neg_cond = torch.cat([condition, neg_cond], dim = 0) + + negative = [[pos_cond, {}]] + positive = [[neg_cond, {}]] + + return io.NodeOutput(positive, negative, noises) class SeedVRExtension(ComfyExtension): @override @@ -113,4 +178,4 @@ class SeedVRExtension(ComfyExtension): ] async def comfy_entrypoint() -> SeedVRExtension: - return SeedVRExtension() \ No newline at end of file + return SeedVRExtension() From 44a5bf353af34f248b137ee7fcbad912b9f6c09b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 7 Dec 2025 23:43:49 +0200 Subject: [PATCH 04/35] testing the model --- comfy/ldm/modules/attention.py | 2 +- comfy/ldm/seedvr/model.py | 178 ++++++++++++++++++++++----------- 2 files changed, 119 insertions(+), 61 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 256f9a989..35d2270ee 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -428,7 +428,7 @@ else: SDP_BATCH_LIMIT = 2**31 -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=Falsez): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index cf6287b03..86836468f 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -145,56 +145,77 @@ def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} -def make_720Pwindows(size, num_windows, shift = False): +def get_window_op(name: str): + if name == "720pwin_by_size_bysize": + return make_720Pwindows_bysize + if name == "720pswin_by_size_bysize": + return make_shifted_720Pwindows_bysize + raise ValueError(f"Unknown windowing method: {name}") + + +# -------------------------------- Windowing -------------------------------- # +def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): t, h, w = size resized_nt, resized_nh, resized_nw = num_windows - - scale = sqrt((45 * 80) / (h * w)) + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + return [ + ( + slice(it * wt, min((it + 1) * wt, t)), + slice(ih * wh, min((ih + 1) * wh, h)), + slice(iw * ww, min((iw + 1) * ww, w)), + ) + for iw in range(nw) + if min((iw + 1) * ww, w) > iw * ww + for ih in range(nh) + if min((ih + 1) * wh, h) > ih * wh + for it in range(nt) + if min((it + 1) * wt, t) > it * wt + ] - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) - wt = ceil(min(t, 30) / resized_nt) +def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. - st, sh, sw = (0.5 * shift if wt < t else 0, - 0.5 * shift if wh < h else 0, - 0.5 * shift if ww < w else 0) - - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) - if shift: - nt += 1 if st > 0 else 0 - nh += 1 if sh > 0 else 0 - nw += 1 if sw > 0 else 0 - - windows = [] - for iw in range(nw): - w_start = max(int((iw - sw) * ww), 0) - w_end = min(int((iw - sw + 1) * ww), w) - if w_end <= w_start: - continue - - for ih in range(nh): - h_start = max(int((ih - sh) * wh), 0) - h_end = min(int((ih - sh + 1) * wh), h) - if h_end <= h_start: - continue - - for it in range(nt): - t_start = max(int((it - st) * wt), 0) - t_end = min(int((it - st + 1) * wt), t) - if t_end <= t_start: - continue - - windows.append((slice(t_start, t_end), - slice(h_start, h_end), - slice(w_start, w_end))) - - return windows + st, sh, sw = ( # shift size. + 0.5 if wt < t else 0, + 0.5 if wh < h else 0, + 0.5 if ww < w else 0, + ) + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. + nt, nh, nw = ( # number of window. + nt + 1 if st > 0 else 1, + nh + 1 if sh > 0 else 1, + nw + 1 if sw > 0 else 1, + ) + return [ + ( + slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), + slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), + slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), + ) + for iw in range(nw) + if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) + for ih in range(nh) + if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) + for it in range(nt) + if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) + ] class RotaryEmbedding(nn.Module): def __init__( self, dim, - custom_freqs, + custom_freqs = None, freqs_for = 'lang', theta = 10000, max_freq = 10, @@ -566,6 +587,7 @@ class NaMMAttention(nn.Module): ): super().__init__() dim = MMArg(vid_dim, txt_dim) + self.heads = heads inner_dim = heads * head_dim qkv_dim = inner_dim * 3 self.head_dim = head_dim @@ -575,19 +597,20 @@ class NaMMAttention(nn.Module): self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) self.norm_q = MMModule( qk_norm, - dim=head_dim, + normalized_shape=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, ) self.norm_k = MMModule( qk_norm, - dim=head_dim, + normalized_shape=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, ) + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) def forward( @@ -634,7 +657,7 @@ class NaMMAttention(nn.Module): _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) - attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True) + attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True) attn = attn.flatten(0, 1) # to continue working with the rest of the code attn = rearrange(attn, "l h d -> l (h d)") @@ -682,7 +705,7 @@ class NaSwinAttention(NaMMAttention): self.window_method = window_method assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) - self.window_op = window_method + self.window_op = get_window_op(window_method) def forward( self, @@ -754,20 +777,17 @@ class NaSwinAttention(NaMMAttention): ) else: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - out = self.attn( - q=concat_win(vid_q, txt_q).bfloat16(), - k=concat_win(vid_k, txt_k).bfloat16(), - v=concat_win(vid_v, txt_v).bfloat16(), - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), - max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), - ).type_as(vid_q) + + # TODO: continue testing + b = len(vid_len_win) + vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] + tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] + + q = torch.cat([vq, tq], dim=1) + k = torch.cat([vk, tk], dim=1) + v = torch.cat([vv, tv], dim=1) + out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True) + out = out.flatten(0, 1) # text pooling vid_out, txt_out = unconcat_win(out) @@ -847,7 +867,7 @@ class NaMMSRTransformerBlock(nn.Module): ): super().__init__() dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) + self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) self.attn = NaSwinAttention( vid_dim=vid_dim, @@ -864,7 +884,7 @@ class NaMMSRTransformerBlock(nn.Module): window_method=kwargs.pop("window_method", None), ) - self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) + self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) self.mlp = MMModule( get_mlp(mlp_type), dim=dim, @@ -1155,6 +1175,7 @@ class NaDiT(nn.Module): txt_in_dim = 5120, heads = 20, head_dim = 128, + mm_layers = 10, expand_ratio = 4, qk_bias = False, patch_size = [ 1,2,2 ], @@ -1163,8 +1184,12 @@ class NaDiT(nn.Module): window_method: Optional[Tuple[str]] = None, temporal_window_size: int = None, temporal_shifted: bool = False, + rope_dim = 128, + rope_type = "mmrope3d", + vid_out_norm: Optional[str] = None, **kwargs, ): + window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] txt_dim = vid_dim emb_dim = vid_dim * 6 block_type = ["mmdit_sr"] * num_layers @@ -1202,6 +1227,7 @@ class NaDiT(nn.Module): if temporal_shifted is None or isinstance(temporal_shifted, bool): temporal_shifted = [temporal_shifted] * num_layers + rope_dim = rope_dim if rope_dim is not None else head_dim // 2 self.blocks = nn.ModuleList( [ NaMMSRTransformerBlock( @@ -1220,10 +1246,16 @@ class NaDiT(nn.Module): shared_qkv=shared_qkv, shared_mlp=shared_mlp, mlp_type=mlp_type, + rope_dim = rope_dim, window=window[i], window_method=window_method[i], temporal_window_size=temporal_window_size[i], temporal_shifted=temporal_shifted[i], + is_last_layer=(i == num_layers - 1), + rope_type = rope_type, + shared_weights=not ( + (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] + ), **kwargs, ) for i in range(num_layers) @@ -1241,6 +1273,20 @@ class NaDiT(nn.Module): "mmdit_stwin_3d_spatial", ] + self.vid_out_norm = None + if vid_out_norm is not None: + self.vid_out_norm = RMSNorm( + normalized_shape=vid_dim, + eps=norm_eps, + elementwise_affine=True, + ) + self.vid_out_ada = ada( + dim=vid_dim, + emb_dim=emb_dim, + layers=["out"], + modes=["in"], + ) + def forward( self, x, @@ -1284,6 +1330,18 @@ class NaDiT(nn.Module): cache=cache, ) + if self.vid_out_norm: + vid = self.vid_out_norm(vid) + vid = self.vid_out_ada( + vid, + emb=emb, + layer="out", + mode="in", + hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), + cache=cache, + branch_tag="vid", + ) + vid, vid_shape = self.vid_out(vid, vid_shape, cache) vid = unflatten(vid, vid_shape) return vid From f030b3afc8ba99637f13dbf84dc54c0cb74f0bfc Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 9 Dec 2025 00:16:17 +0200 Subject: [PATCH 05/35] mostly fixing mistakes --- comfy/ldm/seedvr/model.py | 32 ++++++++++++---------- comfy/ldm/seedvr/vae.py | 52 +++++------------------------------- comfy/model_detection.py | 2 ++ comfy/supported_models.py | 7 ++--- comfy_extras/nodes_seedvr.py | 25 ++++++++++------- nodes.py | 3 ++- 6 files changed, 49 insertions(+), 72 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 86836468f..42567fa30 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1187,8 +1187,12 @@ class NaDiT(nn.Module): rope_dim = 128, rope_type = "mmrope3d", vid_out_norm: Optional[str] = None, + device = None, + dtype = None, + operations = None, **kwargs, ): + self.dtype = dtype window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] txt_dim = vid_dim emb_dim = vid_dim * 6 @@ -1292,33 +1296,33 @@ class NaDiT(nn.Module): x, timestep, context, # l c - txt_shape, # b 1 disable_cache: bool = True, # for test # TODO ? + **kwargs ): + transformer_options = kwargs.get("transformer_options", {}) + c_or_u_list = transformer_options.get("cond_or_uncond", []) + cond_latent = c_or_u_list[0]["condition"] + pos_cond, neg_cond = context.chunk(2, dim=0) - pos_cond, pos_shape = flatten(pos_cond) - neg_cond, neg_shape = flatten(neg_cond) - diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) - if pos_shape.shape[1] > neg_shape.shape[1]: - neg_shape = F.pad(neg_shape, (0, 0, 0, diff)) - neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) - else: - pos_shape = F.pad(pos_shape, (0, 0, 0, diff)) - pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) + # txt_shape should be the same for both + pos_cond, txt_shape = flatten(pos_cond) + neg_cond, _ = flatten(neg_cond) + txt = torch.cat([pos_cond, neg_cond], dim = 0) + txt_shape[0] *= 2 + vid = x - txt = context vid, vid_shape = flatten(x) + + vid = torch.cat([cond_latent, vid]) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - # slice vid after patching in when using sequence parallelism + txt = self.txt_in(txt) vid, vid_shape = self.vid_in(vid, vid_shape) - # Embedding input. emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - # Body cache = Cache(disable=disable_cache) for i, block in enumerate(self.blocks): vid, txt, vid_shape, txt_shape = block( diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 3a0f8cfed..40c592a2b 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -3,10 +3,9 @@ from typing import Literal, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.models.attention_processor import Attention from einops import rearrange -from model import safe_pad_operation +from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution from comfy.ldm.modules.attention import optimized_attention @@ -216,67 +215,37 @@ class Attention(nn.Module): return hidden_states -def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution weight matrix to a 3D one. - Parameters: - weight_2d: The weight matrix of 2D conv to be inflated. - weight_3d: The weight matrix of 3D conv to be initialized. - inflation_mode: the mode of inflation - """ - assert inflation_mode in ["tail", "replicate"] - assert weight_3d.shape[:2] == weight_2d.shape[:2] +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor): with torch.no_grad(): - if inflation_mode == "replicate": - depth = weight_3d.size(2) - weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) - else: - weight_3d.fill_(0.0) - weight_3d[:, :, -1].copy_(weight_2d) + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) return weight_3d - -def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution bias tensor to a 3D one - Parameters: - bias_2d: The bias tensor of 2D conv to be inflated. - bias_3d: The bias tensor of 3D conv to be initialized. - inflation_mode: Placeholder to align `inflate_weight`. - """ - assert bias_3d.shape == bias_2d.shape +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor): with torch.no_grad(): bias_3d.copy_(bias_2d) return bias_3d def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): - """ - the main function to inflated 2D parameters to 3D. - """ weight_name = prefix + "weight" bias_name = prefix + "bias" if weight_name in state_dict: weight_2d = state_dict[weight_name] if weight_2d.dim() == 4: - # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) weight_3d = inflate_weight_fn( weight_2d=weight_2d, weight_3d=layer.weight, - inflation_mode=layer.inflation_mode, ) state_dict[weight_name] = weight_3d else: return state_dict - # It's a 3d state dict, should not do inflation on both bias and weight. if bias_name in state_dict: bias_2d = state_dict[bias_name] if bias_2d.dim() == 1: - # Assuming the 2D biases are 1D tensors (out_channels,) bias_3d = inflate_bias_fn( bias_2d=bias_2d, bias_3d=layer.bias, - inflation_mode=layer.inflation_mode, ) state_dict[bias_name] = bias_3d return state_dict @@ -384,19 +353,12 @@ class InflatedCausalConv3d(nn.Conv3d): def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - if self.inflation_mode != "none": - state_dict = modify_state_dict( - self, - state_dict, - prefix, - inflate_weight_fn=inflate_weight, - inflate_bias_fn=inflate_bias, - ) + super()._load_from_state_dict( state_dict, prefix, local_metadata, - (strict and self.inflation_mode == "none"), + strict, missing_keys, unexpected_keys, error_msgs, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 600c089fa..804878432 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -344,12 +344,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b dit_config = {} + dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 2560 dit_config["heads"] = 20 dit_config["num_layers"] = 32 dit_config["norm_eps"] = 1.0e-05 dit_config["qk_rope"] = None dit_config["mlp_type"] = "swiglu" + dit_config["vid_out_norm"] = True return dit_config diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2301b1188..4162a1f5e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1154,20 +1154,21 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) -class SeedVR2(supported_models_base.Base): +class SeedVR2(supported_models_base.BASE): unet_config = { - "image_mode": "seedvr2" + "image_model": "seedvr2" } latent_format = comfy.latent_formats.SeedVR2 vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] supported_inference_dtypes = [torch.bfloat16, torch.float32] def get_model(self, state_dict, prefix = "", device=None): out = model_base.SeedVR2(self, device=device) return out def clip_target(self, state_dict={}): - return None + return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel) class ACEStep(supported_models_base.BASE): unet_config = { diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 9d4e8bf34..e2fa10427 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -4,6 +4,7 @@ import torch import math from einops import rearrange +import torch.nn.functional as F from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode @@ -108,12 +109,13 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Int.Input("resolution_width") ], outputs = [ - io.Image.Output("images") + io.Image.Output("processed_images") ] ) @classmethod def execute(cls, images, resolution_height, resolution_width): + images = images.permute(0, 3, 1, 2) max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) @@ -134,7 +136,7 @@ class SeedVR2Conditioning(io.ComfyNode): inputs=[ io.Conditioning.Input("text_positive_conditioning"), io.Conditioning.Input("text_negative_conditioning"), - io.Conditioning.Input("vae_conditioning") + io.Latent.Input("vae_conditioning") ], outputs=[io.Conditioning.Output(display_name = "positive"), io.Conditioning.Output(display_name = "negative"), @@ -143,7 +145,8 @@ class SeedVR2Conditioning(io.ComfyNode): @classmethod def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: - # TODO: should do the flattening logic as with the original code + + vae_conditioning = vae_conditioning["samples"] pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] @@ -160,14 +163,18 @@ class SeedVR2Conditioning(io.ComfyNode): cond = inter(vae_conditioning, aug_noises, t) condition = get_conditions(noises, cond) - # TODO / FIXME - pos_cond = torch.cat([condition, pos_cond], dim = 0) - neg_cond = torch.cat([condition, neg_cond], dim = 0) + pos_shape = pos_cond.shape[1] + neg_shape = neg_shape.shape[1] + diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) + if pos_shape.shape[1] > neg_shape.shape[1]: + neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) + else: + pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - negative = [[pos_cond, {}]] - positive = [[neg_cond, {}]] + negative = [[pos_cond, {"condition": condition}]] + positive = [[neg_cond, {"condition": condition}]] - return io.NodeOutput(positive, negative, noises) + return io.NodeOutput(positive, negative, {"samples": noises}) class SeedVRExtension(ComfyExtension): @override diff --git a/nodes.py b/nodes.py index 1b465b9e6..72e9c6066 100644 --- a/nodes.py +++ b/nodes.py @@ -2283,7 +2283,8 @@ def init_builtin_extra_nodes(): "nodes_string.py", "nodes_camera_trajectory.py", "nodes_edit_model.py", - "nodes_tcfg.py" + "nodes_tcfg.py", + "nodes_seedvr.py" ] import_failed = [] From d12702ee0b8324bb28b62bc3cbb75ba2d1179975 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 9 Dec 2025 23:54:56 +0200 Subject: [PATCH 06/35] fixed some issues --- comfy/ldm/seedvr/vae.py | 2 +- comfy/sd.py | 25 +++++++++++----------- comfy_extras/nodes_seedvr.py | 41 ++++++++++++++++++++++++++---------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 40c592a2b..4a503dde4 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1189,7 +1189,6 @@ class Decoder3D(nn.Module): # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] - print(f"slicing_up_num: {slicing_up_num}") for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] @@ -1450,6 +1449,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): def encode(self, x: torch.FloatTensor): if x.ndim == 4: x = x.unsqueeze(2) + x = x.to(next(self.parameters()).dtype) p = super().encode(x).latent_dist z = p.sample().squeeze(2) return z, p diff --git a/comfy/sd.py b/comfy/sd.py index 79b17073f..186a69703 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -268,7 +268,10 @@ class CLIP: class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) + if (metadata is not None and metadata["keep_diffusers_format"] == "true"): + pass + else: + sd = diffusers_convert.convert_vae_state_dict(sd) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) @@ -326,6 +329,15 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 + self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() + self.memory_used_decode = lambda shape, dtype: (2000 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float32] + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) elif "decoder.conv_in.weight" in sd: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -393,17 +405,6 @@ class VAE: self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] - elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 - self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - ddconfig["conv3d"] = True - ddconfig["time_compress"] = 4 - self.memory_used_decode = lambda shape, dtype: (2000 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[2], 5) * shape[3] * shape[4]) * model_management.dtype_size(dtype) - self.working_dtypes = [torch.bfloat16, torch.float32] - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e2fa10427..e83e37c1d 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -105,27 +105,46 @@ class SeedVR2InputProcessing(io.ComfyNode): category="image/video", inputs = [ io.Image.Input("images"), - io.Int.Input("resolution_height"), - io.Int.Input("resolution_width") + io.Vae.Input("vae"), + io.Int.Input("resolution_height", default = 1280, min = 120), # // + io.Int.Input("resolution_width", default = 720, min = 120) # just non-zero value ], outputs = [ - io.Image.Output("processed_images") + io.Latent.Output("vae_conditioning") ] ) @classmethod - def execute(cls, images, resolution_height, resolution_width): - images = images.permute(0, 3, 1, 2) + def execute(cls, images, vae, resolution_height, resolution_width): + vae = vae.first_stage_model + scale = 0.9152; shift = 0 + + if images.dim() != 5: # add the t dim + images = images.unsqueeze(0) + images = images.permute(0, 1, 4, 2, 3) + + b, t, c, h, w = images.shape + images = images.reshape(b * t, c, h, w) + max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) images = area_resize(images, max_area) + images = clip(images) images = crop(images, (16, 16)) images = normalize(images) - images = rearrange(images, "t c h w -> c t h w") + _, _, new_h, new_w = images.shape + + images = images.reshape(b, t, c, new_h, new_w) images = cut_videos(images) - return io.NodeOutput(images) + + images = rearrange(images, "b t c h w -> b c t h w") + latent = vae.encode(images)[0] + + latent = (latent - shift) * scale + + return io.NodeOutput({"samples": latent}) class SeedVR2Conditioning(io.ComfyNode): @classmethod @@ -150,8 +169,8 @@ class SeedVR2Conditioning(io.ComfyNode): pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] - noises = [torch.randn_like(latent) for latent in vae_conditioning] - aug_noises = [torch.randn_like(latent) for latent in vae_conditioning] + noises = torch.randn_like(vae_conditioning) + aug_noises = torch.randn_like(vae_conditioning) cond_noise_scale = 0.0 t = ( @@ -165,8 +184,8 @@ class SeedVR2Conditioning(io.ComfyNode): pos_shape = pos_cond.shape[1] neg_shape = neg_shape.shape[1] - diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) - if pos_shape.shape[1] > neg_shape.shape[1]: + diff = abs(pos_shape - neg_shape) + if pos_shape > neg_shape: neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) From 413ee3f687a5ec9d287e58a73379a28e1eabb44a Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:58:53 +0200 Subject: [PATCH 07/35] . --- comfy/ldm/seedvr/model.py | 10 +++++++--- comfy/ldm/seedvr/vae.py | 22 ++++++++++++++++++++-- comfy/model_base.py | 6 +++++- comfy_extras/nodes_seedvr.py | 16 ++++++++++------ 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 42567fa30..98121f26f 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1300,8 +1300,7 @@ class NaDiT(nn.Module): **kwargs ): transformer_options = kwargs.get("transformer_options", {}) - c_or_u_list = transformer_options.get("cond_or_uncond", []) - cond_latent = c_or_u_list[0]["condition"] + conditions = kwargs.get("condition") pos_cond, neg_cond = context.chunk(2, dim=0) # txt_shape should be the same for both @@ -1312,11 +1311,16 @@ class NaDiT(nn.Module): vid = x vid, vid_shape = flatten(x) + cond_latent, _ = flatten(conditions) - vid = torch.cat([cond_latent, vid]) + vid = torch.cat([cond_latent, vid], dim=-1) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + txt = txt.to(device).to(dtype) + vid = vid.to(device).to(dtype) txt = self.txt_in(txt) vid, vid_shape = self.vid_in(vid, vid_shape) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 4a503dde4..1086f9adc 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -330,6 +330,17 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', _receptive_field_t = Literal["half", "full"] +def extend_head(tensor, times: int = 2, memory = None): + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + assert times >= 0, "Invalid input for function 'extend_head'!" + if times == 0: + return tensor + else: + tile_repeat = [1] * tensor.ndim + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) + class InflatedCausalConv3d(nn.Conv3d): def __init__( self, @@ -348,6 +359,7 @@ class InflatedCausalConv3d(nn.Conv3d): self, input, ): + input = extend_head(input, times=self.temporal_padding * 2) return super().forward(input) def _load_from_state_dict( @@ -514,6 +526,8 @@ class Downsample3D(nn.Module): self.out_channels = out_channels or channels self.temporal_down = temporal_down self.spatial_down = spatial_down + self.use_conv = use_conv + self.padding = padding self.temporal_ratio = 2 if temporal_down else 1 self.spatial_ratio = 2 if spatial_down else 1 @@ -630,6 +644,7 @@ class ResnetBlock3D(nn.Module): inflation_mode=inflation_mode, ) + self.upsample = self.downsample = None if self.up: self.upsample = Upsample3D( self.in_channels, @@ -646,6 +661,7 @@ class ResnetBlock3D(nn.Module): inflation_mode=inflation_mode, ) + self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = InflatedCausalConv3d( self.in_channels, @@ -1093,6 +1109,7 @@ class Encoder3D(nn.Module): extra_cond=None, ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" + sample = sample.to(next(self.parameters()).device) sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: @@ -1450,8 +1467,9 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): if x.ndim == 4: x = x.unsqueeze(2) x = x.to(next(self.parameters()).dtype) - p = super().encode(x).latent_dist - z = p.sample().squeeze(2) + x = x.to(next(self.parameters()).device) + p = super().encode(x) + z = p.squeeze(2) return z, p def decode(self, z: torch.FloatTensor): diff --git a/comfy/model_base.py b/comfy/model_base.py index bbab8627a..f9cc26bfb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -798,7 +798,11 @@ class HunyuanDiT(BaseModel): class SeedVR2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) - # TODO: extra_conds could be needed to add + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + condition = kwargs.get("condition", None) + out["condition"] = comfy.conds.CONDRegular(condition) + return out class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e83e37c1d..9e8429b66 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -39,7 +39,7 @@ def timestep_transform(timesteps, latents_shapes): frames > 1, vid_shift_fn(heights * widths * frames), img_shift_fn(heights * widths), - ) + ).to(timesteps.device) # Shift timesteps. T = 1000.0 @@ -116,6 +116,7 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): + device = vae.patcher.load_device vae = vae.first_stage_model scale = 0.9152; shift = 0 @@ -140,6 +141,8 @@ class SeedVR2InputProcessing(io.ComfyNode): images = cut_videos(images) images = rearrange(images, "b t c h w -> b c t h w") + vae = vae.to(device) + images = images.to(device) latent = vae.encode(images)[0] latent = (latent - shift) * scale @@ -166,24 +169,25 @@ class SeedVR2Conditioning(io.ComfyNode): def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: vae_conditioning = vae_conditioning["samples"] + device = vae_conditioning.device pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] - noises = torch.randn_like(vae_conditioning) - aug_noises = torch.randn_like(vae_conditioning) + noises = torch.randn_like(vae_conditioning).to(device) + aug_noises = torch.randn_like(vae_conditioning).to(device) cond_noise_scale = 0.0 t = ( torch.tensor([1000.0]) * cond_noise_scale - ) - shape = torch.tensor(vae_conditioning.shape[1:])[None] + ).to(device) + shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] t = timestep_transform(t, shape) cond = inter(vae_conditioning, aug_noises, t) condition = get_conditions(noises, cond) pos_shape = pos_cond.shape[1] - neg_shape = neg_shape.shape[1] + neg_shape = neg_cond.shape[1] diff = abs(pos_shape - neg_shape) if pos_shape > neg_shape: neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) From d629c8f910593b8aee6f4b03ec8deea6baacfdd8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 12 Dec 2025 00:46:23 +0200 Subject: [PATCH 08/35] testing --- comfy/ldm/modules/diffusionmodules/model.py | 6 ++-- comfy/ldm/seedvr/model.py | 26 +++++++++++------ comfy/model_base.py | 3 +- comfy/supported_models.py | 2 +- comfy_extras/nodes_seedvr.py | 32 +++++++++++---------- 5 files changed, 41 insertions(+), 28 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8162742cf..aa37b09bb 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,7 +13,7 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops -def get_timestep_embedding(timesteps, embedding_dim): +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -24,11 +24,13 @@ def get_timestep_embedding(timesteps, embedding_dim): assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - downscale_freq_shift) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 98121f26f..7444e2823 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,10 +1,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable import einops -from einops import rearrange, einsum +from einops import rearrange, einsum, repeat from torch import nn import torch.nn.functional as F -from math import ceil, sqrt, pi +from math import ceil, pi import torch from itertools import chain from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding @@ -12,6 +12,7 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn +import math class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -354,8 +355,8 @@ class RotaryEmbedding(nn.Module): freqs = self.freqs - freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) if should_cache and offset == 0: self.cached_freqs[:seq_len] = freqs.detach() @@ -460,6 +461,7 @@ def apply_rotary_emb( t_middle = t[..., start_index:end_index] t_right = t[..., end_index:] + freqs = freqs.to(t_middle.device) t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) out = torch.cat((t_left, t_transformed, t_right), dim=-1) @@ -560,6 +562,7 @@ class MMModule(nn.Module): vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) if not self.vid_only: txt_module = self.txt if not self.shared_weights else self.all + txt = txt.to(device=vid.device, dtype=vid.dtype) txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) return vid, txt @@ -747,6 +750,7 @@ class NaSwinAttention(NaMMAttention): txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len = txt_len.to(window_count.device) txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) concat_win, unconcat_win = cache_win( @@ -1122,8 +1126,12 @@ class TimeEmbedding(nn.Module): emb = emb.to(dtype) emb = self.proj_in(emb) emb = self.act(emb) + device = next(self.proj_hid.parameters()).device + emb = emb.to(device) emb = self.proj_hid(emb) emb = self.act(emb) + device = next(self.proj_out.parameters()).device + emb = emb.to(device) emb = self.proj_out(emb) return emb @@ -1206,6 +1214,8 @@ class NaDiT(nn.Module): elif len(block_type) != num_layers: raise ValueError("The ``block_type`` list should equal to ``num_layers``.") super().__init__() + self.register_parameter("positive_conditioning", torch.empty((58, 5120))) + self.register_parameter("negative_conditioning", torch.empty((64, 5120))) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, @@ -1303,11 +1313,9 @@ class NaDiT(nn.Module): conditions = kwargs.get("condition") pos_cond, neg_cond = context.chunk(2, dim=0) - # txt_shape should be the same for both - pos_cond, txt_shape = flatten(pos_cond) - neg_cond, _ = flatten(neg_cond) + pos_cond, txt_shape = flatten([pos_cond]) + neg_cond, _ = flatten([neg_cond]) txt = torch.cat([pos_cond, neg_cond], dim = 0) - txt_shape[0] *= 2 vid = x vid, vid_shape = flatten(x) @@ -1321,7 +1329,7 @@ class NaDiT(nn.Module): dtype = next(self.parameters()).dtype txt = txt.to(device).to(dtype) vid = vid.to(device).to(dtype) - txt = self.txt_in(txt) + txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device)) vid, vid_shape = self.vid_in(vid, vid_shape) diff --git a/comfy/model_base.py b/comfy/model_base.py index f9cc26bfb..f685ba161 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -801,7 +801,8 @@ class SeedVR2(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) condition = kwargs.get("condition", None) - out["condition"] = comfy.conds.CONDRegular(condition) + if condition is not None: + out["condition"] = comfy.conds.CONDRegular(condition) return out class PixArt(BaseModel): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4162a1f5e..1cab38f97 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1168,7 +1168,7 @@ class SeedVR2(supported_models_base.BASE): out = model_base.SeedVR2(self, device=device) return out def clip_target(self, state_dict={}): - return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.SD3ClipModel) + return None class ACEStep(supported_models_base.BASE): unet_config = { diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 9e8429b66..8a108f37e 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -15,9 +15,9 @@ def expand_dims(tensor, ndim): def get_conditions(latent, latent_blur): t, h, w, c = latent.shape - cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 + cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype) + #cond[:, ..., :-1] = latent_blur[:] + #cond[:, ..., -1:] = 1.0 return cond def timestep_transform(timesteps, latents_shapes): @@ -144,6 +144,8 @@ class SeedVR2InputProcessing(io.ComfyNode): vae = vae.to(device) images = images.to(device) latent = vae.encode(images)[0] + latent = latent.unsqueeze(2) if latent.ndim == 4 else latent + latent = rearrange(latent, "b c ... -> b ... c") latent = (latent - shift) * scale @@ -156,9 +158,8 @@ class SeedVR2Conditioning(io.ComfyNode): node_id="SeedVR2Conditioning", category="image/video", inputs=[ - io.Conditioning.Input("text_positive_conditioning"), - io.Conditioning.Input("text_negative_conditioning"), - io.Latent.Input("vae_conditioning") + io.Latent.Input("vae_conditioning"), + io.Model.Input("model"), ], outputs=[io.Conditioning.Output(display_name = "positive"), io.Conditioning.Output(display_name = "negative"), @@ -166,12 +167,13 @@ class SeedVR2Conditioning(io.ComfyNode): ) @classmethod - def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: + def execute(cls, vae_conditioning, model) -> io.NodeOutput: vae_conditioning = vae_conditioning["samples"] device = vae_conditioning.device - pos_cond = text_positive_conditioning[0][0] - neg_cond = text_negative_conditioning[0][0] + model = model.model.diffusion_model + pos_cond = model.positive_conditioning + neg_cond = model.negative_conditioning noises = torch.randn_like(vae_conditioning).to(device) aug_noises = torch.randn_like(vae_conditioning).to(device) @@ -181,21 +183,21 @@ class SeedVR2Conditioning(io.ComfyNode): torch.tensor([1000.0]) * cond_noise_scale ).to(device) - shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] + shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim t = timestep_transform(t, shape) cond = inter(vae_conditioning, aug_noises, t) - condition = get_conditions(noises, cond) + condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)]) - pos_shape = pos_cond.shape[1] - neg_shape = neg_cond.shape[1] + pos_shape = pos_cond.shape[0] + neg_shape = neg_cond.shape[0] diff = abs(pos_shape - neg_shape) if pos_shape > neg_shape: neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - negative = [[pos_cond, {"condition": condition}]] - positive = [[neg_cond, {"condition": condition}]] + negative = [[neg_cond, {"condition": condition}]] + positive = [[pos_cond, {"condition": condition}]] return io.NodeOutput(positive, negative, {"samples": noises}) From 768c9cedf801f97bb3f3786302a18f6c1ed68465 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 12 Dec 2025 20:51:40 +0200 Subject: [PATCH 09/35] .. --- comfy/ldm/seedvr/model.py | 163 ++++++++++++++++++++--------------- comfy/ldm/seedvr/vae.py | 41 +++++++-- comfy_extras/nodes_seedvr.py | 13 +-- 3 files changed, 136 insertions(+), 81 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 7444e2823..cbf1383d3 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable import einops -from einops import rearrange, einsum, repeat +from einops import rearrange, repeat +import comfy.model_management from torch import nn +import torch.nn.utils.rnn as rnn_utils import torch.nn.functional as F from math import ceil, pi import torch @@ -559,6 +561,8 @@ class MMModule(nn.Module): torch.FloatTensor, ]: vid_module = self.vid if not self.shared_weights else self.all + device = comfy.model_management.get_torch_device() + vid = vid.to(device) vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) if not self.vid_only: txt_module = self.txt if not self.shared_weights else self.all @@ -616,58 +620,8 @@ class NaMMAttention(nn.Module): self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - if self.rope: - if self.rope.mm: - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache - ) - else: - vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache) - - vid_len = cache("vid_len", lambda: vid_shape.prod(-1)) - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - all_len = cache("all_len", lambda: vid_len + txt_len) - - b = len(vid_len) - vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] - tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] - - q = torch.cat([vq, tq], dim=1) - k = torch.cat([vk, tk], dim=1) - v = torch.cat([vv, tv], dim=1) - - _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) - - attn = optimized_attention(q, k, v, heads = self.heads, skip_reshape=True, skip_output_reshape=True) - attn = attn.flatten(0, 1) # to continue working with the rest of the code - - attn = rearrange(attn, "l h d -> l (h d)") - vid_out, txt_out = unconcat(attn) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - return vid_out, txt_out + def forward(self): + pass def window( hid: torch.FloatTensor, # (L c) @@ -783,23 +737,78 @@ class NaSwinAttention(NaMMAttention): vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) # TODO: continue testing - b = len(vid_len_win) - vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] - tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] + v_lens = vid_len_win.cpu().tolist() + t_lens_batch = txt_len.cpu().tolist() + win_counts = window_count.cpu().tolist() - q = torch.cat([vq, tq], dim=1) - k = torch.cat([vk, tk], dim=1) - v = torch.cat([vv, tv], dim=1) - out = optimized_attention(q, k, v, heads=self.heads, skip_reshape=True, skip_output_reshape=True) - out = out.flatten(0, 1) + vq_l = torch.split(vid_q, v_lens) + vk_l = torch.split(vid_k, v_lens) + vv_l = torch.split(vid_v, v_lens) + + tv_batch = torch.split(txt_v, t_lens_batch) + tv_l = [] + for i, count in enumerate(win_counts): + tv_l.extend([tv_batch[i]] * count) + + current_txt_len = txt_q.shape[0] + expected_batch_len = sum(t_lens_batch) + + if current_txt_len != expected_batch_len: + t_lens_win = txt_len_win.cpu().tolist() + + tq_l = torch.split(txt_q, t_lens_win) + tk_l = torch.split(txt_k, t_lens_win) + else: + tq_batch = torch.split(txt_q, t_lens_batch) + tk_batch = torch.split(txt_k, t_lens_batch) + + tq_l = [] + tk_l = [] + for i, count in enumerate(win_counts): + tq_l.extend([tq_batch[i]] * count) + tk_l.extend([tk_batch[i]] * count) + + q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)] + k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)] + v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)] + + q = rnn_utils.pad_sequence(q_list, batch_first=True) + k = rnn_utils.pad_sequence(k_list, batch_first=True) + v = rnn_utils.pad_sequence(v_list, batch_first=True) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + B, Heads, Max_L, _ = q.shape + combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)] + + attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype) + idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L) + len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1) + + padding_mask = idx >= len_tensor + attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf')) + + out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True) + + out = out.transpose(1, 2) + + out_flat_list = [] + for i, length in enumerate(combined_lens): + out_flat_list.append(out[i, :length]) + + out = torch.cat(out_flat_list, dim=0) - # text pooling vid_out, txt_out = unconcat_win(out) vid_out = rearrange(vid_out, "l h d -> l (h d)") txt_out = rearrange(txt_out, "l h d -> l (h d)") vid_out = window_reverse(vid_out) + device = comfy.model_management.get_torch_device() + vid_out, txt_out = vid_out.to(device), txt_out.to(device) + self.proj_out = self.proj_out.to(device) vid_out, txt_out = self.proj_out(vid_out, txt_out) return vid_out, txt_out @@ -837,6 +846,8 @@ class SwiGLUMLP(nn.Module): self.proj_in = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = x.to(next(self.proj_in.parameters()).device) + self.proj_out = self.proj_out.to(x.device) x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) return x @@ -928,6 +939,7 @@ class NaMMSRTransformerBlock(nn.Module): vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + txt = txt.to(txt_attn.device) vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) @@ -967,12 +979,11 @@ class NaPatchOut(PatchOut): vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, cache: Cache = Cache(disable=True), # for test + vid_shape_before_patchify = None ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: - cache = cache.namespace("patch") - vid_shape_before_patchify = cache.get("vid_shape_before_patchify") t, h, w = self.patch_size vid = self.proj(vid) @@ -1074,6 +1085,16 @@ class AdaSingle(nn.Module): emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] emb = expand_dims(emb, 1, hid.ndim + 1) + if hid_len is not None: + slice_inputs = lambda x, dim: x + emb = cache( + f"emb_repeat_{idx}_{branch_tag}", + lambda: slice_inputs( + torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), + dim=0, + ), + ) + shiftA, scaleA, gateA = emb.unbind(-1) shiftB, scaleB, gateB = ( getattr(self, f"{layer}_shift", None), @@ -1214,8 +1235,8 @@ class NaDiT(nn.Module): elif len(block_type) != num_layers: raise ValueError("The ``block_type`` list should equal to ``num_layers``.") super().__init__() - self.register_parameter("positive_conditioning", torch.empty((58, 5120))) - self.register_parameter("negative_conditioning", torch.empty((64, 5120))) + self.register_buffer("positive_conditioning", torch.empty((58, 5120))) + self.register_buffer("negative_conditioning", torch.empty((64, 5120))) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, @@ -1306,13 +1327,14 @@ class NaDiT(nn.Module): x, timestep, context, # l c - disable_cache: bool = True, # for test # TODO ? + disable_cache: bool = False, # for test # TODO ? // gives an error when set to True **kwargs ): transformer_options = kwargs.get("transformer_options", {}) conditions = kwargs.get("condition") - pos_cond, neg_cond = context.chunk(2, dim=0) + pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0) + pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) pos_cond, txt_shape = flatten([pos_cond]) neg_cond, _ = flatten([neg_cond]) txt = torch.cat([pos_cond, neg_cond], dim = 0) @@ -1331,6 +1353,7 @@ class NaDiT(nn.Module): vid = vid.to(device).to(dtype) txt = self.txt_in(txt.to(next(self.txt_in.parameters()).device)) + vid_shape_before_patchify = vid_shape vid, vid_shape = self.vid_in(vid, vid_shape) emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) @@ -1358,6 +1381,6 @@ class NaDiT(nn.Module): branch_tag="vid", ) - vid, vid_shape = self.vid_out(vid, vid_shape, cache) + vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - return vid + return vid[0] diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 1086f9adc..6c58f044b 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -6,9 +6,31 @@ import torch.nn.functional as F from einops import rearrange from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution from comfy.ldm.modules.attention import optimized_attention +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + sample = torch.randn( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + class SpatialNorm(nn.Module): def __init__( self, @@ -453,7 +475,7 @@ class Upsample3D(nn.Module): else: self.Conv2d_0 = conv - self.norm = False + self.norm = None def forward( self, @@ -1255,6 +1277,7 @@ class Decoder3D(nn.Module): latent_embeds: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: + sample = sample.to(next(self.parameters()).device) sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype @@ -1397,10 +1420,10 @@ class VideoAutoencoderKL(nn.Module): def _decode( self, z: torch.Tensor ) -> torch.Tensor: - _z = z.to(self.device) + latent = z.to(self.device) if self.post_quant_conv is not None: - _z = self.post_quant_conv(_z) - output = self.decoder(_z) + latent = self.post_quant_conv(latent) + output = self.decoder(latent) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: @@ -1473,9 +1496,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return z, p def decode(self, z: torch.FloatTensor): + latent = z.unsqueeze(0) + scale = 0.9152 + shift = 0 + latent = latent / scale + shift + latent = rearrange(latent, "b ... c -> b c ...") + latent = latent.squeeze(2) if z.ndim == 4: z = z.unsqueeze(2) - x = super().decode(z).sample.squeeze(2) + x = super().decode(latent).squeeze(2) return x def preprocess(self, x: torch.Tensor): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 8a108f37e..eebcb7dc0 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -15,9 +15,9 @@ def expand_dims(tensor, ndim): def get_conditions(latent, latent_blur): t, h, w, c = latent.shape - cond = torch.ones([t, h, w, 1], device=latent.device, dtype=latent.dtype) - #cond[:, ..., :-1] = latent_blur[:] - #cond[:, ..., -1:] = 1.0 + cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 return cond def timestep_transform(timesteps, latents_shapes): @@ -117,6 +117,7 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): device = vae.patcher.load_device + offload_device = vae.patcher.offload_device vae = vae.first_stage_model scale = 0.9152; shift = 0 @@ -144,6 +145,7 @@ class SeedVR2InputProcessing(io.ComfyNode): vae = vae.to(device) images = images.to(device) latent = vae.encode(images)[0] + vae = vae.to(offload_device) latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") @@ -196,8 +198,9 @@ class SeedVR2Conditioning(io.ComfyNode): else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - negative = [[neg_cond, {"condition": condition}]] - positive = [[pos_cond, {"condition": condition}]] + cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0) + negative = [[cond, {"condition": condition}]] + positive = [[cond, {"condition": condition}]] return io.NodeOutput(positive, negative, {"samples": noises}) From 58e7cea79635f78a05cfc9e2d936655e1a026520 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 13 Dec 2025 19:48:57 +0200 Subject: [PATCH 10/35] lora, 7b model, cfg --- comfy/ldm/seedvr/model.py | 49 +++++++++++++++++++++++++----------- comfy/model_detection.py | 11 +++++++- comfy/supported_models.py | 3 +++ comfy_extras/nodes_seedvr.py | 5 ++-- 4 files changed, 50 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index cbf1383d3..9b69c85a1 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1331,15 +1331,14 @@ class NaDiT(nn.Module): **kwargs ): transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") - pos_cond, neg_cond = context.squeeze(0).chunk(2, dim=0) + pos_cond, neg_cond = context.chunk(2, dim=0) pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) - pos_cond, txt_shape = flatten([pos_cond]) - neg_cond, _ = flatten([neg_cond]) - txt = torch.cat([pos_cond, neg_cond], dim = 0) + txt, txt_shape = flatten([pos_cond, neg_cond]) - vid = x vid, vid_shape = flatten(x) cond_latent, _ = flatten(conditions) @@ -1360,14 +1359,36 @@ class NaDiT(nn.Module): cache = Cache(disable=disable_cache) for i, block in enumerate(self.blocks): - vid, txt, vid_shape, txt_shape = block( - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) + if ("block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( + vid=args["vid"], + txt=args["txt"], + vid_shape=args["vid_shape"], + txt_shape=args["txt_shape"], + emb=args["emb"], + cache=args["cache"], + ) + return out + out = blocks_replace[("block", i)]({ + "vid":vid, + "txt":txt, + "vid_shape":vid_shape, + "txt_shape":txt_shape, + "emb":emb, + "cache":cache, + }, {"original_block": block_wrap}) + vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] + else: + vid, txt, vid_shape, txt_shape = block( + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) if self.vid_out_norm: vid = self.vid_out_norm(vid) @@ -1383,4 +1404,4 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - return vid[0] + return torch.stack(vid) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 804878432..22e774730 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -342,6 +342,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [300, 512, 512] return dit_config + elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + dit_config["norm_eps"] = 1e-5 + dit_config["qk_rope"] = True + dit_config["mlp_type"] = "normal" + return dit_config elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b dit_config = {} dit_config["image_model"] = "seedvr2" @@ -352,7 +362,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["qk_rope"] = None dit_config["mlp_type"] = "swiglu" dit_config["vid_out_norm"] = True - return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1cab38f97..a5f116327 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1163,6 +1163,9 @@ class SeedVR2(supported_models_base.BASE): vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] supported_inference_dtypes = [torch.bfloat16, torch.float32] + sampling_settings = { + "shift": 1.0, + } def get_model(self, state_dict, prefix = "", device=None): out = model_base.SeedVR2(self, device=device) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index eebcb7dc0..08009b4d9 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -198,9 +198,8 @@ class SeedVR2Conditioning(io.ComfyNode): else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) - cond = torch.cat([pos_cond.unsqueeze(0), neg_cond.unsqueeze(0)]).unsqueeze(0) - negative = [[cond, {"condition": condition}]] - positive = [[cond, {"condition": condition}]] + negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] + positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] return io.NodeOutput(positive, negative, {"samples": noises}) From ebd945ce3d22d50ed95d95f361b72ea1e259b5c9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 17 Dec 2025 00:09:38 +0200 Subject: [PATCH 11/35] vae fix --- comfy/ldm/seedvr/vae.py | 71 +++++++++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 6c58f044b..ef07b24e0 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from torch import Tensor from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention @@ -398,6 +399,11 @@ class InflatedCausalConv3d(nn.Conv3d): error_msgs, ) +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + if times == 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + class Upsample3D(nn.Module): def __init__( @@ -509,6 +515,9 @@ class Upsample3D(nn.Module): z=self.temporal_ratio, ) + if self.temporal_up: + hidden_states[0] = remove_head(hidden_states[0]) + if not self.slicing: hidden_states = hidden_states[0] @@ -1296,11 +1305,55 @@ class Decoder3D(nn.Module): return sample -class VideoAutoencoderKL(nn.Module): +def wavelet_blur(image: Tensor, radius: int): """ - We simply inherit the model code from diffusers + Apply wavelet blur to the input tensor. """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq + +class VideoAutoencoderKL(nn.Module): def __init__( self, in_channels: int = 3, @@ -1478,6 +1531,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor self.freeze_encoder = freeze_encoder + self.original_image_video = None super().__init__(*args, **kwargs) def forward(self, x: torch.FloatTensor): @@ -1487,6 +1541,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return x, z, p def encode(self, x: torch.FloatTensor): + # we need to keep a reference to the image/video so we later can do a colour fix later + self.original_image_video = x if x.ndim == 4: x = x.unsqueeze(2) x = x.to(next(self.parameters()).dtype) @@ -1502,18 +1558,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent / scale + shift latent = rearrange(latent, "b ... c -> b c ...") latent = latent.squeeze(2) + if z.ndim == 4: z = z.unsqueeze(2) x = super().decode(latent).squeeze(2) - return x - def preprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] - assert x.ndim == 4 or x.size(2) % 4 == 1 - return x - - def postprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] + input = rearrange(self.original_image_video[0], "c t h w -> t c h w") + x = wavelet_reconstruction(x, input) return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): From d9f71da998d4754c229c0edbde72c97b727c4a66 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 18 Dec 2025 00:32:14 +0200 Subject: [PATCH 12/35] works --- comfy/ldm/seedvr/model.py | 26 +++++++++++++++++++++----- comfy/samplers.py | 13 +++++++++++++ comfy_extras/nodes_model_advanced.py | 17 +++++++++++++++++ 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 9b69c85a1..e44048447 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1321,6 +1321,14 @@ class NaDiT(nn.Module): layers=["out"], modes=["in"], ) + + self.stop_cfg_index = -1 + + def set_cfg_stop_index(self, cfg): + self.stop_cfg_index = cfg + + def get_cfg_stop_index(self): + return self.stop_cfg_index def forward( self, @@ -1335,14 +1343,17 @@ class NaDiT(nn.Module): blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") - pos_cond, neg_cond = context.chunk(2, dim=0) - pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) - txt, txt_shape = flatten([pos_cond, neg_cond]) + try: + neg_cond, pos_cond = context.chunk(2, dim=0) + pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) + txt, txt_shape = flatten([pos_cond, neg_cond]) + except: + txt, txt_shape = flatten([context.squeeze(0)]) vid, vid_shape = flatten(x) cond_latent, _ = flatten(conditions) - vid = torch.cat([cond_latent, vid], dim=-1) + vid = torch.cat([vid, cond_latent], dim=-1) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) @@ -1404,4 +1415,9 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - return torch.stack(vid) + out = torch.stack(vid) + try: + pos, neg = out.chunk(2) + return torch.cat([neg, pos]) + except: + return out diff --git a/comfy/samplers.py b/comfy/samplers.py index 25ccaf39f..c159055dd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -947,8 +947,21 @@ class CFGGuider: def __call__(self, *args, **kwargs): return self.predict_noise(*args, **kwargs) + + def handle_dynamic_cfg(self, timestep, model_options): + if hasattr(self.model_patcher.model.diffusion_model, "stop_cfg_index"): + stop_index = self.model_patcher.model.diffusion_model.stop_cfg_index + transformer_options = model_options.get("transformer_options", {}) + sigmas = transformer_options.get("sample_sigmas", None) + if sigmas is not None or self.cfg != 1.0: + dist = torch.abs(sigmas - timestep) + i = torch.argmin(dist).item() + + if stop_index == i or (stop_index == -1 and i == len(sigmas) - 2): + self.set_cfg(1.0) def predict_noise(self, x, timestep, model_options={}, seed=None): + self.handle_dynamic_cfg(timestep, model_options) return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ae5d2c563..a42bf2b6a 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -5,6 +5,22 @@ import nodes import torch import node_helpers +class CFGCutoff: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), "cfg_stop_index": ("INT", {"default": -1, "min": -1, })}} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, cfg_stop_index): + diff_model = model.model.diffusion_model + if hasattr(diff_model, "set_cfg_stop_index"): + diff_model.set_cfg_stop_index(cfg_stop_index) + + return (model,) class LCM(comfy.model_sampling.EPS): def calculate_denoised(self, sigma, model_output, model_input): @@ -326,4 +342,5 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingFlux": ModelSamplingFlux, "RescaleCFG": RescaleCFG, "ModelComputeDtype": ModelComputeDtype, + "CFGCutoff": CFGCutoff } From db74a2787097dc70db53c1c8abcf4c77547db190 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:13:41 +0200 Subject: [PATCH 13/35] fix vae issue --- comfy/ldm/seedvr/model.py | 9 +++++++-- comfy/ldm/seedvr/vae.py | 9 +++++++-- comfy_extras/nodes_seedvr.py | 16 ++++++++++------ 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index e44048447..119799592 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1342,6 +1342,8 @@ class NaDiT(nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") + x = x.movedim(1, -1) + conditions = conditions.movedim(1, -1) try: neg_cond, pos_cond = context.chunk(2, dim=0) @@ -1418,6 +1420,9 @@ class NaDiT(nn.Module): out = torch.stack(vid) try: pos, neg = out.chunk(2) - return torch.cat([neg, pos]) - except: + ut = torch.cat([neg, pos]) + out = out.movedim(-1, 1) + return out + except: + out = out.movedim(-1, 1) return out diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index ef07b24e0..277f7a697 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from einops import rearrange from torch import Tensor +import comfy.model_management from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention @@ -1552,15 +1553,19 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return z, p def decode(self, z: torch.FloatTensor): + z = z.movedim(1, -1) latent = z.unsqueeze(0) scale = 0.9152 shift = 0 latent = latent / scale + shift latent = rearrange(latent, "b ... c -> b c ...") latent = latent.squeeze(2) + + if latent.ndim == 4: + latent = latent.unsqueeze(2) - if z.ndim == 4: - z = z.unsqueeze(2) + target_device = comfy.model_management.get_torch_device() + self.to(target_device) x = super().decode(latent).squeeze(2) input = rearrange(self.original_image_video[0], "c t h w -> t c h w") diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 08009b4d9..e4022d209 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -4,6 +4,7 @@ import torch import math from einops import rearrange +import comfy.model_management import torch.nn.functional as F from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize @@ -116,11 +117,12 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): + comfy.model_management.load_models_gpu([vae.patcher], force_full_load=True) device = vae.patcher.load_device - offload_device = vae.patcher.offload_device - vae = vae.first_stage_model - scale = 0.9152; shift = 0 + offload_device = comfy.model_management.intermediate_device() + vae_model = vae.first_stage_model + scale = 0.9152; shift = 0 if images.dim() != 5: # add the t dim images = images.unsqueeze(0) images = images.permute(0, 1, 4, 2, 3) @@ -142,14 +144,14 @@ class SeedVR2InputProcessing(io.ComfyNode): images = cut_videos(images) images = rearrange(images, "b t c h w -> b c t h w") - vae = vae.to(device) images = images.to(device) - latent = vae.encode(images)[0] - vae = vae.to(offload_device) + latent = vae_model.encode(images)[0] + latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") latent = (latent - shift) * scale + latent = latent.to(offload_device) return io.NodeOutput({"samples": latent}) @@ -189,6 +191,8 @@ class SeedVR2Conditioning(io.ComfyNode): t = timestep_transform(t, shape) cond = inter(vae_conditioning, aug_noises, t) condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)]) + condition = condition.movedim(-1, 1) + noises = noises.movedim(-1, 1) pos_shape = pos_cond.shape[0] neg_shape = neg_cond.shape[0] From 74621b9d86a9fd45ca376b152777a5c89850bf78 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:52:10 +0200 Subject: [PATCH 14/35] . --- comfy_extras/nodes_seedvr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e4022d209..e758ad516 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -117,7 +117,7 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): - comfy.model_management.load_models_gpu([vae.patcher], force_full_load=True) + comfy.model_management.load_models_gpu([vae.patcher]) device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() From 7e62f8cc9fa8aa973072388fd27ef1a9ab3a4cc0 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 19 Dec 2025 20:23:39 +0200 Subject: [PATCH 15/35] added var length attention and fixed the vae issue --- comfy/ldm/modules/attention.py | 46 ++++++++++++++++-- comfy/ldm/seedvr/model.py | 85 +++++++++------------------------- comfy/ldm/seedvr/vae.py | 2 +- 3 files changed, 63 insertions(+), 70 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..332c65ffb 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -32,7 +32,7 @@ except ImportError as e: FLASH_ATTENTION_IS_AVAILABLE = False try: - from flash_attn import flash_attn_func + from flash_attn import flash_attn_func, flash_attn_varlen_func FLASH_ATTENTION_IS_AVAILABLE = True except ImportError: if model_management.flash_attention_enabled(): @@ -473,8 +473,29 @@ else: SDP_BATCH_LIMIT = 2**31 @wrap_attn -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if skip_reshape: +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): + if var_length: + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) + assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + if not skip_reshape: + # assumes 2D q, k,v [total_tokens, embed_dim] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) + + b = q.size(0); dim_head = q.shape[-1] + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + + mask = None + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + elif skip_reshape: b, _, _, dim_head = q.shape else: b, _, dim_head = q.shape @@ -492,8 +513,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.ndim == 3: mask = mask.unsqueeze(1) - if SDP_BATCH_LIMIT >= b: + if SDP_BATCH_LIMIT >= b or var_length: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if var_length: + return out.contiguous().transpose(1, 2).values() if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -583,7 +606,20 @@ except AttributeError as error: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" @wrap_attn -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): + if var_length: + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) + assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True" + assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + return flash_attn_varlen_func( + q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False + ) if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 119799592..0825a12ba 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -15,6 +15,11 @@ from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn import math +import logging +try: + from flash_attn import flash_attn_varlen_func +except: + logging.warning("Best results will be achieved with flash attention enabled for SeedVR2") class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -735,70 +740,21 @@ class NaSwinAttention(NaMMAttention): ) else: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - # TODO: continue testing - v_lens = vid_len_win.cpu().tolist() - t_lens_batch = txt_len.cpu().tolist() - win_counts = window_count.cpu().tolist() - vq_l = torch.split(vid_q, v_lens) - vk_l = torch.split(vid_k, v_lens) - vv_l = torch.split(vid_v, v_lens) - - tv_batch = torch.split(txt_v, t_lens_batch) - tv_l = [] - for i, count in enumerate(win_counts): - tv_l.extend([tv_batch[i]] * count) - - current_txt_len = txt_q.shape[0] - expected_batch_len = sum(t_lens_batch) - - if current_txt_len != expected_batch_len: - t_lens_win = txt_len_win.cpu().tolist() - - tq_l = torch.split(txt_q, t_lens_win) - tk_l = torch.split(txt_k, t_lens_win) - else: - tq_batch = torch.split(txt_q, t_lens_batch) - tk_batch = torch.split(txt_k, t_lens_batch) - - tq_l = [] - tk_l = [] - for i, count in enumerate(win_counts): - tq_l.extend([tq_batch[i]] * count) - tk_l.extend([tk_batch[i]] * count) - - q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)] - k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)] - v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)] - - q = rnn_utils.pad_sequence(q_list, batch_first=True) - k = rnn_utils.pad_sequence(k_list, batch_first=True) - v = rnn_utils.pad_sequence(v_list, batch_first=True) - - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - B, Heads, Max_L, _ = q.shape - combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)] - - attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype) - idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L) - len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1) - - padding_mask = idx >= len_tensor - attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf')) - - out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True) - - out = out.transpose(1, 2) - - out_flat_list = [] - for i, length in enumerate(combined_lens): - out_flat_list.append(out[i, :length]) - - out = torch.cat(out_flat_list, dim=0) + out = optimized_attention( + q=concat_win(vid_q, txt_q).bfloat16(), + k=concat_win(vid_k, txt_k).bfloat16(), + v=concat_win(vid_v, txt_v).bfloat16(), + heads=self.heads, skip_reshape=True, var_length = True, + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), + max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), + ) vid_out, txt_out = unconcat_win(out) @@ -807,7 +763,8 @@ class NaSwinAttention(NaMMAttention): vid_out = window_reverse(vid_out) device = comfy.model_management.get_torch_device() - vid_out, txt_out = vid_out.to(device), txt_out.to(device) + dtype = next(self.proj_out.parameters()).dtype + vid_out, txt_out = vid_out.to(device=device, dtype=dtype), txt_out.to(device=device, dtype=dtype) self.proj_out = self.proj_out.to(device) vid_out, txt_out = self.proj_out(vid_out, txt_out) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 277f7a697..ac5e20b8d 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent.unsqueeze(2) target_device = comfy.model_management.get_torch_device() - self.to(target_device) + self.decoder.to(target_device) x = super().decode(latent).squeeze(2) input = rearrange(self.original_image_video[0], "c t h w -> t c h w") From 0d2044a7781b19677f22881e114c7b9159854bfa Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 19 Dec 2025 20:28:09 +0200 Subject: [PATCH 16/35] ... --- comfy_extras/nodes_seedvr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e758ad516..bff358424 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -117,7 +117,6 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): - comfy.model_management.load_models_gpu([vae.patcher]) device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() @@ -145,7 +144,9 @@ class SeedVR2InputProcessing(io.ComfyNode): images = rearrange(images, "b t c h w -> b c t h w") images = images.to(device) + vae_model = vae_model.to(device) latent = vae_model.encode(images)[0] + vae_model = vae_model.to(offload_device) latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") From 4fe772fae9e780d7e46c0fa602d28ecd05819dd5 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 20 Dec 2025 23:20:45 +0200 Subject: [PATCH 17/35] improvements --- comfy/ldm/seedvr/model.py | 2 +- comfy/ldm/seedvr/vae.py | 5 ++++- comfy/sd.py | 2 ++ comfy_extras/nodes_seedvr.py | 22 ++++++++++++++++------ 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 0825a12ba..c1b8a1738 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1377,7 +1377,7 @@ class NaDiT(nn.Module): out = torch.stack(vid) try: pos, neg = out.chunk(2) - ut = torch.cat([neg, pos]) + out = torch.cat([neg, pos]) out = out.movedim(-1, 1) return out except: diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index ac5e20b8d..d3786e85d 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1541,9 +1541,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): x = self.decode(z).sample return x, z, p - def encode(self, x: torch.FloatTensor): + def encode(self, x, orig_dims): # we need to keep a reference to the image/video so we later can do a colour fix later self.original_image_video = x + self.img_dims = orig_dims if x.ndim == 4: x = x.unsqueeze(2) x = x.to(next(self.parameters()).dtype) @@ -1570,6 +1571,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): input = rearrange(self.original_image_video[0], "c t h w -> t c h w") x = wavelet_reconstruction(x, input) + o_h, o_w = self.img_dims + x = x[..., :o_h, :o_w] return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): diff --git a/comfy/sd.py b/comfy/sd.py index 86b5ff2ad..be2ce30a8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -386,6 +386,8 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) + self.process_input = lambda image: image + self.crop_input = False elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index bff358424..2b1d41174 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -68,14 +68,21 @@ def area_resize(image, max_area): interpolation=InterpolationMode.BICUBIC, ) -def crop(image, factor): +def div_pad(image, factor): + height_factor, width_factor = factor height, width = image.shape[-2:] - cropped_height = height - (height % height_factor) - cropped_width = width - (width % width_factor) + pad_height = (height_factor - (height % height_factor)) % height_factor + pad_width = (width_factor - (width % width_factor)) % width_factor + + if pad_height == 0 and pad_width == 0: + return image + + if isinstance(image, torch.Tensor): + padding = (0, pad_width, 0, pad_height) + image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) - image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width)) return image def cut_videos(videos): @@ -120,6 +127,8 @@ class SeedVR2InputProcessing(io.ComfyNode): device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() + main_device = comfy.model_management.get_torch_device() + images = images.to(main_device) vae_model = vae.first_stage_model scale = 0.9152; shift = 0 if images.dim() != 5: # add the t dim @@ -135,7 +144,8 @@ class SeedVR2InputProcessing(io.ComfyNode): images = area_resize(images, max_area) images = clip(images) - images = crop(images, (16, 16)) + o_h, o_w = images.shape[-2:] + images = div_pad(images, (16, 16)) images = normalize(images) _, _, new_h, new_w = images.shape @@ -145,7 +155,7 @@ class SeedVR2InputProcessing(io.ComfyNode): images = rearrange(images, "b t c h w -> b c t h w") images = images.to(device) vae_model = vae_model.to(device) - latent = vae_model.encode(images)[0] + latent = vae_model.encode(images, [o_h, o_w])[0] vae_model = vae_model.to(offload_device) latent = latent.unsqueeze(2) if latent.ndim == 4 else latent From a4e9d071e833b920d8520df6f040ed380d4562e8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:12:46 +0200 Subject: [PATCH 18/35] video works --- comfy/ldm/seedvr/model.py | 11 ++- comfy/ldm/seedvr/vae.py | 35 ++++++-- comfy_extras/nodes_seedvr.py | 154 ++++++++++++++++++++++++++++++++++- 3 files changed, 185 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index c1b8a1738..716d728c2 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -16,10 +16,6 @@ from torch.nn.modules.utils import _triple from torch import nn import math import logging -try: - from flash_attn import flash_attn_varlen_func -except: - logging.warning("Best results will be achieved with flash attention enabled for SeedVR2") class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -1299,6 +1295,9 @@ class NaDiT(nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") + b, tc, h, w = x.shape + x = x.view(b, 16, -1, h, w) + conditions = conditions.view(b, 17, -1, h, w) x = x.movedim(1, -1) conditions = conditions.movedim(1, -1) @@ -1375,11 +1374,11 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) out = torch.stack(vid) + out = out.movedim(-1, 1) + out = rearrange(out, "b c t h w -> b (c t) h w") try: pos, neg = out.chunk(2) out = torch.cat([neg, pos]) - out = out.movedim(-1, 1) return out except: - out = out.movedim(-1, 1) return out diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index d3786e85d..a8f8c31f2 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -9,6 +9,7 @@ from torch import Tensor import comfy.model_management from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention +from comfy_extras.nodes_seedvr import tiled_vae class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): @@ -1450,7 +1451,7 @@ class VideoAutoencoderKL(nn.Module): return posterior - def decode( + def decode_( self, z: torch.Tensor, return_dict: bool = True ): decoded = self.slicing_decode(z) @@ -1541,10 +1542,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): x = self.decode(z).sample return x, z, p - def encode(self, x, orig_dims): + def encode(self, x, orig_dims=None): # we need to keep a reference to the image/video so we later can do a colour fix later - self.original_image_video = x - self.img_dims = orig_dims + #self.original_image_video = x + if orig_dims is not None: + self.img_dims = orig_dims if x.ndim == 4: x = x.unsqueeze(2) x = x.to(next(self.parameters()).dtype) @@ -1554,6 +1556,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return z, p def decode(self, z: torch.FloatTensor): + b, tc, h, w = z.shape + z = z.view(b, 16, -1, h, w) z = z.movedim(1, -1) latent = z.unsqueeze(0) scale = 0.9152 @@ -1567,12 +1571,31 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): target_device = comfy.model_management.get_torch_device() self.decoder.to(target_device) - x = super().decode(latent).squeeze(2) + x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) + #x = super().decode(latent).squeeze(2) + + input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w") + if x.ndim == 4: + x = x.unsqueeze(0) + + # in case of padded frames + t = input.size(0) + x = x[:, :, :t] + + x = rearrange(x, "b c t h w -> (b t) c h w") - input = rearrange(self.original_image_video[0], "c t h w -> t c h w") x = wavelet_reconstruction(x, input) + x = x.unsqueeze(0) o_h, o_w = self.img_dims x = x[..., :o_h, :o_w] + x = rearrange(x, "b t c h w -> b c t h w") + + # ensure even dims for save video + h, w = x.shape[-2:] + w2 = w - (w % 2) + h2 = h - (h % 2) + x = x[..., :h2, :w2] + return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 2b1d41174..e3281b5f3 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -4,12 +4,146 @@ import torch import math from einops import rearrange +import gc import comfy.model_management +from comfy.utils import ProgressBar + import torch.nn.functional as F from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode +@torch.inference_mode() +def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True): + + gc.collect() + torch.cuda.empty_cache() + + if x.ndim != 5: + x = x.unsqueeze(2) + + b, c, d, h, w = x.shape + + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) + sf_t = getattr(vae_model, "temporal_downsample_factor", 4) + + if encode: + ti_h, ti_w = tile_size + ov_h, ov_w = tile_overlap + ti_t = temporal_size + ov_t = temporal_overlap + + target_d = (d + sf_t - 1) // sf_t + target_h = (h + sf_s - 1) // sf_s + target_w = (w + sf_s - 1) // sf_s + else: + ti_h = max(1, tile_size[0] // sf_s) + ti_w = max(1, tile_size[1] // sf_s) + ov_h = max(0, tile_overlap[0] // sf_s) + ov_w = max(0, tile_overlap[1] // sf_s) + ti_t = max(1, temporal_size // sf_t) + ov_t = max(0, temporal_overlap // sf_t) + + target_d = d * sf_t + target_h = h * sf_s + target_w = w * sf_s + + stride_t = max(1, ti_t - ov_t) + stride_h = max(1, ti_h - ov_h) + stride_w = max(1, ti_w - ov_w) + + storage_device = torch.device("cpu") + result = None + count = None + + ramp_cache = {} + def get_ramp(steps): + if steps not in ramp_cache: + t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) + ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) + return ramp_cache[steps] + + bar = ProgressBar(d // stride_t) + for t_idx in range(0, d, stride_t): + t_end = min(t_idx + ti_t, d) + + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) + + tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end] + + if encode: + tile_out = vae_model.encode(tile_x)[0] + else: + tile_out = vae_model.decode_(tile_x) + + if tile_out.ndim == 4: + tile_out = tile_out.unsqueeze(2) + + tile_out = tile_out.to(storage_device).float() + + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + + if encode: + ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2] + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + + cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2)) + cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) + else: + ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2] + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + + cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2)) + cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) + + w_t = torch.ones((tile_out.shape[2],), device=storage_device) + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) + + if cur_ov_t > 0: + r = get_ramp(cur_ov_t) + if t_idx > 0: w_t[:cur_ov_t] = r + if t_end < d: w_t[-cur_ov_t:] = 1.0 - r + + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: w_h[:cur_ov_h] = r + if y_end < h: w_h[-cur_ov_h:] = 1.0 - r + + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: w_w[:cur_ov_w] = r + if x_end < w: w_w[-cur_ov_w:] = 1.0 - r + + final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) + + tile_out.mul_(final_weight) + result[:, :, ts:te, ys:ye, xs:xe] += tile_out + count[:, :, ts:te, ys:ye, xs:xe] += final_weight + + del tile_out, final_weight, tile_x, w_t, w_h, w_w + bar.update(1) + result.div_(count.clamp(min=1e-6)) + + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + + return result + def expand_dims(tensor, ndim): shape = tensor.shape + (1,) * (ndim - tensor.ndim) return tensor.reshape(shape) @@ -115,7 +249,11 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Image.Input("images"), io.Vae.Input("vae"), io.Int.Input("resolution_height", default = 1280, min = 120), # // - io.Int.Input("resolution_width", default = 720, min = 120) # just non-zero value + io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value + io.Int.Input("spatial_tile_size", default = 512, min = -1), + io.Int.Input("temporal_tile_size", default = 8, min = -1), + io.Int.Input("spatial_overlap", default = 64, min = -1), + io.Int.Input("temporal_overlap", default = 8, min = -1), ], outputs = [ io.Latent.Output("vae_conditioning") @@ -123,7 +261,7 @@ class SeedVR2InputProcessing(io.ComfyNode): ) @classmethod - def execute(cls, images, vae, resolution_height, resolution_width): + def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap): device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() @@ -155,8 +293,15 @@ class SeedVR2InputProcessing(io.ComfyNode): images = rearrange(images, "b t c h w -> b c t h w") images = images.to(device) vae_model = vae_model.to(device) - latent = vae_model.encode(images, [o_h, o_w])[0] + vae_model.original_image_video = images + + args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), + "temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap} + vae_model.tiled_args = args + latent = tiled_vae(images, vae_model, encode=True, **args) + vae_model = vae_model.to(offload_device) + vae_model.img_dims = [o_h, o_w] latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") @@ -213,6 +358,9 @@ class SeedVR2Conditioning(io.ComfyNode): else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) + noises = rearrange(noises, "b c t h w -> b (c t) h w") + condition = rearrange(condition, "b c t h w -> b (c t) h w") + negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] From 5db5da790f9980a1fa188074b15770c473b70e5a Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 22 Dec 2025 21:15:12 +0200 Subject: [PATCH 19/35] remove cfg cutoff node --- comfy/samplers.py | 17 +---------------- comfy_extras/nodes_model_advanced.py | 17 ----------------- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 934310930..fa09c71af 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -949,19 +949,8 @@ class CFGGuider: for k in conds: self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs): return self.outer_predict_noise(*args, **kwargs) - def handle_dynamic_cfg(self, timestep, model_options): - if hasattr(self.model_patcher.model.diffusion_model, "stop_cfg_index"): - stop_index = self.model_patcher.model.diffusion_model.stop_cfg_index - transformer_options = model_options.get("transformer_options", {}) - sigmas = transformer_options.get("sample_sigmas", None) - if sigmas is not None or self.cfg != 1.0: - dist = torch.abs(sigmas - timestep) - i = torch.argmin(dist).item() - - if stop_index == i or (stop_index == -1 and i == len(sigmas) - 2): - self.set_cfg(1.0) def outer_predict_noise(self, x, timestep, model_options={}, seed=None): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -971,7 +960,6 @@ class CFGGuider: ).execute(x, timestep, model_options, seed) def predict_noise(self, x, timestep, model_options={}, seed=None): - self.handle_dynamic_cfg(timestep, model_options) return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None): @@ -996,9 +984,6 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - if denoise_mask is not None: - denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) - noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index a42bf2b6a..ae5d2c563 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -5,22 +5,6 @@ import nodes import torch import node_helpers -class CFGCutoff: - @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL",), "cfg_stop_index": ("INT", {"default": -1, "min": -1, })}} - - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "advanced/model" - - def patch(self, model, cfg_stop_index): - diff_model = model.model.diffusion_model - if hasattr(diff_model, "set_cfg_stop_index"): - diff_model.set_cfg_stop_index(cfg_stop_index) - - return (model,) class LCM(comfy.model_sampling.EPS): def calculate_denoised(self, sigma, model_output, model_input): @@ -342,5 +326,4 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingFlux": ModelSamplingFlux, "RescaleCFG": RescaleCFG, "ModelComputeDtype": ModelComputeDtype, - "CFGCutoff": CFGCutoff } From fc5fabb6291feffb19ad967c3ec080390ae5d435 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 22 Dec 2025 21:16:21 +0200 Subject: [PATCH 20/35] . --- comfy/samplers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index fa09c71af..fa4640842 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -984,6 +984,9 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device + if denoise_mask is not None: + denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) + noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) From 98b6bfcb71b72655ac13976fe3f69faeaff55309 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 22 Dec 2025 21:46:40 +0200 Subject: [PATCH 21/35] revert file perm. --- comfy_extras/nodes_model_advanced.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 comfy_extras/nodes_model_advanced.py diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py old mode 100644 new mode 100755 From e30298dda2b6b1844d9104fb28339b3caeca1e36 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 22 Dec 2025 21:49:48 +0200 Subject: [PATCH 22/35] .. --- comfy_extras/nodes_model_advanced.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 comfy_extras/nodes_model_advanced.py diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py old mode 100755 new mode 100644 From 5b0c80a09363f057084d29deef51fbb35f52c502 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 23 Dec 2025 12:35:00 +0200 Subject: [PATCH 23/35] ruff --- comfy/ldm/seedvr/model.py | 19 ++++++++----------- comfy/ldm/seedvr/vae.py | 36 ++++++++++++++++++------------------ comfy/model_base.py | 2 +- comfy/model_detection.py | 2 +- comfy/supported_models.py | 2 +- comfy_extras/nodes_seedvr.py | 28 ++++++++++++++-------------- 6 files changed, 43 insertions(+), 46 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 716d728c2..cf3ebd520 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,10 +1,8 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable import einops -from einops import rearrange, repeat +from einops import rearrange import comfy.model_management -from torch import nn -import torch.nn.utils.rnn as rnn_utils import torch.nn.functional as F from math import ceil, pi import torch @@ -15,7 +13,6 @@ from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn import math -import logging class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -126,7 +123,7 @@ def safe_pad_operation(x, padding, mode='constant', value=0.0): """Safe padding operation that handles Half precision only for problematic modes""" # Modes qui nécessitent le fix Half precision problematic_modes = ['replicate', 'reflect', 'circular'] - + if mode in problematic_modes: try: return F.pad(x, padding, mode=mode, value=value) @@ -189,7 +186,7 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup resized_h, resized_w = round(h * scale), round(w * scale) wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. wt = ceil(min(t, 30) / resized_nt) # window size. - + st, sh, sw = ( # shift size. 0.5 if wt < t else 0, 0.5 if wh < h else 0, @@ -466,7 +463,7 @@ def apply_rotary_emb( freqs = freqs.to(t_middle.device) t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) - + out = torch.cat((t_left, t_transformed, t_right), dim=-1) return out.type(dtype) @@ -655,7 +652,7 @@ class NaSwinAttention(NaMMAttention): self, *args, window: Union[int, Tuple[int, int, int]], - window_method: bool, # shifted or not + window_method: bool, # shifted or not **kwargs, ): super().__init__(*args, **kwargs) @@ -765,7 +762,7 @@ class NaSwinAttention(NaMMAttention): vid_out, txt_out = self.proj_out(vid_out, txt_out) return vid_out, txt_out - + class MLP(nn.Module): def __init__( self, @@ -1274,7 +1271,7 @@ class NaDiT(nn.Module): layers=["out"], modes=["in"], ) - + self.stop_cfg_index = -1 def set_cfg_stop_index(self, cfg): @@ -1290,7 +1287,7 @@ class NaDiT(nn.Module): context, # l c disable_cache: bool = False, # for test # TODO ? // gives an error when set to True **kwargs - ): + ): transformer_options = kwargs.get("transformer_options", {}) patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index a8f8c31f2..f30646dda 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -317,26 +317,26 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', """Safe interpolate operation that handles Half precision for problematic modes""" # Modes qui peuvent causer des problèmes avec Half precision problematic_modes = ['bilinear', 'bicubic', 'trilinear'] - + if mode in problematic_modes: try: return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, + x, + size=size, + scale_factor=scale_factor, + mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ) except RuntimeError as e: - if ("not implemented for 'Half'" in str(e) or + if ("not implemented for 'Half'" in str(e) or "compute_indices_weights" in str(e)): original_dtype = x.dtype return F.interpolate( - x.float(), - size=size, - scale_factor=scale_factor, - mode=mode, + x.float(), + size=size, + scale_factor=scale_factor, + mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ).to(original_dtype) @@ -345,10 +345,10 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', else: # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, + x, + size=size, + scale_factor=scale_factor, + mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ) @@ -426,7 +426,7 @@ class Upsample3D(nn.Module): **kwargs, ): super().__init__() - self.interpolate = interpolate + self.interpolate = interpolate self.channels = channels self.out_channels = out_channels or channels self.use_conv_transpose = use_conv_transpose @@ -444,7 +444,7 @@ class Upsample3D(nn.Module): if kernel_size is None: kernel_size = 3 self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) - + conv = self.conv if self.name == "conv" else self.Conv2d_0 assert type(conv) is not nn.ConvTranspose2d @@ -587,7 +587,7 @@ class Downsample3D(nn.Module): kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), ) - + self.conv = conv @@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent / scale + shift latent = rearrange(latent, "b ... c -> b c ...") latent = latent.squeeze(2) - + if latent.ndim == 4: latent = latent.unsqueeze(2) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2b354f418..53f953710 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -815,7 +815,7 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out - + class SeedVR2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f1312c3ab..886409d47 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -445,7 +445,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["pad_tokens_multiple"] = 32 return dit_config - + elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b dit_config = {} dit_config["image_model"] = "seedvr2" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1c325524d..9bbb1d0cd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1287,7 +1287,7 @@ class Chroma(supported_models_base.BASE): pref = self.text_encoder_key_prefix[0] t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) - + class SeedVR2(supported_models_base.BASE): unet_config = { "image_model": "seedvr2" diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e3281b5f3..ce5437517 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -15,7 +15,7 @@ from torchvision.transforms.functional import InterpolationMode @torch.inference_mode() def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True): - + gc.collect() torch.cuda.empty_cache() @@ -23,7 +23,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora x = x.unsqueeze(2) b, c, d, h, w = x.shape - + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) sf_t = getattr(vae_model, "temporal_downsample_factor", 4) @@ -32,7 +32,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ov_h, ov_w = tile_overlap ti_t = temporal_size ov_t = temporal_overlap - + target_d = (d + sf_t - 1) // sf_t target_h = (h + sf_s - 1) // sf_s target_w = (w + sf_s - 1) // sf_s @@ -55,7 +55,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora storage_device = torch.device("cpu") result = None count = None - + ramp_cache = {} def get_ramp(steps): if steps not in ramp_cache: @@ -66,10 +66,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora bar = ProgressBar(d // stride_t) for t_idx in range(0, d, stride_t): t_end = min(t_idx + ti_t, d) - + for y_idx in range(0, h, stride_h): y_end = min(y_idx + ti_h, h) - + for x_idx in range(0, w, stride_w): x_end = min(x_idx + ti_w, w) @@ -94,7 +94,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2] ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] - + cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2)) cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) @@ -115,7 +115,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora r = get_ramp(cur_ov_t) if t_idx > 0: w_t[:cur_ov_t] = r if t_end < d: w_t[-cur_ov_t:] = 1.0 - r - + if cur_ov_h > 0: r = get_ramp(cur_ov_h) if y_idx > 0: w_h[:cur_ov_h] = r @@ -131,11 +131,11 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora tile_out.mul_(final_weight) result[:, :, ts:te, ys:ye, xs:xe] += tile_out count[:, :, ts:te, ys:ye, xs:xe] += final_weight - + del tile_out, final_weight, tile_x, w_t, w_h, w_w bar.update(1) result.div_(count.clamp(min=1e-6)) - + if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -238,7 +238,7 @@ def cut_videos(videos): videos = torch.cat([videos, padding], dim=1) assert (videos.size(1) - 1) % (4) == 0 return videos - + class SeedVR2InputProcessing(io.ComfyNode): @classmethod def define_schema(cls): @@ -259,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Latent.Output("vae_conditioning") ] ) - + @classmethod def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap): device = vae.patcher.load_device @@ -271,7 +271,7 @@ class SeedVR2InputProcessing(io.ComfyNode): scale = 0.9152; shift = 0 if images.dim() != 5: # add the t dim images = images.unsqueeze(0) - images = images.permute(0, 1, 4, 2, 3) + images = images.permute(0, 1, 4, 2, 3) b, t, c, h, w = images.shape images = images.reshape(b * t, c, h, w) @@ -328,7 +328,7 @@ class SeedVR2Conditioning(io.ComfyNode): @classmethod def execute(cls, vae_conditioning, model) -> io.NodeOutput: - + vae_conditioning = vae_conditioning["samples"] device = vae_conditioning.device model = model.model.diffusion_model From d41b1111eb2cf6ba4557e2fe1cf36c083870567f Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 23 Dec 2025 12:36:10 +0200 Subject: [PATCH 24/35] removed print statement --- comfy/ldm/seedvr/vae.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index f30646dda..da2bb2c2f 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -962,10 +962,6 @@ class UNetMidBlock3D(nn.Module): attentions = [] if attention_head_dim is None: - print( - f"It is not recommend to pass `attention_head_dim=None`. " - f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." - ) attention_head_dim = in_channels for _ in range(num_layers): From 1afc2ed8e60b5206c2475825b4191309a6ad5234 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 24 Dec 2025 02:23:57 +0200 Subject: [PATCH 25/35] fixed the speed issue --- comfy/ldm/seedvr/vae.py | 488 ++++++++++++++++++++++++++++------- comfy_extras/nodes_seedvr.py | 157 +++++------ 2 files changed, 485 insertions(+), 160 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index da2bb2c2f..0c7fa5c5f 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -5,12 +5,57 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import Tensor +from contextlib import contextmanager import comfy.model_management from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention from comfy_extras.nodes_seedvr import tiled_vae +import math +from enum import Enum +from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND + +_NORM_LIMIT = float("inf") + + +def get_norm_limit(): + return _NORM_LIMIT + + +def set_norm_limit(value: Optional[float] = None): + global _NORM_LIMIT + if value is None: + value = float("inf") + _NORM_LIMIT = value + +@contextmanager +def ignore_padding(model): + orig_padding = model.padding + model.padding = (0, 0, 0) + try: + yield + finally: + model.padding = orig_padding + +class MemoryState(Enum): + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + UNSET = 3 + +def get_cache_size(conv_module, input_len, pad_len, dim=0): + dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + remain_len = ( + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + ) + overlap_len = dilated_kernerl_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len # >= 0 + + assert output_len > 0 + return cache_len + class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters @@ -34,6 +79,9 @@ class DiagonalGaussianDistribution(object): x = self.mean + self.std * sample return x + def mode(self): + return self.mean + class SpatialNorm(nn.Module): def __init__( self, @@ -366,41 +414,233 @@ def extend_head(tensor, times: int = 2, memory = None): tile_repeat[2] = times return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) -class InflatedCausalConv3d(nn.Conv3d): +def cache_send_recv(tensor, cache_size, times, memory=None): + # Single GPU inference - simplified cache handling + recv_buffer = None + + # Handle memory buffer for single GPU case + if memory is not None: + recv_buffer = memory.to(tensor[0]) + elif times > 0: + tile_repeat = [1] * tensor[0].ndim + tile_repeat[2] = times + recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) + + return recv_buffer + +class InflatedCausalConv3d(torch.nn.Conv3d): def __init__( self, *args, inflation_mode, + memory_device = "same", **kwargs, ): self.inflation_mode = inflation_mode self.memory = None super().__init__(*args, **kwargs) self.temporal_padding = self.padding[0] + self.memory_device = memory_device self.padding = (0, *self.padding[1:]) self.memory_limit = float("inf") + def set_memory_limit(self, value: float): + self.memory_limit = value + + def set_memory_device(self, memory_device): + self.memory_device = memory_device + + def _conv_forward(self, input, weight, bias, *args, **kwargs): + if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and + weight.dtype in (torch.float16, torch.bfloat16) and + hasattr(torch.backends.cudnn, 'is_available') and + torch.backends.cudnn.is_available() and + getattr(torch.backends.cudnn, 'enabled', True)): + try: + out = torch.cudnn_convolution( + input, weight, self.padding, self.stride, self.dilation, self.groups, + benchmark=False, deterministic=False, allow_tf32=True + ) + if bias is not None: + out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) + return out + except RuntimeError: + pass + + return super()._conv_forward(input, weight, bias, *args, **kwargs) + + def memory_limit_conv( + self, + x, + *, + split_dim=3, + padding=(0, 0, 0, 0, 0, 0), + prev_cache=None, + ): + # Compatible with no limit. + if math.isinf(self.memory_limit): + if prev_cache is not None: + x = torch.cat([prev_cache, x], dim=split_dim - 1) + return super().forward(x) + + # Compute tensor shape after concat & padding. + shape = torch.tensor(x.size()) + if prev_cache is not None: + shape[split_dim - 1] += prev_cache.size(split_dim - 1) + shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) + memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + if memory_occupy < self.memory_limit or split_dim == x.ndim: + x_concat = x + if prev_cache is not None: + x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) + + def pad_and_forward(): + padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + with ignore_padding(self): + return torch.nn.Conv3d.forward(self, padded) + + return pad_and_forward() + + num_splits = math.ceil(memory_occupy / self.memory_limit) + size_per_split = x.size(split_dim) // num_splits + split_sizes = [size_per_split] * (num_splits - 1) + split_sizes += [x.size(split_dim) - sum(split_sizes)] + + x = list(x.split(split_sizes, dim=split_dim)) + if prev_cache is not None: + prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) + cache = None + for idx in range(len(x)): + if prev_cache is not None: + x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) + + lpad_dim = (x[idx].ndim - split_dim - 1) * 2 + rpad_dim = lpad_dim + 1 + padding = list(padding) + padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 + padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 + pad_len = padding[lpad_dim] + padding[rpad_dim] + padding = tuple(padding) + + next_cache = None + cache_len = cache.size(split_dim) if cache is not None else 0 + next_catch_size = get_cache_size( + conv_module=self, + input_len=x[idx].size(split_dim) + cache_len, + pad_len=pad_len, + dim=split_dim - 2, + ) + if next_catch_size != 0: + assert next_catch_size <= x[idx].size(split_dim) + next_cache = ( + x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + ) + + x[idx] = self.memory_limit_conv( + x[idx], + split_dim=split_dim + 1, + padding=padding, + prev_cache=cache + ) + + cache = next_cache + + output = torch.cat(x, dim=split_dim) + return output + def forward( self, input, - ): - input = extend_head(input, times=self.temporal_padding * 2) + memory_state: MemoryState = MemoryState.UNSET + ) -> Tensor: + assert memory_state != MemoryState.UNSET + if memory_state != MemoryState.ACTIVE: + self.memory = None + if ( + math.isinf(self.memory_limit) + and torch.is_tensor(input) + ): + return self.basic_forward(input, memory_state) + return self.slicing_forward(input, memory_state) + + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + mem_size = self.stride[0] - self.kernel_size[0] + if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.memory, times=-1) + else: + input = extend_head(input, times=self.temporal_padding * 2) + memory = ( + input[:, :, mem_size:].detach() + if (mem_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if ( + memory_state != MemoryState.DISABLED + and not self.training + and (self.memory_device is not None) + ): + self.memory = memory + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") return super().forward(input) - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): + def slicing_forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET, + ) -> Tensor: + squeeze_out = False + if torch.is_tensor(input): + input = [input] + squeeze_out = True - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, + cache_size = self.kernel_size[0] - self.stride[0] + cache = cache_send_recv( + input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 ) + # Single GPU inference - simplified memory management + if ( + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + and not self.training + and (self.memory_device is not None) + and cache_size != 0 + ): + if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: + input[0] = torch.cat([cache, input[0]], dim=2) + cache = None + if cache_size <= input[-1].size(2): + self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + + padding = tuple(x for x in reversed(self.padding) for _ in range(2)) + for i in range(len(input)): + # Prepare cache for next input slice. + next_cache = None + cache_size = 0 + if i < len(input) - 1: + cache_len = cache.size(2) if cache is not None else 0 + cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) + if cache_size != 0: + if cache_size > input[i].size(2) and cache is not None: + input[i] = torch.cat([cache, input[i]], dim=2) + cache = None + assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + next_cache = input[i][:, :, -cache_size:] + + # Conv forward for this input slice. + input[i] = self.memory_limit_conv( + input[i], + padding=padding, + prev_cache=cache + ) + + # Update cache. + cache = next_cache + + return input[0] if squeeze_out else input + def remove_head(tensor: Tensor, times: int = 1) -> Tensor: if times == 0: return tensor @@ -488,6 +728,7 @@ class Upsample3D(nn.Module): def forward( self, hidden_states: torch.FloatTensor, + memory_state=None, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels @@ -517,7 +758,7 @@ class Upsample3D(nn.Module): z=self.temporal_ratio, ) - if self.temporal_up: + if self.temporal_up and memory_state != MemoryState.ACTIVE: hidden_states[0] = remove_head(hidden_states[0]) if not self.slicing: @@ -525,9 +766,9 @@ class Upsample3D(nn.Module): if self.use_conv: if self.name == "conv": - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, memory_state=memory_state) else: - hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) if not self.slicing: return hidden_states @@ -594,6 +835,7 @@ class Downsample3D(nn.Module): def forward( self, hidden_states: torch.FloatTensor, + memory_state = None, **kwargs, ) -> torch.FloatTensor: @@ -609,7 +851,7 @@ class Downsample3D(nn.Module): assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, memory_state=memory_state) return hidden_states @@ -707,7 +949,7 @@ class ResnetBlock3D(nn.Module): ) def forward( - self, input_tensor, temb, **kwargs + self, input_tensor, temb, memory_state = None, **kwargs ): hidden_states = input_tensor @@ -719,13 +961,13 @@ class ResnetBlock3D(nn.Module): if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) + input_tensor = self.upsample(input_tensor, memory_state=memory_state) + hidden_states = self.upsample(hidden_states, memory_state=memory_state) elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) + input_tensor = self.downsample(input_tensor, memory_state=memory_state) + hidden_states = self.downsample(hidden_states, memory_state=memory_state) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, memory_state=memory_state) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -740,10 +982,10 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, memory_state=memory_state) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor @@ -819,15 +1061,16 @@ class DownEncoderBlock3D(nn.Module): def forward( self, hidden_states: torch.FloatTensor, + memory_state = None, **kwargs, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, memory_state=memory_state) return hidden_states @@ -907,14 +1150,15 @@ class UpDecoderBlock3D(nn.Module): self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, + memory_state=None ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, memory_state=memory_state) return hidden_states @@ -1008,9 +1252,9 @@ class UNetMidBlock3D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, memory_state=None): video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") @@ -1018,7 +1262,7 @@ class UNetMidBlock3D(nn.Module): hidden_states = rearrange( hidden_states, "(b f) c h w -> b c f h w", f=video_length ) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state) return hidden_states @@ -1136,10 +1380,11 @@ class Encoder3D(nn.Module): self, sample: torch.FloatTensor, extra_cond=None, + memory_state = None ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample) + sample = self.conv_in(sample, memory_state = memory_state) if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -1164,17 +1409,17 @@ class Encoder3D(nn.Module): # down # [Override] add extra block and extra cond for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = down_block(sample) + sample = down_block(sample, memory_state=memory_state) if extra_block is not None: sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) # middle - sample = self.mid_block(sample) + sample = self.mid_block(sample, memory_state=memory_state) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample) + sample = self.conv_out(sample, memory_state = memory_state) return sample @@ -1282,74 +1527,90 @@ class Decoder3D(nn.Module): self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, + memory_state = None, ) -> torch.FloatTensor: sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample) + sample = self.conv_in(sample, memory_state=memory_state) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle - sample = self.mid_block(sample, latent_embeds) + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds) + sample = up_block(sample, latent_embeds, memory_state=memory_state) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample) + sample = self.conv_out(sample, memory_state=memory_state) return sample -def wavelet_blur(image: Tensor, radius: int): - """ - Apply wavelet blur to the input tensor. - """ - # input shape: (1, 3, H, W) - # convolution kernel +def wavelet_blur(image: Tensor, radius): + max_safe_radius = max(1, min(image.shape[-2:]) // 8) + if radius > max_safe_radius: + radius = max_safe_radius + + num_channels = image.shape[1] + kernel_vals = [ [0.0625, 0.125, 0.0625], - [0.125, 0.25, 0.125], + [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625], ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) - # add channel dimensions to the kernel to make it a 4D tensor - kernel = kernel[None, None] - # repeat the kernel across all input channels - kernel = kernel.repeat(3, 1, 1, 1) - image = F.pad(image, (radius, radius, radius, radius), mode='replicate') - # apply convolution - output = F.conv2d(image, kernel, groups=3, dilation=radius) + kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) + + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) + return output -def wavelet_decomposition(image: Tensor, levels=5): - """ - Apply wavelet decomposition to the input tensor. - This function only returns the low frequency & the high frequency. - """ +def wavelet_decomposition(image: Tensor, levels: int = 5): high_freq = torch.zeros_like(image) + for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) - high_freq += (image - low_freq) + high_freq.add_(image).sub_(low_freq) image = low_freq - + return high_freq, low_freq -def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): - """ - Apply wavelet decomposition, so that the content will have the same color as the style. - """ - # calculate the wavelet decomposition of the content feature +def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: + + if content_feat.shape != style_feat.shape: + # Resize style to match content spatial dimensions + if len(content_feat.shape) >= 3: + # safe_interpolate_operation handles FP16 conversion automatically + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Decompose both features into frequency components content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq - # calculate the wavelet decomposition of the style feature - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq - # reconstruct the content feature with the style's high frequency - return content_high_freq + style_low_freq + del content_low_freq # Free memory immediately + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq # Free memory immediately + + if content_high_freq.shape != style_low_freq.shape: + style_low_freq = safe_interpolate_operation( + style_low_freq, + size=content_high_freq.shape[-2:], + mode='bilinear', + align_corners=False + ) + + content_high_freq.add_(style_low_freq) + + return content_high_freq.clamp_(-1.0, 1.0) class VideoAutoencoderKL(nn.Module): def __init__( @@ -1368,9 +1629,12 @@ class VideoAutoencoderKL(nn.Module): time_receptive_field: _receptive_field_t = "full", use_quant_conv: bool = False, use_post_quant_conv: bool = False, + slicing_sample_min_size = 4, *args, **kwargs, ): + self.slicing_sample_min_size = slicing_sample_min_size + self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None block_out_channels = (128, 256, 512, 512) down_block_types = ("DownEncoderBlock3D",) * 4 @@ -1438,9 +1702,11 @@ class VideoAutoencoderKL(nn.Module): self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.use_slicing = True + def encode(self, x: torch.FloatTensor, return_dict: bool = True): h = self.slicing_encode(x) - posterior = DiagonalGaussianDistribution(h).sample() + posterior = DiagonalGaussianDistribution(h).mode() if not return_dict: return (posterior,) @@ -1458,30 +1724,72 @@ class VideoAutoencoderKL(nn.Module): return decoded def _encode( - self, x: torch.Tensor + self, x, memory_state ) -> torch.Tensor: _x = x.to(self.device) - h = self.encoder(_x,) + h = self.encoder(_x, memory_state=memory_state) if self.quant_conv is not None: - output = self.quant_conv(h) + output = self.quant_conv(h, memory_state=memory_state) else: output = h return output.to(x.device) def _decode( - self, z: torch.Tensor + self, z, memory_state ) -> torch.Tensor: - latent = z.to(self.device) + _z = z.to(self.device) + if self.post_quant_conv is not None: - latent = self.post_quant_conv(latent) - output = self.decoder(latent) + _z = self.post_quant_conv(_z, memory_state=memory_state) + + output = self.decoder(_z, memory_state=memory_state) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - return self._encode(x) + sp_size =1 + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) + encoded_slices = [ + self._encode( + torch.cat((x[:, :, :1], x_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for x_idx in range(1, len(x_slices)): + encoded_slices.append( + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(encoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._encode(x) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - return self._decode(z) + sp_size = 1 + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + decoded_slices = [ + self._decode( + torch.cat((z[:, :, :1], z_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING + ) + ] + for z_idx in range(1, len(z_slices)): + decoded_slices.append( + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(decoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._decode(z) def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError @@ -1531,6 +1839,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): self.freeze_encoder = freeze_encoder self.original_image_video = None super().__init__(*args, **kwargs) + self.set_memory_limit(0.5, 0.5) def forward(self, x: torch.FloatTensor): with torch.no_grad() if self.freeze_encoder else nullcontext(): @@ -1567,8 +1876,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): target_device = comfy.model_management.get_torch_device() self.decoder.to(target_device) - x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) - #x = super().decode(latent).squeeze(2) + if self.tiled_args.get("enable_tiling", None) is not None: + self.enable_tiling = self.tiled_args.pop("enable_tiling", False) + + if self.enable_tiling: + x = tiled_vae(latent, self, **self.tiled_args, encode=False).squeeze(2) + else: + x = super().decode_(latent).squeeze(2) input = rearrange(self.original_image_video, "b c t h w -> (b t) c h w") if x.ndim == 4: @@ -1581,6 +1895,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): x = rearrange(x, "b c t h w -> (b t) c h w") x = wavelet_reconstruction(x, input) + x = x.unsqueeze(0) o_h, o_w = self.img_dims x = x[..., :o_h, :o_w] @@ -1595,8 +1910,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): - # TODO - #set_norm_limit(norm_max_mem) + set_norm_limit(norm_max_mem) for m in self.modules(): if isinstance(m, InflatedCausalConv3d): m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index ce5437517..8380e4feb 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -14,25 +14,23 @@ from torchvision.transforms import Lambda, Normalize from torchvision.transforms.functional import InterpolationMode @torch.inference_mode() -def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True): +def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True): gc.collect() torch.cuda.empty_cache() + x = x.to(next(vae_model.parameters()).dtype) if x.ndim != 5: x = x.unsqueeze(2) b, c, d, h, w = x.shape - + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) sf_t = getattr(vae_model, "temporal_downsample_factor", 4) if encode: ti_h, ti_w = tile_size ov_h, ov_w = tile_overlap - ti_t = temporal_size - ov_t = temporal_overlap - target_d = (d + sf_t - 1) // sf_t target_h = (h + sf_s - 1) // sf_s target_w = (w + sf_s - 1) // sf_s @@ -41,21 +39,44 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ti_w = max(1, tile_size[1] // sf_s) ov_h = max(0, tile_overlap[0] // sf_s) ov_w = max(0, tile_overlap[1] // sf_s) - ti_t = max(1, temporal_size // sf_t) - ov_t = max(0, temporal_overlap // sf_t) - + target_d = d * sf_t target_h = h * sf_s target_w = w * sf_s - stride_t = max(1, ti_t - ov_t) stride_h = max(1, ti_h - ov_h) stride_w = max(1, ti_w - ov_w) storage_device = torch.device("cpu") + result = None count = None + def run_temporal_chunks(spatial_tile): + chunk_results = [] + t_dim_size = spatial_tile.shape[2] + + if encode: + input_chunk = temporal_size + else: + input_chunk = max(1, temporal_size // sf_t) + + for i in range(0, t_dim_size, input_chunk): + t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] + + if encode: + out = vae_model.encode(t_chunk) + else: + out = vae_model.decode_(t_chunk) + + if isinstance(out, (tuple, list)): out = out[0] + + if out.ndim == 4: out = out.unsqueeze(2) + + chunk_results.append(out.to(storage_device)) + + return torch.cat(chunk_results, dim=2) + ramp_cache = {} def get_ramp(steps): if steps not in ramp_cache: @@ -63,79 +84,64 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) return ramp_cache[steps] - bar = ProgressBar(d // stride_t) - for t_idx in range(0, d, stride_t): - t_end = min(t_idx + ti_t, d) + total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w)) + bar = ProgressBar(total_tiles) - for y_idx in range(0, h, stride_h): - y_end = min(y_idx + ti_h, h) + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) - for x_idx in range(0, w, stride_w): - x_end = min(x_idx + ti_w, w) + tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] - tile_x = x[:, :, t_idx:t_end, y_idx:y_end, x_idx:x_end] + # Run VAE + tile_out = run_temporal_chunks(tile_x) - if encode: - tile_out = vae_model.encode(tile_x)[0] - else: - tile_out = vae_model.decode_(tile_x) + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) - if tile_out.ndim == 4: - tile_out = tile_out.unsqueeze(2) + if encode: + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) + else: + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) - tile_out = tile_out.to(storage_device).float() + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) - if result is None: - b_out, c_out = tile_out.shape[0], tile_out.shape[1] - result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) - count = torch.zeros((1, 1, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: w_h[:cur_ov_h] = r + if y_end < h: w_h[-cur_ov_h:] = 1.0 - r - if encode: - ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2] - ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] - xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: w_w[:cur_ov_w] = r + if x_end < w: w_w[-cur_ov_w:] = 1.0 - r - cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2)) - cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) - else: - ts, te = t_idx * sf_t, (t_idx * sf_t) + tile_out.shape[2] - ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] - xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - cur_ov_t = max(0, min(ov_t, tile_out.shape[2] // 2)) - cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2)) + valid_d = min(tile_out.shape[2], result.shape[2]) + tile_out = tile_out[:, :, :valid_d, :, :] + + tile_out.mul_(final_weight) + + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out + count[:, :, :, ys:ye, xs:xe] += final_weight - w_t = torch.ones((tile_out.shape[2],), device=storage_device) - w_h = torch.ones((tile_out.shape[3],), device=storage_device) - w_w = torch.ones((tile_out.shape[4],), device=storage_device) + del tile_out, final_weight, tile_x, w_h, w_w + bar.update(1) - if cur_ov_t > 0: - r = get_ramp(cur_ov_t) - if t_idx > 0: w_t[:cur_ov_t] = r - if t_end < d: w_t[-cur_ov_t:] = 1.0 - r - - if cur_ov_h > 0: - r = get_ramp(cur_ov_h) - if y_idx > 0: w_h[:cur_ov_h] = r - if y_end < h: w_h[-cur_ov_h:] = 1.0 - r - - if cur_ov_w > 0: - r = get_ramp(cur_ov_w) - if x_idx > 0: w_w[:cur_ov_w] = r - if x_end < w: w_w[-cur_ov_w:] = 1.0 - r - - final_weight = w_t.view(1,1,-1,1,1) * w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - - tile_out.mul_(final_weight) - result[:, :, ts:te, ys:ye, xs:xe] += tile_out - count[:, :, ts:te, ys:ye, xs:xe] += final_weight - - del tile_out, final_weight, tile_x, w_t, w_h, w_w - bar.update(1) result.div_(count.clamp(min=1e-6)) - + if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -253,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Int.Input("spatial_tile_size", default = 512, min = -1), io.Int.Input("temporal_tile_size", default = 8, min = -1), io.Int.Input("spatial_overlap", default = 64, min = -1), - io.Int.Input("temporal_overlap", default = 8, min = -1), + io.Boolean.Input("enable_tiling", default=False) ], outputs = [ io.Latent.Output("vae_conditioning") @@ -261,7 +267,7 @@ class SeedVR2InputProcessing(io.ComfyNode): ) @classmethod - def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap): + def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() @@ -296,9 +302,14 @@ class SeedVR2InputProcessing(io.ComfyNode): vae_model.original_image_video = images args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), - "temporal_size":temporal_tile_size, "temporal_overlap": temporal_overlap} + "temporal_size":temporal_tile_size} + if enable_tiling: + latent = tiled_vae(images, vae_model, encode=True, **args) + else: + latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0] + + args["enable_tiling"] = enable_tiling vae_model.tiled_args = args - latent = tiled_vae(images, vae_model, encode=True, **args) vae_model = vae_model.to(offload_device) vae_model.img_dims = [o_h, o_w] From 7b2e5ef0af9b92427b73a93517d15983e24c98ae Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 24 Dec 2025 22:15:27 +0200 Subject: [PATCH 26/35] outputs/speed/memory match custom node --- comfy/ldm/seedvr/model.py | 31 +++++++++++++++++++++++++++++-- comfy_extras/nodes_seedvr.py | 14 ++++++++++---- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index cf3ebd520..7578c0be5 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -491,6 +491,11 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): "mmrope_freqs_3d", lambda: self.get_freqs(vid_shape, txt_shape), ) + target_device = vid_q.device + if vid_freqs.device != target_device: + vid_freqs = vid_freqs.to(target_device) + if txt_freqs.device != target_device: + txt_freqs = txt_freqs.to(target_device) vid_q = rearrange(vid_q, "L h d -> h L d") vid_k = rearrange(vid_k, "L h d -> h L d") vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) @@ -506,6 +511,7 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): txt_k = rearrange(txt_k, "h L d -> L h d") return vid_q, vid_k, txt_q, txt_k + @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks def get_freqs( self, vid_shape: torch.LongTensor, @@ -514,8 +520,29 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): torch.Tensor, torch.Tensor, ]: - vid_freqs = self.get_axial_freqs(1024, 128, 128) - txt_freqs = self.get_axial_freqs(1024) + + # Calculate actual max dimensions needed for this batch + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + + # Compute frequencies for actual max dimensions needed + # Add small buffer to improve cache hits across similar batches + vid_freqs = self.get_axial_freqs( + min(max_temporal + 16, 1024), # Cap at 1024, add small buffer + min(max_height + 4, 128), # Cap at 128, add small buffer + min(max_width + 4, 128) # Cap at 128, add small buffer + ) + txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024)) + + # Now slice as before vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 8380e4feb..e6ccd44c1 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -65,9 +65,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] if encode: - out = vae_model.encode(t_chunk) + out = vae_model.slicing_encode(t_chunk) else: - out = vae_model.decode_(t_chunk) + out = vae_model.slicing_decode(t_chunk) if isinstance(out, (tuple, list)): out = out[0] @@ -245,6 +245,11 @@ def cut_videos(videos): assert (videos.size(1) - 1) % (4) == 0 return videos +def side_resize(image, size): + antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps') + resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias) + return resized + class SeedVR2InputProcessing(io.ComfyNode): @classmethod def define_schema(cls): @@ -285,7 +290,8 @@ class SeedVR2InputProcessing(io.ComfyNode): max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) - images = area_resize(images, max_area) + #images = area_resize(images, max_area) + images = side_resize(images, resolution_height) images = clip(images) o_h, o_w = images.shape[-2:] @@ -348,7 +354,7 @@ class SeedVR2Conditioning(io.ComfyNode): noises = torch.randn_like(vae_conditioning).to(device) aug_noises = torch.randn_like(vae_conditioning).to(device) - + aug_noises = noises * 0.1 + aug_noises * 0.05 cond_noise_scale = 0.0 t = ( torch.tensor([1000.0]) From 21bc67d7db037d652d3d5fc65087261fc5411b96 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 02:08:59 +0200 Subject: [PATCH 27/35] final changes --- comfy/ldm/seedvr/model.py | 70 +++++++++++++++++++----------- comfy/ldm/seedvr/vae.py | 83 +++++++++++++++++++----------------- comfy_extras/nodes_seedvr.py | 61 +++++++++++++++----------- 3 files changed, 125 insertions(+), 89 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 7578c0be5..bd0057332 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -526,22 +526,22 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): max_height = 0 max_width = 0 max_txt_len = 0 - + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal max_height = max(max_height, h) max_width = max(max_width, w) max_txt_len = max(max_txt_len, l) - + # Compute frequencies for actual max dimensions needed # Add small buffer to improve cache hits across similar batches vid_freqs = self.get_axial_freqs( min(max_temporal + 16, 1024), # Cap at 1024, add small buffer - min(max_height + 4, 128), # Cap at 128, add small buffer + min(max_height + 4, 128), # Cap at 128, add small buffer min(max_width + 4, 128) # Cap at 128, add small buffer ) txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024)) - + # Now slice as before vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): @@ -615,6 +615,7 @@ class NaMMAttention(nn.Module): rope_type: Optional[str], rope_dim: int, shared_weights: bool, + device, dtype, operations, **kwargs, ): super().__init__() @@ -624,15 +625,16 @@ class NaMMAttention(nn.Module): qkv_dim = inner_dim * 3 self.head_dim = head_dim self.proj_qkv = MMModule( - nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights + operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype ) - self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) + self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) self.norm_q = MMModule( qk_norm, normalized_shape=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, + device=device, dtype=dtype ) self.norm_k = MMModule( qk_norm, @@ -640,6 +642,7 @@ class NaMMAttention(nn.Module): eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, + device=device, dtype=dtype ) @@ -795,11 +798,12 @@ class MLP(nn.Module): self, dim: int, expand_ratio: int, + device, dtype, operations ): super().__init__() - self.proj_in = nn.Linear(dim, dim * expand_ratio) + self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) self.act = nn.GELU("tanh") - self.proj_out = nn.Linear(dim * expand_ratio, dim) + self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = self.proj_in(x) @@ -814,13 +818,14 @@ class SwiGLUMLP(nn.Module): dim: int, expand_ratio: int, multiple_of: int = 256, + device=None, dtype=None, operations=None ): super().__init__() hidden_dim = int(2 * dim * expand_ratio / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) - self.proj_out = nn.Linear(hidden_dim, dim, bias=False) - self.proj_in = nn.Linear(dim, hidden_dim, bias=False) + self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) + self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = x.to(next(self.proj_in.parameters()).device) @@ -855,11 +860,12 @@ class NaMMSRTransformerBlock(nn.Module): rope_type: str, rope_dim: int, is_last_layer: bool, + device, dtype, operations, **kwargs, ): super().__init__() dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) + self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) self.attn = NaSwinAttention( vid_dim=vid_dim, @@ -874,17 +880,19 @@ class NaMMSRTransformerBlock(nn.Module): shared_weights=shared_weights, window=kwargs.pop("window", None), window_method=kwargs.pop("window_method", None), + device=device, dtype=dtype, operations=operations ) - self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) + self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) self.mlp = MMModule( get_mlp(mlp_type), dim=dim, expand_ratio=expand_ratio, shared_weights=shared_weights, - vid_only=is_last_layer + vid_only=is_last_layer, + device=device, dtype=dtype, operations=operations ) - self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) self.is_last_layer = is_last_layer def forward( @@ -933,11 +941,12 @@ class PatchOut(nn.Module): out_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, + device, dtype, operations ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w - self.proj = nn.Linear(dim, out_channels * t * h * w) + self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) def forward( self, @@ -981,11 +990,12 @@ class PatchIn(nn.Module): in_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, + device, dtype, operations ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w - self.proj = nn.Linear(in_channels * t * h * w, dim) + self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) def forward( self, @@ -1033,6 +1043,7 @@ class AdaSingle(nn.Module): emb_dim: int, layers: List[str], modes: List[str] = ["in", "out"], + device = None, dtype = None, ): assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" super().__init__() @@ -1041,12 +1052,12 @@ class AdaSingle(nn.Module): self.layers = layers for l in layers: if "in" in modes: - self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5)) self.register_parameter( f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) ) if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, device=device, dtype=dtype) / dim**0.5)) def forward( self, @@ -1096,12 +1107,13 @@ class TimeEmbedding(nn.Module): sinusoidal_dim: int, hidden_dim: int, output_dim: int, + device, dtype, operations ): super().__init__() self.sinusoidal_dim = sinusoidal_dim - self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) - self.proj_hid = nn.Linear(hidden_dim, hidden_dim) - self.proj_out = nn.Linear(hidden_dim, output_dim) + self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) + self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) self.act = nn.SiLU() def forward( @@ -1199,6 +1211,7 @@ class NaDiT(nn.Module): **kwargs, ): self.dtype = dtype + factory_kwargs = {"device": device, "dtype": dtype} window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] txt_dim = vid_dim emb_dim = vid_dim * 6 @@ -1212,15 +1225,16 @@ class NaDiT(nn.Module): elif len(block_type) != num_layers: raise ValueError("The ``block_type`` list should equal to ``num_layers``.") super().__init__() - self.register_buffer("positive_conditioning", torch.empty((58, 5120))) - self.register_buffer("negative_conditioning", torch.empty((64, 5120))) + self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype)) + self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype)) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, dim=vid_dim, + device=device, dtype=dtype, operations=operations ) self.txt_in = ( - nn.Linear(txt_in_dim, txt_dim) + operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) if txt_in_dim and txt_in_dim != txt_dim else nn.Identity() ) @@ -1228,6 +1242,7 @@ class NaDiT(nn.Module): sinusoidal_dim=256, hidden_dim=max(vid_dim, txt_dim), output_dim=emb_dim, + device=device, dtype=dtype, operations=operations ) if window is None or isinstance(window[0], int): @@ -1268,7 +1283,9 @@ class NaDiT(nn.Module): shared_weights=not ( (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] ), + operations = operations, **kwargs, + **factory_kwargs ) for i in range(num_layers) ] @@ -1277,6 +1294,7 @@ class NaDiT(nn.Module): out_channels=vid_out_channels, patch_size=patch_size, dim=vid_dim, + device=device, dtype=dtype, operations=operations ) self.need_txt_repeat = block_type[0] in [ @@ -1291,12 +1309,14 @@ class NaDiT(nn.Module): normalized_shape=vid_dim, eps=norm_eps, elementwise_affine=True, + device=device, dtype=dtype ) self.vid_out_ada = ada( dim=vid_dim, emb_dim=emb_dim, layers=["out"], modes=["in"], + device=device, dtype=dtype ) self.stop_cfg_index = -1 diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 0c7fa5c5f..9fcea60ad 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -16,6 +16,9 @@ import math from enum import Enum from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND +import comfy.ops +ops = comfy.ops.disable_weight_init + _NORM_LIMIT = float("inf") @@ -89,9 +92,9 @@ class SpatialNorm(nn.Module): zq_channels: int, ): super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: f_size = f.shape[-2:] @@ -164,7 +167,7 @@ class Attention(nn.Module): self.only_cross_attention = only_cross_attention if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None @@ -177,22 +180,22 @@ class Attention(nn.Module): self.norm_k = None self.norm_cross = None - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes - self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) else: self.to_k = None self.to_v = None self.added_proj_bias = added_proj_bias if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) else: self.add_q_proj = None self.add_k_proj = None @@ -200,13 +203,13 @@ class Attention(nn.Module): if not self.pre_only: self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) else: self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) else: self.to_add_out = None @@ -325,7 +328,7 @@ def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype - if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): + if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): if x.ndim == 4: x = rearrange(x, "b c h w -> b h w c") x = norm_layer(x) @@ -336,14 +339,14 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: x = norm_layer(x) x = rearrange(x, "b t h w c -> b c t h w") return x.to(input_dtype) - if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): if x.ndim <= 4: return norm_layer(x).to(input_dtype) if x.ndim == 5: t = x.size(2) x = rearrange(x, "b c t h w -> (b t) c h w") memory_occupy = x.numel() * x.element_size() / 1024**3 - if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae + if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) assert norm_layer.num_groups % num_chunks == 0 num_groups_per_chunk = norm_layer.num_groups // num_chunks @@ -428,7 +431,7 @@ def cache_send_recv(tensor, cache_size, times, memory=None): return recv_buffer -class InflatedCausalConv3d(torch.nn.Conv3d): +class InflatedCausalConv3d(ops.Conv3d): def __init__( self, *args, @@ -677,17 +680,16 @@ class Upsample3D(nn.Module): if use_conv_transpose: if kernel_size is None: kernel_size = 4 - self.conv = nn.ConvTranspose2d( + self.conv = ops.ConvTranspose2d( channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias ) elif use_conv: if kernel_size is None: kernel_size = 3 - self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) conv = self.conv if self.name == "conv" else self.Conv2d_0 - assert type(conv) is not nn.ConvTranspose2d # Note: lora_layer is not passed into constructor in the original implementation. # So we make a simplification. conv = InflatedCausalConv3d( @@ -708,7 +710,7 @@ class Upsample3D(nn.Module): # [Override] MAGViT v2 implementation if not self.interpolate: upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio - self.upscale_conv = nn.Conv3d( + self.upscale_conv = ops.Conv3d( self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 ) identity = ( @@ -892,13 +894,13 @@ class ResnetBlock3D(nn.Module): self.skip_time_act = skip_time_act self.nonlinearity = nn.SiLU() if temb_channels is not None: - self.time_emb_proj = nn.Linear(temb_channels, out_channels) + self.time_emb_proj = ops.Linear(temb_channels, out_channels) else: self.time_emb_proj = None - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) if groups_out is None: groups_out = groups - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.use_in_shortcut = self.in_channels != out_channels self.dropout = torch.nn.Dropout(dropout) self.conv1 = InflatedCausalConv3d( @@ -1342,7 +1344,7 @@ class Encoder3D(nn.Module): self.conv_extra_cond.append( zero_module( - nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) + ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) ) if self.extra_cond_dim is not None and self.extra_cond_dim > 0 else None @@ -1364,7 +1366,7 @@ class Encoder3D(nn.Module): ) # out - self.conv_norm_out = nn.GroupNorm( + self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() @@ -1512,7 +1514,7 @@ class Decoder3D(nn.Module): if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: - self.conv_norm_out = nn.GroupNorm( + self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() @@ -1553,9 +1555,9 @@ def wavelet_blur(image: Tensor, radius): max_safe_radius = max(1, min(image.shape[-2:]) // 8) if radius > max_safe_radius: radius = max_safe_radius - + num_channels = image.shape[1] - + kernel_vals = [ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], @@ -1563,21 +1565,21 @@ def wavelet_blur(image: Tensor, radius): ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) - + return output def wavelet_decomposition(image: Tensor, levels: int = 5): high_freq = torch.zeros_like(image) - + for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) high_freq.add_(image).sub_(low_freq) image = low_freq - + return high_freq, low_freq def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: @@ -1587,19 +1589,19 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: if len(content_feat.shape) >= 3: # safe_interpolate_operation handles FP16 conversion automatically style_feat = safe_interpolate_operation( - style_feat, + style_feat, size=content_feat.shape[-2:], - mode='bilinear', + mode='bilinear', align_corners=False ) - + # Decompose both features into frequency components content_high_freq, content_low_freq = wavelet_decomposition(content_feat) del content_low_freq # Free memory immediately - - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) del style_high_freq # Free memory immediately - + if content_high_freq.shape != style_low_freq.shape: style_low_freq = safe_interpolate_operation( style_low_freq, @@ -1607,9 +1609,9 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: mode='bilinear', align_corners=False ) - + content_high_freq.add_(style_low_freq) - + return content_high_freq.clamp_(-1.0, 1.0) class VideoAutoencoderKL(nn.Module): @@ -1894,6 +1896,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): x = rearrange(x, "b c t h w -> (b t) c h w") + input = input.to(x.device) x = wavelet_reconstruction(x, input) x = x.unsqueeze(0) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e6ccd44c1..bd0c6037a 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -24,7 +24,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora x = x.unsqueeze(2) b, c, d, h, w = x.shape - + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) sf_t = getattr(vae_model, "temporal_downsample_factor", 4) @@ -39,7 +39,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ti_w = max(1, tile_size[1] // sf_s) ov_h = max(0, tile_overlap[0] // sf_s) ov_w = max(0, tile_overlap[1] // sf_s) - + target_d = d * sf_t target_h = h * sf_s target_w = w * sf_s @@ -47,15 +47,14 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora stride_h = max(1, ti_h - ov_h) stride_w = max(1, ti_w - ov_w) - storage_device = torch.device("cpu") - + storage_device = vae_model.device result = None count = None def run_temporal_chunks(spatial_tile): chunk_results = [] t_dim_size = spatial_tile.shape[2] - + if encode: input_chunk = temporal_size else: @@ -63,18 +62,18 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora for i in range(0, t_dim_size, input_chunk): t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] - + if encode: - out = vae_model.slicing_encode(t_chunk) + out = vae_model.encode(t_chunk) else: - out = vae_model.slicing_decode(t_chunk) - + out = vae_model.decode_(t_chunk) + if isinstance(out, (tuple, list)): out = out[0] - + if out.ndim == 4: out = out.unsqueeze(2) - - chunk_results.append(out.to(storage_device)) - + + chunk_results.append(out.to(storage_device)) + return torch.cat(chunk_results, dim=2) ramp_cache = {} @@ -89,7 +88,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora for y_idx in range(0, h, stride_h): y_end = min(y_idx + ti_h, h) - + for x_idx in range(0, w, stride_w): x_end = min(x_idx + ti_w, w) @@ -131,9 +130,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora valid_d = min(tile_out.shape[2], result.shape[2]) tile_out = tile_out[:, :, :valid_d, :, :] - + tile_out.mul_(final_weight) - + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out count[:, :, :, ys:ye, xs:xe] += final_weight @@ -141,7 +140,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora bar.update(1) result.div_(count.clamp(min=1e-6)) - + if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -150,6 +149,18 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora return result +def clear_vae_memory(vae_model): + for module in vae_model.modules(): + if hasattr(module, "memory"): + module.memory = None + if hasattr(vae_model, "original_image_video"): + del vae_model.original_image_video + + if hasattr(vae_model, "tiled_args"): + del vae_model.tiled_args + gc.collect() + torch.cuda.empty_cache() + def expand_dims(tensor, ndim): shape = tensor.shape + (1,) * (ndim - tensor.ndim) return tensor.reshape(shape) @@ -261,9 +272,9 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Vae.Input("vae"), io.Int.Input("resolution_height", default = 1280, min = 120), # // io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value - io.Int.Input("spatial_tile_size", default = 512, min = -1), - io.Int.Input("temporal_tile_size", default = 8, min = -1), - io.Int.Input("spatial_overlap", default = 64, min = -1), + io.Int.Input("spatial_tile_size", default = 512, min = 1), + io.Int.Input("temporal_tile_size", default = 8, min = 1), + io.Int.Input("spatial_overlap", default = 64, min = 1), io.Boolean.Input("enable_tiling", default=False) ], outputs = [ @@ -305,7 +316,6 @@ class SeedVR2InputProcessing(io.ComfyNode): images = rearrange(images, "b t c h w -> b c t h w") images = images.to(device) vae_model = vae_model.to(device) - vae_model.original_image_video = images args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), "temporal_size":temporal_tile_size} @@ -314,11 +324,14 @@ class SeedVR2InputProcessing(io.ComfyNode): else: latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0] + clear_vae_memory(vae_model) + #images = images.to(offload_device) + #vae_model = vae_model.to(offload_device) + + vae_model.img_dims = [o_h, o_w] args["enable_tiling"] = enable_tiling vae_model.tiled_args = args - - vae_model = vae_model.to(offload_device) - vae_model.img_dims = [o_h, o_w] + vae_model.original_image_video = images latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") From 4d7012ecdacd99e150f8c1e147b672f85e8e7e0b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 02:23:51 +0200 Subject: [PATCH 28/35] . --- comfy_extras/nodes_seedvr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index bd0c6037a..22b117872 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -298,7 +298,7 @@ class SeedVR2InputProcessing(io.ComfyNode): b, t, c, h, w = images.shape images = images.reshape(b * t, c, h, w) - max_area = ((resolution_height * resolution_width)** 0.5) ** 2 + #max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) #images = area_resize(images, max_area) From 9b573da39b5ca9d08104840206ec76d4b6601c8e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 21:16:36 +0200 Subject: [PATCH 29/35] added other types of attention + compatibility with images --- comfy/ldm/modules/attention.py | 83 ++++++++++++++++++++++++++-------- comfy/ldm/seedvr/vae.py | 17 +++++-- comfy_extras/nodes_seedvr.py | 24 ++++++---- 3 files changed, 93 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 332c65ffb..c7a15a5c8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -19,9 +19,15 @@ if model_management.xformers_enabled(): import xformers.ops SAGE_ATTENTION_IS_AVAILABLE = False +SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False try: from sageattention import sageattn SAGE_ATTENTION_IS_AVAILABLE = True + try: + from sageattention import sageattn_varlen + SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True + except: + pass except ImportError as e: if model_management.sage_attention_enabled(): if e.name == "sageattention": @@ -80,7 +86,13 @@ def default(val, d): return val return d - +def var_attn_arg(kwargs): + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) + assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops): @@ -404,14 +416,14 @@ except: pass @wrap_attn -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken disabled_xformers = False if BROKEN_XFORMERS: - if b * heads > 65535: + if b * heads > 65535 and not var_length: disabled_xformers = True if not disabled_xformers: @@ -419,9 +431,24 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs) - if skip_reshape: + if var_length: + if not skip_reshape: + total_tokens, hidden_dim = q.shape + dim_head = hidden_dim // heads + q = q.view(1, total_tokens, heads, dim_head) + k = k.view(1, total_tokens, heads, dim_head) + v = v.view(1, total_tokens, heads, dim_head) + else: + if q.ndim == 3: q = q.unsqueeze(0) + if k.ndim == 3: k = k.unsqueeze(0) + if v.ndim == 3: v = v.unsqueeze(0) + dim_head = q.shape[-1] + + target_output_shape = (q.shape[1], -1) + b = 1 + elif skip_reshape: # b h k d -> b k h d q, k, v = map( lambda t: t.permute(0, 2, 1, 3), @@ -435,7 +462,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh (q, k, v), ) - if mask is not None: + if var_length: + cu_seqlens_q, _, _, _ = var_attn_arg(kwargs) + seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens) + elif mask is not None: # add a singleton batch dimension if mask.ndim == 2: mask = mask.unsqueeze(0) @@ -457,6 +488,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) + if var_length: + return out.reshape(*target_output_shape) if skip_output_reshape: out = out.permute(0, 2, 1, 3) else: @@ -475,9 +508,7 @@ else: @wrap_attn def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): if var_length: - cu_seqlens_q = kwargs.get("cu_seqlens_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) - assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) if not skip_reshape: # assumes 2D q, k,v [total_tokens, embed_dim] total_tokens, embed_dim = q.shape @@ -539,9 +570,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out @wrap_attn -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs): exception_fallback = False - if skip_reshape: + if var_length: + if not skip_reshape: + total_tokens, hidden_dim = q.shape + dim_head = hidden_dim // heads + q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)] + b, _, dim_head = q.shape + # skips batched code + mask = None + tensor_layout = "VAR" + target_output_shape = (q.shape[0], -1) + elif skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" else: @@ -562,7 +603,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= mask = mask.unsqueeze(1) try: - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE: + raise ValueError("Sage Attention two is required to run variable length attention.") + elif var_length: + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale) + else: + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) exception_fallback = True @@ -572,7 +620,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -583,6 +631,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if skip_output_reshape: out = out.transpose(1, 2) else: + if var_length: + return out.view(*target_output_shape) out = out.reshape(b, -1, heads * dim_head) return out @@ -608,12 +658,7 @@ except AttributeError as error: @wrap_attn def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): if var_length: - cu_seqlens_q = kwargs.get("cu_seqlens_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) - max_seqlen_q = kwargs.get("max_seqlen_q", None) - max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) - assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True" - assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) return flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 9fcea60ad..c9fef0677 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -499,6 +499,8 @@ class InflatedCausalConv3d(ops.Conv3d): def pad_and_forward(): padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + if not padded.is_contiguous(): + padded = padded.contiguous() with ignore_padding(self): return torch.nn.Conv3d.forward(self, padded) @@ -1726,7 +1728,7 @@ class VideoAutoencoderKL(nn.Module): return decoded def _encode( - self, x, memory_state + self, x, memory_state = MemoryState.DISABLED ) -> torch.Tensor: _x = x.to(self.device) h = self.encoder(_x, memory_state=memory_state) @@ -1737,7 +1739,7 @@ class VideoAutoencoderKL(nn.Module): return output.to(x.device) def _decode( - self, z, memory_state + self, z, memory_state = MemoryState.DISABLED ) -> torch.Tensor: _z = z.to(self.device) @@ -1892,9 +1894,16 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): # in case of padded frames t = input.size(0) - x = x[:, :, :t] + if t != 1: + x = x[:, :, :t] + if t == 1 and x.size(2) == 4: + x = x[:, :, :t] - x = rearrange(x, "b c t h w -> (b t) c h w") + if x.size(1) == 1: + exp = "b t c h w -> (b t) c h w" + else: + exp = "b c t h w -> (b t) c h w" + x = rearrange(x, exp) input = input.to(x.device) x = wavelet_reconstruction(x, input) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 22b117872..4ec089dde 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -270,12 +270,11 @@ class SeedVR2InputProcessing(io.ComfyNode): inputs = [ io.Image.Input("images"), io.Vae.Input("vae"), - io.Int.Input("resolution_height", default = 1280, min = 120), # // - io.Int.Input("resolution_width", default = 720, min = 120), # just non-zero value + io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value io.Int.Input("spatial_tile_size", default = 512, min = 1), - io.Int.Input("temporal_tile_size", default = 8, min = 1), io.Int.Input("spatial_overlap", default = 64, min = 1), - io.Boolean.Input("enable_tiling", default=False) + io.Int.Input("temporal_tile_size", default = 8, min = 1), + io.Boolean.Input("enable_tiling", default=False), ], outputs = [ io.Latent.Output("vae_conditioning") @@ -283,7 +282,7 @@ class SeedVR2InputProcessing(io.ComfyNode): ) @classmethod - def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): + def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): device = vae.patcher.load_device offload_device = comfy.model_management.intermediate_device() @@ -298,11 +297,9 @@ class SeedVR2InputProcessing(io.ComfyNode): b, t, c, h, w = images.shape images = images.reshape(b * t, c, h, w) - #max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) - #images = area_resize(images, max_area) - images = side_resize(images, resolution_height) + images = side_resize(images, resolution) images = clip(images) o_h, o_w = images.shape[-2:] @@ -317,6 +314,17 @@ class SeedVR2InputProcessing(io.ComfyNode): images = images.to(device) vae_model = vae_model.to(device) + # in case users a non-compatiable number for tiling + def make_divisible(val, divisor): + return max(divisor, round(val / divisor) * divisor) + + temporal_tile_size = make_divisible(temporal_tile_size, 4) + spatial_tile_size = make_divisible(spatial_tile_size, 32) + spatial_overlap = make_divisible(spatial_overlap, 32) + + if spatial_overlap >= spatial_tile_size: + spatial_overlap = max(0, spatial_tile_size - 8) + args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap), "temporal_size":temporal_tile_size} if enable_tiling: From 3039c7ba149435b92b81ec3e6f46d99b6d9aca13 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 23:12:45 +0200 Subject: [PATCH 30/35] tile edge case handles by padding vid --- comfy_extras/nodes_seedvr.py | 68 ++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 4ec089dde..314100324 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -59,19 +59,37 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora input_chunk = temporal_size else: input_chunk = max(1, temporal_size // sf_t) - for i in range(0, t_dim_size, input_chunk): t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :] + current_valid_len = t_chunk.shape[2] + + pad_amount = 0 + if current_valid_len < input_chunk: + pad_amount = input_chunk - current_valid_len + + last_frame = t_chunk[:, :, -1:, :, :] + padding = last_frame.repeat(1, 1, pad_amount, 1, 1) + + t_chunk = torch.cat([t_chunk, padding], dim=2) + t_chunk = t_chunk.contiguous() if encode: - out = vae_model.encode(t_chunk) + out = vae_model.encode(t_chunk)[0] else: out = vae_model.decode_(t_chunk) if isinstance(out, (tuple, list)): out = out[0] - if out.ndim == 4: out = out.unsqueeze(2) + if pad_amount > 0: + if encode: + expected_valid_out = (current_valid_len + sf_t - 1) // sf_t + out = out[:, :, :expected_valid_out, :, :] + + else: + expected_valid_out = current_valid_len * sf_t + out = out[:, :, :expected_valid_out, :, :] + chunk_results.append(out.to(storage_device)) return torch.cat(chunk_results, dim=2) @@ -149,15 +167,46 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora return result +def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False): + t = videos.size(temporal_dim) + + if count == 0 and not prepend: + if t % 4 == 1: + return videos + count = ((t - 1) // 4 + 1) * 4 + 1 - t + + if count <= 0: + return videos + + def select(start, end): + return videos[start:end] if temporal_dim == 0 else videos[:, start:end] + + if count >= t: + repeat_count = count - t + 1 + last = select(-1, None) + + if temporal_dim == 0: + repeated = last.repeat(repeat_count, 1, 1, 1) + reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0] + else: + repeated = last.expand(-1, repeat_count, -1, -1).contiguous() + reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0] + + return torch.cat([repeated, reversed_frames, videos] if prepend else + [videos, reversed_frames, repeated], dim=temporal_dim) + + if prepend: + reversed_frames = select(1, count+1).flip(temporal_dim) + else: + reversed_frames = select(-count-1, -1).flip(temporal_dim) + + return torch.cat([reversed_frames, videos] if prepend else + [videos, reversed_frames], dim=temporal_dim) + def clear_vae_memory(vae_model): for module in vae_model.modules(): if hasattr(module, "memory"): module.memory = None - if hasattr(vae_model, "original_image_video"): - del vae_model.original_image_video - - if hasattr(vae_model, "tiled_args"): - del vae_model.tiled_args gc.collect() torch.cuda.empty_cache() @@ -273,7 +322,7 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value io.Int.Input("spatial_tile_size", default = 512, min = 1), io.Int.Input("spatial_overlap", default = 64, min = 1), - io.Int.Input("temporal_tile_size", default = 8, min = 1), + io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4), io.Boolean.Input("enable_tiling", default=False), ], outputs = [ @@ -318,7 +367,6 @@ class SeedVR2InputProcessing(io.ComfyNode): def make_divisible(val, divisor): return max(divisor, round(val / divisor) * divisor) - temporal_tile_size = make_divisible(temporal_tile_size, 4) spatial_tile_size = make_divisible(spatial_tile_size, 32) spatial_overlap = make_divisible(spatial_overlap, 32) From fadc7839cc88acd822efe46adf8529c6d8c11fb5 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 26 Dec 2025 23:14:33 +0200 Subject: [PATCH 31/35] ruff --- comfy/ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c7a15a5c8..6163aec22 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -446,7 +446,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh if v.ndim == 3: v = v.unsqueeze(0) dim_head = q.shape[-1] - target_output_shape = (q.shape[1], -1) + target_output_shape = (q.shape[1], -1) b = 1 elif skip_reshape: # b h k d -> b k h d From 84fa1550717b776072b9ae2783be5542fa64f4a9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 30 Dec 2025 18:44:57 +0200 Subject: [PATCH 32/35] fixed manual vae loading --- comfy/ldm/seedvr/vae.py | 2 -- comfy/sd.py | 4 ++-- comfy_extras/nodes_seedvr.py | 8 +------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index c9fef0677..292958a88 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1878,8 +1878,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): if latent.ndim == 4: latent = latent.unsqueeze(2) - target_device = comfy.model_management.get_torch_device() - self.decoder.to(target_device) if self.tiled_args.get("enable_tiling", None) is not None: self.enable_tiling = self.tiled_args.pop("enable_tiling", False) diff --git a/comfy/sd.py b/comfy/sd.py index be2ce30a8..5f89d2c82 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -379,8 +379,8 @@ class VAE: self.latent_channels = 16 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.memory_used_decode = lambda shape, dtype: (2000 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 314100324..945cf966b 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -332,11 +332,8 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): - device = vae.patcher.load_device - offload_device = comfy.model_management.intermediate_device() - main_device = comfy.model_management.get_torch_device() - images = images.to(main_device) + comfy.model_management.load_models_gpu([vae.patcher]) vae_model = vae.first_stage_model scale = 0.9152; shift = 0 if images.dim() != 5: # add the t dim @@ -360,8 +357,6 @@ class SeedVR2InputProcessing(io.ComfyNode): images = cut_videos(images) images = rearrange(images, "b t c h w -> b c t h w") - images = images.to(device) - vae_model = vae_model.to(device) # in case users a non-compatiable number for tiling def make_divisible(val, divisor): @@ -393,7 +388,6 @@ class SeedVR2InputProcessing(io.ComfyNode): latent = rearrange(latent, "b c ... -> b ... c") latent = (latent - shift) * scale - latent = latent.to(offload_device) return io.NodeOutput({"samples": latent}) From 31d358c78c4414090e191b3aaceea42378de8c58 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 4 Jan 2026 19:15:53 +0200 Subject: [PATCH 33/35] rope, attetntion update | vae on cpu warning --- comfy/ldm/modules/attention.py | 12 ++++++++---- comfy/ldm/seedvr/model.py | 19 +++++++++++-------- comfy/ldm/seedvr/vae.py | 14 ++++++++++++-- comfy/sd.py | 4 ++-- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 6163aec22..be253a010 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -416,7 +416,8 @@ except: pass @wrap_attn -def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -506,7 +507,8 @@ else: SDP_BATCH_LIMIT = 2**31 @wrap_attn -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) if var_length: cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) if not skip_reshape: @@ -570,7 +572,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out @wrap_attn -def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs): +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) exception_fallback = False if var_length: if not skip_reshape: @@ -656,7 +659,8 @@ except AttributeError as error: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" @wrap_attn -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = kwargs.get("var_length", False) if var_length: cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) return flash_attn_varlen_func( diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index bd0057332..6c3e9c526 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -13,6 +13,7 @@ from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn import math +from comfy.ldm.flux.math import apply_rope1 class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -443,7 +444,6 @@ def apply_rotary_emb( freqs_seq_dim = None ): dtype = t.dtype - if not exists(freqs_seq_dim): if freqs.ndim == 2 or t.ndim == 3: freqs_seq_dim = 0 @@ -452,20 +452,23 @@ def apply_rotary_emb( seq_len = t.shape[seq_dim] freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) - rot_dim = freqs.shape[-1] - end_index = start_index + rot_dim - - assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats t_left = t[..., :start_index] t_middle = t[..., start_index:end_index] t_right = t[..., end_index:] - freqs = freqs.to(t_middle.device) - t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) + angles = freqs.to(t_middle.device)[..., ::2] + cos = torch.cos(angles) * scale + sin = torch.sin(angles) * scale - out = torch.cat((t_left, t_transformed, t_right), dim=-1) + col0 = torch.stack([cos, sin], dim=-1) + col1 = torch.stack([-sin, cos], dim=-1) + freqs_mat = torch.stack([col0, col1], dim=-1) + t_middle_out = apply_rope1(t_middle, freqs_mat) + out = torch.cat((t_left, t_middle_out, t_right), dim=-1) return out.type(dtype) class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 292958a88..d218b90e9 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -16,6 +16,7 @@ import math from enum import Enum from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND +import logging import comfy.ops ops = comfy.ops.disable_weight_init @@ -446,6 +447,7 @@ class InflatedCausalConv3d(ops.Conv3d): self.memory_device = memory_device self.padding = (0, *self.padding[1:]) self.memory_limit = float("inf") + self.logged_once = False def set_memory_limit(self, value: float): self.memory_limit = value @@ -469,8 +471,16 @@ class InflatedCausalConv3d(ops.Conv3d): return out except RuntimeError: pass - - return super()._conv_forward(input, weight, bias, *args, **kwargs) + except NotImplementedError: + pass + try: + return super()._conv_forward(input, weight, bias, *args, **kwargs) + except NotImplementedError: + # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend + if not self.logged_once: + logging.warning("VAE is on CPU for decoding. This is most likely due to being not enough memory") + self.logged_once = True + return F.conv3d(input, weight, bias, *args, **kwargs) def memory_limit_conv( self, diff --git a/comfy/sd.py b/comfy/sd.py index 69ec40756..102d1a026 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -382,8 +382,8 @@ class VAE: self.latent_channels = 16 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float32] self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) From f588e6c821b75cced850b3eafd9d13b5b7b43428 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 4 Jan 2026 20:30:24 +0200 Subject: [PATCH 34/35] ruff --- comfy/ldm/modules/attention.py | 14 +++++++++----- comfy_extras/nodes_seedvr.py | 21 ++++++++++++++------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 0e3821ef0..dd8c6ba72 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -98,7 +98,7 @@ def var_attn_arg(kwargs): cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) max_seqlen_q = kwargs.get("max_seqlen_q", None) max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) - assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k # feedforward class GEGLU(nn.Module): @@ -449,9 +449,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh k = k.view(1, total_tokens, heads, dim_head) v = v.view(1, total_tokens, heads, dim_head) else: - if q.ndim == 3: q = q.unsqueeze(0) - if k.ndim == 3: k = k.unsqueeze(0) - if v.ndim == 3: v = v.unsqueeze(0) + if q.ndim == 3: + q = q.unsqueeze(0) + if k.ndim == 3: + k = k.unsqueeze(0) + if v.ndim == 3: + v = v.unsqueeze(0) dim_head = q.shape[-1] target_output_shape = (q.shape[1], -1) @@ -526,7 +529,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha k = k.view(k.shape[0], heads, head_dim) v = v.view(v.shape[0], heads, head_dim) - b = q.size(0); dim_head = q.shape[-1] + b = q.size(0) + dim_head = q.shape[-1] q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 945cf966b..c4e8f3958 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -78,8 +78,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora else: out = vae_model.decode_(t_chunk) - if isinstance(out, (tuple, list)): out = out[0] - if out.ndim == 4: out = out.unsqueeze(2) + if isinstance(out, (tuple, list)): + out = out[0] + if out.ndim == 4: + out = out.unsqueeze(2) if pad_amount > 0: if encode: @@ -136,13 +138,17 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora if cur_ov_h > 0: r = get_ramp(cur_ov_h) - if y_idx > 0: w_h[:cur_ov_h] = r - if y_end < h: w_h[-cur_ov_h:] = 1.0 - r + if y_idx > 0: + w_h[:cur_ov_h] = r + if y_end < h: + w_h[-cur_ov_h:] = 1.0 - r if cur_ov_w > 0: r = get_ramp(cur_ov_w) - if x_idx > 0: w_w[:cur_ov_w] = r - if x_end < w: w_w[-cur_ov_w:] = 1.0 - r + if x_idx > 0: + w_w[:cur_ov_w] = r + if x_end < w: + w_w[-cur_ov_w:] = 1.0 - r final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) @@ -335,7 +341,8 @@ class SeedVR2InputProcessing(io.ComfyNode): comfy.model_management.load_models_gpu([vae.patcher]) vae_model = vae.first_stage_model - scale = 0.9152; shift = 0 + scale = 0.9152 + shift = 0 if images.dim() != 5: # add the t dim images = images.unsqueeze(0) images = images.permute(0, 1, 4, 2, 3) From 72ca18acc2fc92d60b942c08c8822c377a1120ed Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 4 Jan 2026 20:32:38 +0200 Subject: [PATCH 35/35] . --- comfy/ldm/seedvr/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index d218b90e9..aa69bfb80 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -478,7 +478,7 @@ class InflatedCausalConv3d(ops.Conv3d): except NotImplementedError: # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend if not self.logged_once: - logging.warning("VAE is on CPU for decoding. This is most likely due to being not enough memory") + logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory") self.logged_once = True return F.conv3d(input, weight, bias, *args, **kwargs)