From 08d93555d015ee0fd0a921097d5a67fd867698db Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:18:10 +0200 Subject: [PATCH] init --- comfy/latent_formats.py | 4 + comfy/ldm/modules/attention.py | 2 +- comfy/ldm/seedvr/model.py | 1287 ++++++++++++++++++++++++++++++++ comfy/ldm/seedvr/vae.py | 1260 +++++++++++++++++++++++++++++++ comfy/model_base.py | 6 + comfy/model_detection.py | 11 + comfy/sd.py | 14 + comfy/supported_models.py | 17 +- 8 files changed, 2599 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/seedvr/model.py create mode 100644 comfy/ldm/seedvr/vae.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 82d9f9bb8..f260528d4 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -470,3 +470,7 @@ class Hunyuan3Dv2mini(LatentFormat): class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + +class SeedVR2(LatentFormat): + latent_channels = 16 + latent_dimensions = 16 \ No newline at end of file diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee..256f9a989 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -428,7 +428,7 @@ else: SDP_BATCH_LIMIT = 2**31 -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=Falsez): if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py new file mode 100644 index 000000000..40a460d67 --- /dev/null +++ b/comfy/ldm/seedvr/model.py @@ -0,0 +1,1287 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union, List, Dict, Any, Callable +import einops +from einops import rearrange, einsum +from torch import nn +import torch.nn.functional as F +from math import ceil, sqrt, pi +import torch +from itertools import chain +from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding +from comfy.ldm.modules.attention import optimized_attention +from comfy.rmsnorm import RMSNorm +from torch.nn.modules.utils import _triple +from torch import nn + +class Cache: + def __init__(self, disable=False, prefix="", cache=None): + self.cache = cache if cache is not None else {} + self.disable = disable + self.prefix = prefix + + def __call__(self, key: str, fn: Callable): + if self.disable: + return fn() + + key = self.prefix + key + try: + result = self.cache[key] + except KeyError: + result = fn() + self.cache[key] = result + return result + + def namespace(self, namespace: str): + return Cache( + disable=self.disable, + prefix=self.prefix + namespace + ".", + cache=self.cache, + ) + + def get(self, key: str): + key = self.prefix + key + return self.cache[key] + +def repeat_concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: List, # (n) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + txt = [[x] * n for x, n in zip(txt, txt_repeat)] + txt = list(chain(*txt)) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat_idx( + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) + src_idx = torch.argsort(tgt_idx) + return ( + lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), + lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), + ) + + +def repeat_concat_idx( + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: torch.LongTensor, # (n) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + txt_repeat_list = txt_repeat.tolist() + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + src_idx = torch.argsort(tgt_idx) + txt_idx_len = len(tgt_idx) - len(vid_idx) + repeat_txt_len = (txt_len * txt_repeat).tolist() + + def unconcat_coalesce(all): + vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) + txt_out_coalesced = [] + for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): + txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) + txt_out_coalesced.append(txt) + return vid_out, torch.cat(txt_out_coalesced) + + return ( + lambda vid, txt: torch.cat([vid, txt])[tgt_idx], + lambda all: unconcat_coalesce(all), + ) + +@dataclass +class MMArg: + vid: Any + txt: Any + +def safe_pad_operation(x, padding, mode='constant', value=0.0): + """Safe padding operation that handles Half precision only for problematic modes""" + # Modes qui nécessitent le fix Half precision + problematic_modes = ['replicate', 'reflect', 'circular'] + + if mode in problematic_modes: + try: + return F.pad(x, padding, mode=mode, value=value) + except RuntimeError as e: + if "not implemented for 'Half'" in str(e): + original_dtype = x.dtype + return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) + else: + raise e + else: + # Pour 'constant' et autres modes compatibles, pas de fix nécessaire + return F.pad(x, padding, mode=mode, value=value) + + +def get_args(key: str, args: List[Any]) -> List[Any]: + return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] + + +def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} + + +def make_720Pwindows(size, num_windows, shift = False): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + + scale = sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) + wt = ceil(min(t, 30) / resized_nt) + + st, sh, sw = (0.5 * shift if wt < t else 0, + 0.5 * shift if wh < h else 0, + 0.5 * shift if ww < w else 0) + + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) + if shift: + nt += 1 if st > 0 else 0 + nh += 1 if sh > 0 else 0 + nw += 1 if sw > 0 else 0 + + windows = [] + for iw in range(nw): + w_start = max(int((iw - sw) * ww), 0) + w_end = min(int((iw - sw + 1) * ww), w) + if w_end <= w_start: + continue + + for ih in range(nh): + h_start = max(int((ih - sh) * wh), 0) + h_end = min(int((ih - sh + 1) * wh), h) + if h_end <= h_start: + continue + + for it in range(nt): + t_start = max(int((it - st) * wt), 0) + t_end = min(int((it - st + 1) * wt), t) + if t_end <= t_start: + continue + + windows.append((slice(t_start, t_end), + slice(h_start, h_end), + slice(w_start, w_end))) + + return windows + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True, + cache_max_seq_len = 8192 + ): + super().__init__() + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + self.cache_max_seq_len = cache_max_seq_len + + self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_freqs_seq_len = 0 + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.register_buffer('dummy', torch.tensor(0), persistent = False) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + + if not use_xpos: + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + + self.register_buffer('scale', scale, persistent = False) + self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_scales_seq_len = 0 + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def get_axial_freqs( + self, + *dims, + offsets = None + ): + Colon = slice(None) + all_freqs = [] + + # handle offset + + if exists(offsets): + assert len(offsets) == len(dims) + + # get frequencies for each axis + + for ind, dim in enumerate(dims): + + offset = 0 + if exists(offsets): + offset = offsets[ind] + + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + pos = pos + offset + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + # concat all freqs + + all_freqs = torch.broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + def forward( + self, + t, + seq_len: int | None = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and + not self.learned_freq and + exists(seq_len) and + self.freqs_for != 'pixel' and + (offset + seq_len) <= self.cache_max_seq_len + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs_seq_len + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len + + return freqs + +class RotaryEmbeddingBase(nn.Module): + def __init__(self, dim: int, rope_dim: int): + super().__init__() + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="pixel", + max_freq=256, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + + def get_axial_freqs(self, *dims): + return self.rope.get_axial_freqs(*dims) + + +class RotaryEmbedding3d(RotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + self.mm = False + + def forward( + self, + q: torch.FloatTensor, # b h l d + k: torch.FloatTensor, # b h l d + size: Tuple[int, int, int], + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + T, H, W = size + freqs = self.get_axial_freqs(T, H, W) + q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + q = apply_rotary_emb(freqs, q.float()).to(q.dtype) + k = apply_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "b h T H W d -> b h (T H W) d") + k = rearrange(k, "b h T H W d -> b h (T H W) d") + return q, k + + +class MMRotaryEmbeddingBase(RotaryEmbeddingBase): + def __init__(self, dim: int, rope_dim: int): + super().__init__(dim, rope_dim) + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="lang", + theta=10000, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + self.mm = True + +def slice_at_dim(t, dim_slice: slice, *, dim): + dim += (t.ndim if dim < 0 else 0) + colons = [slice(None)] * t.ndim + colons[dim] = dim_slice + return t[tuple(colons)] + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') +def exists(val): + return val is not None + +def apply_rotary_emb( + freqs, + t, + start_index = 0, + scale = 1., + seq_dim = -2, + freqs_seq_dim = None +): + dtype = t.dtype + + if not exists(freqs_seq_dim): + if freqs.ndim == 2 or t.ndim == 3: + freqs_seq_dim = 0 + + if t.ndim == 3 or exists(freqs_seq_dim): + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) + + out = torch.cat((t_left, t_transformed, t_right), dim=-1) + + return out.type(dtype) + +class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + + def forward( + self, + vid_q: torch.FloatTensor, # L h d + vid_k: torch.FloatTensor, # L h d + vid_shape: torch.LongTensor, # B 3 + txt_q: torch.FloatTensor, # L h d + txt_k: torch.FloatTensor, # L h d + txt_shape: torch.LongTensor, # B 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_freqs, txt_freqs = cache( + "mmrope_freqs_3d", + lambda: self.get_freqs(vid_shape, txt_shape), + ) + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) + vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) + vid_q = rearrange(vid_q, "h L d -> L h d") + vid_k = rearrange(vid_k, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) + txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) + txt_q = rearrange(txt_q, "h L d -> L h d") + txt_k = rearrange(txt_k, "h L d -> L h d") + return vid_q, vid_k, txt_q, txt_k + + def get_freqs( + self, + vid_shape: torch.LongTensor, + txt_shape: torch.LongTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + ]: + vid_freqs = self.get_axial_freqs(1024, 128, 128) + txt_freqs = self.get_axial_freqs(1024) + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) + txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) + +class MMModule(nn.Module): + def __init__( + self, + module: Callable[..., nn.Module], + *args, + shared_weights: bool = False, + vid_only: bool = False, + **kwargs, + ): + super().__init__() + self.shared_weights = shared_weights + self.vid_only = vid_only + if self.shared_weights: + assert get_args("vid", args) == get_args("txt", args) + assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + else: + self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + self.txt = ( + module(*get_args("txt", args), **get_kwargs("txt", kwargs)) + if not vid_only + else None + ) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + *args, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.vid if not self.shared_weights else self.all + vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) + if not self.vid_only: + txt_module = self.txt if not self.shared_weights else self.all + txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) + return vid, txt + +def get_na_rope(rope_type: Optional[str], dim: int): + # 7b doesn't use rope + if rope_type is None: + return None + if rope_type == "mmrope3d": + return NaMMRotaryEmbedding3d(dim=dim) + +class NaMMAttention(nn.Module): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_norm, + qk_norm_eps: float, + rope_type: Optional[str], + rope_dim: int, + shared_weights: bool, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + inner_dim = heads * head_dim + qkv_dim = inner_dim * 3 + self.head_dim = head_dim + self.proj_qkv = MMModule( + nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights + ) + self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) + self.norm_q = MMModule( + qk_norm, + dim=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + ) + self.norm_k = MMModule( + qk_norm, + dim=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + ) + + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + if self.rope: + if self.rope.mm: + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache) + + vid_len = cache("vid_len", lambda: vid_shape.prod(-1)) + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + all_len = cache("all_len", lambda: vid_len + txt_len) + + b = len(vid_len) + vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] + tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] + + q = torch.cat([vq, tq], dim=1) + k = torch.cat([vk, tk], dim=1) + v = torch.cat([vv, tv], dim=1) + + _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) + + attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True) + attn = attn.flatten(0, 1) # to continue working with the rest of the code + + attn = rearrange(attn, "l h d -> l (h d)") + vid_out, txt_out = unconcat(attn) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + return vid_out, txt_out + +def window( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid = unflatten(hid, hid_shape) + hid = list(map(window_fn, hid)) + hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid, hid_shape = flatten(list(chain(*hid))) + return hid, hid_shape, hid_windows + +def window_idx( + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + tgt_windows, + ) + +class NaSwinAttention(NaMMAttention): + def __init__( + self, + *args, + window: Union[int, Tuple[int, int, int]], + window_method: bool, # shifted or not + **kwargs, + ): + super().__init__(*args, **kwargs) + self.window = _triple(window) + self.window_method = window_method + assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + + self.window_op = window_method + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + + # re-org the input seq for window attn + cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") + + def make_window(x: torch.Tensor): + t, h, w, _ = x.shape + window_slices = self.window_op((t, h, w), self.window) + return [x[st, sh, sw] for (st, sh, sw) in window_slices] + + window_partition, window_reverse, window_shape, window_count = cache_win( + "win_transform", + lambda: window_idx(vid_shape, make_window), + ) + vid_qkv_win = window_partition(vid_qkv) + + vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + + vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + concat_win, unconcat_win = cache_win( + "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) + ) + + # window rope + if self.rope: + if self.rope.mm: + # repeat text q and k for window mmrope + _, num_h, _ = txt_q.shape + txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = unflatten(txt_q_repeat, txt_shape) + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = list(chain(*txt_q_repeat)) + txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) + txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + + txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = unflatten(txt_k_repeat, txt_shape) + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = list(chain(*txt_k_repeat)) + txt_k_repeat, _ = flatten(txt_k_repeat) + txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + + out = self.attn( + q=concat_win(vid_q, txt_q).bfloat16(), + k=concat_win(vid_k, txt_k).bfloat16(), + v=concat_win(vid_v, txt_v).bfloat16(), + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), + max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), + ).type_as(vid_q) + + # text pooling + vid_out, txt_out = unconcat_win(out) + + vid_out = rearrange(vid_out, "l h d -> l (h d)") + txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = window_reverse(vid_out) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + + return vid_out, txt_out + +class MLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + ): + super().__init__() + self.proj_in = nn.Linear(dim, dim * expand_ratio) + self.act = nn.GELU("tanh") + self.proj_out = nn.Linear(dim * expand_ratio, dim) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_in(x) + x = self.act(x) + x = self.proj_out(x) + return x + + +class SwiGLUMLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + multiple_of: int = 256, + ): + super().__init__() + hidden_dim = int(2 * dim * expand_ratio / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) + self.proj_out = nn.Linear(hidden_dim, dim, bias=False) + self.proj_in = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) + return x + +def get_mlp(mlp_type: Optional[str] = "normal"): + # 3b and 7b uses different mlp types + if mlp_type == "normal": + return MLP + elif mlp_type == "swiglu": + return SwiGLUMLP + +class NaMMSRTransformerBlock(nn.Module): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm, + norm_eps: float, + ada, + qk_bias: bool, + qk_norm, + mlp_type: str, + shared_weights: bool, + rope_type: str, + rope_dim: int, + is_last_layer: bool, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) + + self.attn = NaSwinAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + rope_type=rope_type, + rope_dim=rope_dim, + shared_weights=shared_weights, + window=kwargs.pop("window", None), + window_method=kwargs.pop("window_method", None), + ) + + self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) + self.mlp = MMModule( + get_mlp(mlp_type), + dim=dim, + expand_ratio=expand_ratio, + shared_weights=shared_weights, + vid_only=is_last_layer + ) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) + self.is_last_layer = is_last_layer + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + emb: torch.FloatTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.LongTensor, + torch.LongTensor, + ]: + hid_len = MMArg( + cache("vid_len", lambda: vid_shape.prod(-1)), + cache("txt_len", lambda: txt_shape.prod(-1)), + ) + ada_kwargs = { + "emb": emb, + "hid_len": hid_len, + "cache": cache, + "branch_tag": MMArg("vid", "txt"), + } + + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp, vid_shape, txt_shape + +class PatchOut(nn.Module): + def __init__( + self, + out_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(dim, out_channels * t * h * w) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = self.proj(vid) + vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + if t > 1: + vid = vid[:, :, (t - 1) :] + return vid + +class NaPatchOut(PatchOut): + def forward( + self, + vid: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, + ]: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache.get("vid_shape_before_patchify") + + t, h, w = self.patch_size + vid = self.proj(vid) + + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] + vid, vid_shape = flatten(vid) + + return vid, vid_shape + +class PatchIn(nn.Module): + def __init__( + self, + in_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(in_channels * t * h * w, dim) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + if t > 1: + assert vid.size(2) % t == 1 + vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) + vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + vid = self.proj(vid) + return vid + +class NaPatchIn(PatchIn): + def forward( + self, + vid: torch.Tensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> torch.Tensor: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) + t, h, w = self.patch_size + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) + vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + vid, vid_shape = flatten(vid) + + vid = self.proj(vid) + return vid, vid_shape + +def expand_dims(x: torch.Tensor, dim: int, ndim: int): + shape = x.shape + shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] + return x.reshape(shape) + + +class AdaSingle(nn.Module): + def __init__( + self, + dim: int, + emb_dim: int, + layers: List[str], + modes: List[str] = ["in", "out"], + ): + assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + super().__init__() + self.dim = dim + self.emb_dim = emb_dim + self.layers = layers + for l in layers: + if "in" in modes: + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter( + f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) + ) + if "out" in modes: + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) + + def forward( + self, + hid: torch.FloatTensor, # b ... c + emb: torch.FloatTensor, # b d + layer: str, + mode: str, + cache: Cache = Cache(disable=True), + branch_tag: str = "", + hid_len: Optional[torch.LongTensor] = None, # b + ) -> torch.FloatTensor: + idx = self.layers.index(layer) + emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = expand_dims(emb, 1, hid.ndim + 1) + + shiftA, scaleA, gateA = emb.unbind(-1) + shiftB, scaleB, gateB = ( + getattr(self, f"{layer}_shift", None), + getattr(self, f"{layer}_scale", None), + getattr(self, f"{layer}_gate", None), + ) + + if mode == "in": + return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) + if mode == "out": + return hid.mul_(gateA + gateB) + raise NotImplementedError + + +def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): + return emb1 if emb2 is None else emb1 + emb2 + + +class TimeEmbedding(nn.Module): + def __init__( + self, + sinusoidal_dim: int, + hidden_dim: int, + output_dim: int, + ): + super().__init__() + self.sinusoidal_dim = sinusoidal_dim + self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) + self.proj_hid = nn.Linear(hidden_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.act = nn.SiLU() + + def forward( + self, + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.FloatTensor: + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=device, dtype=dtype) + if timestep.ndim == 0: + timestep = timestep[None] + + emb = get_timestep_embedding( + timesteps=timestep, + embedding_dim=self.sinusoidal_dim, + flip_sin_to_cos=False, + downscale_freq_shift=0, + ) + emb = emb.to(dtype) + emb = self.proj_in(emb) + emb = self.act(emb) + emb = self.proj_hid(emb) + emb = self.act(emb) + emb = self.proj_out(emb) + return emb + +def flatten( + hid: List[torch.FloatTensor], # List of (*** c) +) -> Tuple[ + torch.FloatTensor, # (L c) + torch.LongTensor, # (b n) +]: + assert len(hid) > 0 + shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + hid = torch.cat([x.flatten(0, -2) for x in hid]) + return hid, shape + + +def unflatten( + hid: torch.FloatTensor, # (L c) or (L ... c) + hid_shape: torch.LongTensor, # (b n) +) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) + hid_len = hid_shape.prod(-1) + hid = hid.split(hid_len.tolist()) + hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] + return hid + +def repeat( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, torch.LongTensor], # (b) +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + hid = unflatten(hid, hid_shape) + kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] + return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) + +@dataclass +class NaDiTOutput: + vid_sample: torch.Tensor + + +class NaDiT(nn.Module): + + def __init__( + self, + norm_eps, + qk_rope, + num_layers, + mlp_type, + vid_in_channels = 33, + vid_out_channels = 16, + vid_dim = 2560, + txt_in_dim = 5120, + heads = 20, + head_dim = 128, + expand_ratio = 4, + qk_bias = False, + patch_size = [ 1,2,2 ], + shared_qkv: bool = False, + shared_mlp: bool = False, + window_method: Optional[Tuple[str]] = None, + temporal_window_size: int = None, + temporal_shifted: bool = False, + **kwargs, + ): + txt_dim = vid_dim + emb_dim = vid_dim * 6 + block_type = ["mmdit_sr"] * num_layers + window = num_layers * [(4,3,3)] + ada = AdaSingle + norm = RMSNorm + qk_norm = RMSNorm + if isinstance(block_type, str): + block_type = [block_type] * num_layers + elif len(block_type) != num_layers: + raise ValueError("The ``block_type`` list should equal to ``num_layers``.") + super().__init__() + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + ) + self.txt_in = ( + nn.Linear(txt_in_dim, txt_dim) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=256, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + if window_method is None or isinstance(window_method, str): + window_method = [window_method] * num_layers + if temporal_window_size is None or isinstance(temporal_window_size, int): + temporal_window_size = [temporal_window_size] * num_layers + if temporal_shifted is None or isinstance(temporal_shifted, bool): + temporal_shifted = [temporal_shifted] * num_layers + + self.blocks = nn.ModuleList( + [ + NaMMSRTransformerBlock( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + shared_qkv=shared_qkv, + shared_mlp=shared_mlp, + mlp_type=mlp_type, + window=window[i], + window_method=window_method[i], + temporal_window_size=temporal_window_size[i], + temporal_shifted=temporal_shifted[i], + **kwargs, + ) + for i in range(num_layers) + ] + ) + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + ) + + self.need_txt_repeat = block_type[0] in [ + "mmdit_stwin", + "mmdit_stwin_spatial", + "mmdit_stwin_3d_spatial", + ] + + def set_gradient_checkpointing(self, enable: bool): + self.gradient_checkpointing = enable + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b + disable_cache: bool = True, # for test + ): + # Text input. + if txt_shape.size(-1) == 1 and self.need_txt_repeat: + txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + # slice vid after patching in when using sequence parallelism + txt = self.txt_in(txt) + + # Video input. + # Sequence parallel slicing is done inside patching class. + vid, vid_shape = self.vid_in(vid, vid_shape) + + # Embedding input. + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + + # Body + cache = Cache(disable=disable_cache) + for i, block in enumerate(self.blocks): + vid, txt, vid_shape, txt_shape = block( + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + vid, vid_shape = self.vid_out(vid, vid_shape, cache) + return NaDiTOutput(vid_sample=vid) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py new file mode 100644 index 000000000..eb74e9442 --- /dev/null +++ b/comfy/ldm/seedvr/vae.py @@ -0,0 +1,1260 @@ +from contextlib import nullcontext +from typing import Literal, Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention +from diffusers.models.upsampling import Upsample2D +from einops import rearrange + +from model import safe_pad_operation +from comfy.ldm.hunyuan3d.vae import DiagonalGaussianDistribution + +class SpatialNorm(nn.Module): + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + +def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + if isinstance(norm_layer, (nn.LayerNorm, nn.RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x.to(input_dtype) + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x.to(input_dtype) + if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4: + return norm_layer(x).to(input_dtype) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + memory_occupy = x.numel() * x.element_size() / 1024**3 + if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > float("inf"): # TODO: this may be set dynamically from the vae + num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) + assert norm_layer.num_groups % num_chunks == 0 + num_groups_per_chunk = norm_layer.num_groups // num_chunks + + x = list(x.chunk(num_chunks, dim=1)) + weights = norm_layer.weight.chunk(num_chunks, dim=0) + biases = norm_layer.bias.chunk(num_chunks, dim=0) + for i, (w, b) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + x[i] = x[i].to(input_dtype) + x = torch.cat(x, dim=1) + else: + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x.to(input_dtype) + raise NotImplementedError + +def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): + """Safe interpolate operation that handles Half precision for problematic modes""" + # Modes qui peuvent causer des problèmes avec Half precision + problematic_modes = ['bilinear', 'bicubic', 'trilinear'] + + if mode in problematic_modes: + try: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + except RuntimeError as e: + if ("not implemented for 'Half'" in str(e) or + "compute_indices_weights" in str(e)): + original_dtype = x.dtype + return F.interpolate( + x.float(), + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ).to(original_dtype) + else: + raise e + else: + # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + +_receptive_field_t = Literal["half", "full"] + +class InflatedCausalConv3d(nn.Conv3d): + def __init__( + self, + *args, + inflation_mode, + **kwargs, + ): + self.inflation_mode = inflation_mode + self.memory = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.padding = (0, *self.padding[1:]) + self.memory_limit = float("inf") + + def forward( + self, + input, + ): + return super().forward(input) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + # wirdly inflation_mode is pad, which would cause an assert error + #if self.inflation_mode != "none": + # state_dict = modify_state_dict( + # self, + # state_dict, + # prefix, + # inflate_weight_fn=inflate_weight, + # inflate_bias_fn=inflate_bias, + # ) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + (strict and self.inflation_mode == "none"), + missing_keys, + unexpected_keys, + error_msgs, + ) + +class Upsample3D(nn.Module): + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + temporal_up: bool = False, + spatial_up: bool = True, + slicing: bool = False, + interpolate = True, + name: str = "conv", + use_conv_transpose = False, + use_conv: bool = False, + padding = 1, + bias = True, + kernel_size = None, + **kwargs, + ): + super().__init__() + self.interpolate = interpolate + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv_transpose = use_conv_transpose + self.use_conv = use_conv + self.name = name + + self.conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + self.conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + conv = self.conv if self.name == "conv" else self.Conv2d_0 + + assert type(conv) is not nn.ConvTranspose2d + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + 3, + padding=1, + inflation_mode=inflation_mode, + ) + + self.temporal_up = temporal_up + self.spatial_up = spatial_up + self.temporal_ratio = 2 if temporal_up else 1 + self.spatial_ratio = 2 if spatial_up else 1 + self.slicing = slicing + + assert not self.interpolate + # [Override] MAGViT v2 implementation + if not self.interpolate: + upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio + self.upscale_conv = nn.Conv3d( + self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 + ) + identity = ( + torch.eye(self.channels) + .repeat(upscale_ratio, 1) + .reshape_as(self.upscale_conv.weight) + ) + self.upscale_conv.weight.data.copy_(identity) + + if self.name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + self.norm = False + + def forward( + self, + hidden_states: torch.FloatTensor, + **kwargs, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + if self.slicing: + split_size = hidden_states.size(2) // 2 + hidden_states = list( + hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) + ) + else: + hidden_states = [hidden_states] + + for i in range(len(hidden_states)): + hidden_states[i] = self.upscale_conv(hidden_states[i]) + hidden_states[i] = rearrange( + hidden_states[i], + "b (x y z c) f h w -> b c (f z) (h x) (w y)", + x=self.spatial_ratio, + y=self.spatial_ratio, + z=self.temporal_ratio, + ) + + if not self.slicing: + hidden_states = hidden_states[0] + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + if not self.slicing: + return hidden_states + else: + return torch.cat(hidden_states, dim=2) + + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution.""" + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + spatial_down: bool = False, + temporal_down: bool = False, + name: str = "conv", + padding = 1, + **kwargs, + ): + super().__init__() + self.padding = padding + self.name = name + self.channels = channels + self.out_channels = out_channels or channels + conv = self.conv + self.temporal_down = temporal_down + self.spatial_down = spatial_down + + self.temporal_ratio = 2 if temporal_down else 1 + self.spatial_ratio = 2 if spatial_down else 1 + + self.temporal_kernel = 3 if temporal_down else 1 + self.spatial_kernel = 3 if spatial_down else 1 + + if type(conv) in [nn.Conv2d]: + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + padding=( + 1 if self.temporal_down else 0, + self.padding if self.spatial_down else 0, + self.padding if self.spatial_down else 0, + ), + inflation_mode=inflation_mode, + ) + elif type(conv) is nn.AvgPool2d: + assert self.channels == self.out_channels + conv = nn.AvgPool3d( + kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + ) + else: + raise NotImplementedError + + if self.name == "conv": + self.Conv2d_0 = conv + self.conv = conv + else: + self.conv = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + **kwargs, + ) -> torch.FloatTensor: + + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv and self.padding == 0 and self.spatial_down: + pad = (0, 1, 0, 1) + hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + skip_time_act: bool = False, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + slicing: bool = False, + **kwargs, + ): + super().__init__() + self.up = up + self.down = down + self.use_in_shortcut = use_in_shortcut + self.output_scale_factor = output_scale_factor + self.skip_time_act = skip_time_act + self.nonlinearity = nn.SiLU() + if temb_channels is not None: + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + self.conv1 = InflatedCausalConv3d( + self.in_channels, + self.out_channels, + kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), + stride=1, + padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), + inflation_mode=inflation_mode, + ) + + self.conv2 = InflatedCausalConv3d( + self.out_channels, + self.conv2.out_channels, + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + if self.up: + self.upsample = Upsample3D( + self.in_channels, + use_conv=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + elif self.down: + self.downsample = Downsample3D( + self.in_channels, + use_conv=False, + padding=1, + name="op", + inflation_mode=inflation_mode, + ) + + if self.use_in_shortcut: + self.conv_shortcut = InflatedCausalConv3d( + self.in_channels, + self.conv_shortcut.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=(self.conv_shortcut.bias is not None), + inflation_mode=inflation_mode, + ) + + def forward( + self, input_tensor, temb, **kwargs + ): + hidden_states = input_tensor + + hidden_states = causal_norm_wrapper(self.norm1, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if temb is not None: + hidden_states = hidden_states + temb + + hidden_states = causal_norm_wrapper(self.norm2, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_down: bool = True, + spatial_down: bool = True, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + # [Override] Replace module. + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + temporal_down=temporal_down, + spatial_down=spatial_down, + inflation_mode=inflation_mode, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + **kwargs, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temporal(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up: bool = True, + spatial_up: bool = True, + slicing: bool = False, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + slicing=slicing, + ) + ) + + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_upsample: + # [Override] Replace module & use learnable upsample + self.upsamplers = nn.ModuleList( + [ + Upsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + temporal_up=temporal_up, + spatial_up=spatial_up, + interpolate=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + ] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temporal(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ] + attentions = [] + + if attention_head_dim is None: + print( + f"It is not recommend to pass `attention_head_dim=None`. " + f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=( + resnet_groups if resnet_time_scale_shift == "default" else None + ), + spatial_norm_dim=( + temb_channels if resnet_time_scale_shift == "spatial" else None + ), + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + video_length, frame_height, frame_width = hidden_states.size()[-3:] + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn(hidden_states, temb=temb) + hidden_states = rearrange( + hidden_states, "(b f) c h w -> b c f h w", f=video_length + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class Encoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + # [Override] add extra_cond_dim, temporal down num + temporal_down_num: int = 2, + extra_cond_dim: int = None, + gradient_checkpoint: bool = False, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_down_num = temporal_down_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + self.extra_cond_dim = extra_cond_dim + + self.conv_extra_cond = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + # [Override] to support temporal down block design + is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 + # Note: take the last ones + + assert down_block_type == "DownEncoderBlock3D" + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + # Note: Don't know why set it as 0 + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temporal_down=is_temporal_down_block, + spatial_down=True, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.down_blocks.append(down_block) + + def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + self.conv_extra_cond.append( + zero_module( + nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) + ) + if self.extra_cond_dim is not None and self.extra_cond_dim > 0 + else None + ) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = InflatedCausalConv3d( + block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + def forward( + self, + sample: torch.FloatTensor, + extra_cond=None, + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + sample = self.conv_in(sample) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample) + + else: + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = down_block(sample) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder3D(nn.Module): + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + # [Override] add temporal up block + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_up_num = temporal_up_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + print(f"slicing_up_num: {slicing_up_num}") + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + is_temporal_up_block = i < self.temporal_up_num + is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num + # Note: Keep symmetric + + assert up_block_type == "UpDecoderBlock3D" + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=norm_type, + temb_channels=temb_channels, + temporal_up=is_temporal_up_block, + slicing=is_slicing_up_block, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedCausalConv3d( + block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + # Note: Just copy from Decoder. + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + +class VideoAutoencoderKL(nn.Module): + """ + We simply inherit the model code from diffusers + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + attention: bool = True, + temporal_scale_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "full", + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + *args, + **kwargs, + ): + extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + extra_cond_dim=extra_cond_dim, + # [Override] add temporal_down_num parameter + temporal_down_num=temporal_scale_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + # [Override] add temporal_up_num parameter + temporal_up_num=temporal_scale_num, + slicing_up_num=slicing_up_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + self.quant_conv = ( + InflatedCausalConv3d( + in_channels=2 * latent_channels, + out_channels=2 * latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + InflatedCausalConv3d( + in_channels=latent_channels, + out_channels=latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_post_quant_conv + else None + ) + + # A hacky way to remove attention. + if not attention: + self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + + def encode(self, x: torch.FloatTensor, return_dict: bool = True): + h = self.slicing_encode(x) + posterior = DiagonalGaussianDistribution(h).sample() + + if not return_dict: + return (posterior,) + + return posterior + + def decode( + self, z: torch.Tensor, return_dict: bool = True + ): + decoded = self.slicing_decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def _encode( + self, x: torch.Tensor + ) -> torch.Tensor: + _x = x.to(self.device) + h = self.encoder(_x,) + if self.quant_conv is not None: + output = self.quant_conv(h) + else: + output = h + return output.to(x.device) + + def _decode( + self, z: torch.Tensor + ) -> torch.Tensor: + _z = z.to(self.device) + if self.post_quant_conv is not None: + _z = self.post_quant_conv(_z) + output = self.decoder(_z) + return output.to(z.device) + + def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: + return self._encode(x) + + def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: + return self._decode(z) + + def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs + ): + # x: [b c t h w] + if mode == "encode": + h = self.encode(x) + return h.latent_dist + elif mode == "decode": + h = self.decode(x) + return h.sample + else: + h = self.encode(x) + h = self.decode(h.latent_dist.mode()) + return h.sample + + def load_state_dict(self, state_dict, strict=False): + # Newer version of diffusers changed the model keys, + # causing incompatibility with old checkpoints. + # They provided a method for conversion. + # We call conversion before loading state_dict. + convert_deprecated_attention_blocks = getattr( + self, "_convert_deprecated_attention_blocks", None + ) + if callable(convert_deprecated_attention_blocks): + convert_deprecated_attention_blocks(state_dict) + return super().load_state_dict(state_dict, strict) + + +class VideoAutoencoderKLWrapper(VideoAutoencoderKL): + def __init__( + self, + *args, + spatial_downsample_factor = 8, + temporal_downsample_factor = 4, + freeze_encoder = True, + **kwargs, + ): + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + self.freeze_encoder = freeze_encoder + super().__init__(*args, **kwargs) + + def forward(self, x: torch.FloatTensor): + with torch.no_grad() if self.freeze_encoder else nullcontext(): + z, p = self.encode(x) + x = self.decode(z).sample + return x, z, p + + def encode(self, x: torch.FloatTensor): + if x.ndim == 4: + x = x.unsqueeze(2) + p = super().encode(x).latent_dist + z = p.sample().squeeze(2) + return z, p + + def decode(self, z: torch.FloatTensor): + if z.ndim == 4: + z = z.unsqueeze(2) + x = super().decode(z).sample.squeeze(2) + return x + + def preprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + assert x.ndim == 4 or x.size(2) % 4 == 1 + return x + + def postprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + return x + + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): + # TODO + #set_norm_limit(norm_max_mem) + for m in self.modules(): + if isinstance(m, InflatedCausalConv3d): + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index 4392355ea..bbab8627a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.seedvr.model import comfy.model_management import comfy.patcher_extension @@ -793,6 +794,11 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out + +class SeedVR2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) + # TODO: extra_conds could be needed to add class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 18232ade3..600c089fa 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -341,6 +341,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_dims"] = [32, 32, 32] dit_config["axes_lens"] = [300, 512, 512] return dit_config + + elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b + dit_config = {} + dit_config["vid_dim"] = 2560 + dit_config["heads"] = 20 + dit_config["num_layers"] = 32 + dit_config["norm_eps"] = 1.0e-05 + dit_config["qk_rope"] = None + dit_config["mlp_type"] = "swiglu" + + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} diff --git a/comfy/sd.py b/comfy/sd.py index 5b95cf75a..79b17073f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.hunyuan3d.vae +import comfy.ldm.seedvr.vae import comfy.ldm.ace.vae.music_dcae_pipeline import yaml import math @@ -391,6 +392,19 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + + elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 + self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() + ddconfig["conv3d"] = True + ddconfig["time_compress"] = 4 + self.memory_used_decode = lambda shape, dtype: (2000 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[2], 5) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float32] + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2669ca01e..2301b1188 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1153,6 +1153,21 @@ class Chroma(supported_models_base.BASE): pref = self.text_encoder_key_prefix[0] t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) + +class SeedVR2(supported_models_base.Base): + unet_config = { + "image_mode": "seedvr2" + } + latent_format = comfy.latent_formats.SeedVR2 + + vae_key_prefix = ["vae."] + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix = "", device=None): + out = model_base.SeedVR2(self, device=device) + return out + def clip_target(self, state_dict={}): + return None class ACEStep(supported_models_base.BASE): unet_config = { @@ -1217,6 +1232,6 @@ class Omnigen2(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.LuminaTokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, SeedVR2] models += [SVD_img2vid]