from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable import einops from einops import rearrange, einsum from torch import nn import torch.nn.functional as F from math import ceil, sqrt, pi import torch from itertools import chain from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding from comfy.ldm.modules.attention import optimized_attention from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn class Cache: def __init__(self, disable=False, prefix="", cache=None): self.cache = cache if cache is not None else {} self.disable = disable self.prefix = prefix def __call__(self, key: str, fn: Callable): if self.disable: return fn() key = self.prefix + key try: result = self.cache[key] except KeyError: result = fn() self.cache[key] = result return result def namespace(self, namespace: str): return Cache( disable=self.disable, prefix=self.prefix + namespace + ".", cache=self.cache, ) def get(self, key: str): key = self.prefix + key return self.cache[key] def repeat_concat( vid: torch.FloatTensor, # (VL ... c) txt: torch.FloatTensor, # (TL ... c) vid_len: torch.LongTensor, # (n*b) txt_len: torch.LongTensor, # (b) txt_repeat: List, # (n) ) -> torch.FloatTensor: # (L ... c) vid = torch.split(vid, vid_len.tolist()) txt = torch.split(txt, txt_len.tolist()) txt = [[x] * n for x, n in zip(txt, txt_repeat)] txt = list(chain(*txt)) return torch.cat(list(chain(*zip(vid, txt)))) def concat( vid: torch.FloatTensor, # (VL ... c) txt: torch.FloatTensor, # (TL ... c) vid_len: torch.LongTensor, # (b) txt_len: torch.LongTensor, # (b) ) -> torch.FloatTensor: # (L ... c) vid = torch.split(vid, vid_len.tolist()) txt = torch.split(txt, txt_len.tolist()) return torch.cat(list(chain(*zip(vid, txt)))) def concat_idx( vid_len: torch.LongTensor, # (b) txt_len: torch.LongTensor, # (b) ) -> Tuple[ Callable, Callable, ]: device = vid_len.device vid_idx = torch.arange(vid_len.sum(), device=device) txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) src_idx = torch.argsort(tgt_idx) return ( lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), ) def repeat_concat_idx( vid_len: torch.LongTensor, # (n*b) txt_len: torch.LongTensor, # (b) txt_repeat: torch.LongTensor, # (n) ) -> Tuple[ Callable, Callable, ]: device = vid_len.device vid_idx = torch.arange(vid_len.sum(), device=device) txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) txt_repeat_list = txt_repeat.tolist() tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) src_idx = torch.argsort(tgt_idx) txt_idx_len = len(tgt_idx) - len(vid_idx) repeat_txt_len = (txt_len * txt_repeat).tolist() def unconcat_coalesce(all): vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) txt_out_coalesced = [] for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) txt_out_coalesced.append(txt) return vid_out, torch.cat(txt_out_coalesced) return ( lambda vid, txt: torch.cat([vid, txt])[tgt_idx], lambda all: unconcat_coalesce(all), ) @dataclass class MMArg: vid: Any txt: Any def safe_pad_operation(x, padding, mode='constant', value=0.0): """Safe padding operation that handles Half precision only for problematic modes""" # Modes qui nécessitent le fix Half precision problematic_modes = ['replicate', 'reflect', 'circular'] if mode in problematic_modes: try: return F.pad(x, padding, mode=mode, value=value) except RuntimeError as e: if "not implemented for 'Half'" in str(e): original_dtype = x.dtype return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) else: raise e else: # Pour 'constant' et autres modes compatibles, pas de fix nécessaire return F.pad(x, padding, mode=mode, value=value) def get_args(key: str, args: List[Any]) -> List[Any]: return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} def make_720Pwindows(size, num_windows, shift = False): t, h, w = size resized_nt, resized_nh, resized_nw = num_windows scale = sqrt((45 * 80) / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) wt = ceil(min(t, 30) / resized_nt) st, sh, sw = (0.5 * shift if wt < t else 0, 0.5 * shift if wh < h else 0, 0.5 * shift if ww < w else 0) nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) if shift: nt += 1 if st > 0 else 0 nh += 1 if sh > 0 else 0 nw += 1 if sw > 0 else 0 windows = [] for iw in range(nw): w_start = max(int((iw - sw) * ww), 0) w_end = min(int((iw - sw + 1) * ww), w) if w_end <= w_start: continue for ih in range(nh): h_start = max(int((ih - sh) * wh), 0) h_end = min(int((ih - sh + 1) * wh), h) if h_end <= h_start: continue for it in range(nt): t_start = max(int((it - st) * wt), 0) t_end = min(int((it - st + 1) * wt), t) if t_end <= t_start: continue windows.append((slice(t_start, t_end), slice(h_start, h_end), slice(w_start, w_end))) return windows class RotaryEmbedding(nn.Module): def __init__( self, dim, custom_freqs, freqs_for = 'lang', theta = 10000, max_freq = 10, num_freqs = 1, learned_freq = False, use_xpos = False, xpos_scale_base = 512, interpolate_factor = 1., theta_rescale_factor = 1., seq_before_head_dim = False, cache_if_possible = True, cache_max_seq_len = 8192 ): super().__init__() theta *= theta_rescale_factor ** (dim / (dim - 2)) self.freqs_for = freqs_for if exists(custom_freqs): freqs = custom_freqs elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() self.cache_if_possible = cache_if_possible self.cache_max_seq_len = cache_max_seq_len self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) self.cached_freqs_seq_len = 0 self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) self.learned_freq = learned_freq # dummy for device self.register_buffer('dummy', torch.tensor(0), persistent = False) # default sequence dimension self.seq_before_head_dim = seq_before_head_dim self.default_seq_dim = -3 if seq_before_head_dim else -2 # interpolation factors assert interpolate_factor >= 1. self.interpolate_factor = interpolate_factor # xpos self.use_xpos = use_xpos if not use_xpos: return scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.scale_base = xpos_scale_base self.register_buffer('scale', scale, persistent = False) self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) self.cached_scales_seq_len = 0 # add apply_rotary_emb as static method self.apply_rotary_emb = staticmethod(apply_rotary_emb) @property def device(self): return self.dummy.device def get_axial_freqs( self, *dims, offsets = None ): Colon = slice(None) all_freqs = [] # handle offset if exists(offsets): assert len(offsets) == len(dims) # get frequencies for each axis for ind, dim in enumerate(dims): offset = 0 if exists(offsets): offset = offsets[ind] if self.freqs_for == 'pixel': pos = torch.linspace(-1, 1, steps = dim, device = self.device) else: pos = torch.arange(dim, device = self.device) pos = pos + offset freqs = self.forward(pos, seq_len = dim) all_axis = [None] * len(dims) all_axis[ind] = Colon new_axis_slice = (Ellipsis, *all_axis, Colon) all_freqs.append(freqs[new_axis_slice]) # concat all freqs all_freqs = torch.broadcast_tensors(*all_freqs) return torch.cat(all_freqs, dim = -1) def forward( self, t, seq_len: int | None = None, offset = 0 ): should_cache = ( self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != 'pixel' and (offset + seq_len) <= self.cache_max_seq_len ) if ( should_cache and \ exists(self.cached_freqs) and \ (offset + seq_len) <= self.cached_freqs_seq_len ): return self.cached_freqs[offset:(offset + seq_len)].detach() freqs = self.freqs freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) freqs = repeat(freqs, '... n -> ... (n r)', r = 2) if should_cache and offset == 0: self.cached_freqs[:seq_len] = freqs.detach() self.cached_freqs_seq_len = seq_len return freqs class RotaryEmbeddingBase(nn.Module): def __init__(self, dim: int, rope_dim: int): super().__init__() self.rope = RotaryEmbedding( dim=dim // rope_dim, freqs_for="pixel", max_freq=256, ) freqs = self.rope.freqs del self.rope.freqs self.rope.register_buffer("freqs", freqs.data) def get_axial_freqs(self, *dims): return self.rope.get_axial_freqs(*dims) class RotaryEmbedding3d(RotaryEmbeddingBase): def __init__(self, dim: int): super().__init__(dim, rope_dim=3) self.mm = False def forward( self, q: torch.FloatTensor, # b h l d k: torch.FloatTensor, # b h l d size: Tuple[int, int, int], ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, ]: T, H, W = size freqs = self.get_axial_freqs(T, H, W) q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) q = apply_rotary_emb(freqs, q.float()).to(q.dtype) k = apply_rotary_emb(freqs, k.float()).to(k.dtype) q = rearrange(q, "b h T H W d -> b h (T H W) d") k = rearrange(k, "b h T H W d -> b h (T H W) d") return q, k class MMRotaryEmbeddingBase(RotaryEmbeddingBase): def __init__(self, dim: int, rope_dim: int): super().__init__(dim, rope_dim) self.rope = RotaryEmbedding( dim=dim // rope_dim, freqs_for="lang", theta=10000, ) freqs = self.rope.freqs del self.rope.freqs self.rope.register_buffer("freqs", freqs.data) self.mm = True def slice_at_dim(t, dim_slice: slice, *, dim): dim += (t.ndim if dim < 0 else 0) colons = [slice(None)] * t.ndim colons[dim] = dim_slice return t[tuple(colons)] # rotary embedding helper functions def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r = 2) x1, x2 = x.unbind(dim = -1) x = torch.stack((-x2, x1), dim = -1) return rearrange(x, '... d r -> ... (d r)') def exists(val): return val is not None def apply_rotary_emb( freqs, t, start_index = 0, scale = 1., seq_dim = -2, freqs_seq_dim = None ): dtype = t.dtype if not exists(freqs_seq_dim): if freqs.ndim == 2 or t.ndim == 3: freqs_seq_dim = 0 if t.ndim == 3 or exists(freqs_seq_dim): seq_len = t.shape[seq_dim] freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) rot_dim = freqs.shape[-1] end_index = start_index + rot_dim assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' t_left = t[..., :start_index] t_middle = t[..., start_index:end_index] t_right = t[..., end_index:] t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) out = torch.cat((t_left, t_transformed, t_right), dim=-1) return out.type(dtype) class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): def __init__(self, dim: int): super().__init__(dim, rope_dim=3) def forward( self, vid_q: torch.FloatTensor, # L h d vid_k: torch.FloatTensor, # L h d vid_shape: torch.LongTensor, # B 3 txt_q: torch.FloatTensor, # L h d txt_k: torch.FloatTensor, # L h d txt_shape: torch.LongTensor, # B 1 cache: Cache, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, ]: vid_freqs, txt_freqs = cache( "mmrope_freqs_3d", lambda: self.get_freqs(vid_shape, txt_shape), ) vid_q = rearrange(vid_q, "L h d -> h L d") vid_k = rearrange(vid_k, "L h d -> h L d") vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) vid_q = rearrange(vid_q, "h L d -> L h d") vid_k = rearrange(vid_k, "h L d -> L h d") txt_q = rearrange(txt_q, "L h d -> h L d") txt_k = rearrange(txt_k, "L h d -> h L d") txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) txt_q = rearrange(txt_q, "h L d -> L h d") txt_k = rearrange(txt_k, "h L d -> L h d") return vid_q, vid_k, txt_q, txt_k def get_freqs( self, vid_shape: torch.LongTensor, txt_shape: torch.LongTensor, ) -> Tuple[ torch.Tensor, torch.Tensor, ]: vid_freqs = self.get_axial_freqs(1024, 128, 128) txt_freqs = self.get_axial_freqs(1024) vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) vid_freq_list.append(vid_freq) txt_freq_list.append(txt_freq) return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) class MMModule(nn.Module): def __init__( self, module: Callable[..., nn.Module], *args, shared_weights: bool = False, vid_only: bool = False, **kwargs, ): super().__init__() self.shared_weights = shared_weights self.vid_only = vid_only if self.shared_weights: assert get_args("vid", args) == get_args("txt", args) assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) else: self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) self.txt = ( module(*get_args("txt", args), **get_kwargs("txt", kwargs)) if not vid_only else None ) def forward( self, vid: torch.FloatTensor, txt: torch.FloatTensor, *args, **kwargs, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, ]: vid_module = self.vid if not self.shared_weights else self.all vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) if not self.vid_only: txt_module = self.txt if not self.shared_weights else self.all txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) return vid, txt def get_na_rope(rope_type: Optional[str], dim: int): # 7b doesn't use rope if rope_type is None: return None if rope_type == "mmrope3d": return NaMMRotaryEmbedding3d(dim=dim) class NaMMAttention(nn.Module): def __init__( self, vid_dim: int, txt_dim: int, heads: int, head_dim: int, qk_bias: bool, qk_norm, qk_norm_eps: float, rope_type: Optional[str], rope_dim: int, shared_weights: bool, **kwargs, ): super().__init__() dim = MMArg(vid_dim, txt_dim) inner_dim = heads * head_dim qkv_dim = inner_dim * 3 self.head_dim = head_dim self.proj_qkv = MMModule( nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights ) self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) self.norm_q = MMModule( qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, ) self.norm_k = MMModule( qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True, shared_weights=shared_weights, ) self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) def forward( self, vid: torch.FloatTensor, # l c txt: torch.FloatTensor, # l c vid_shape: torch.LongTensor, # b 3 txt_shape: torch.LongTensor, # b 1 cache: Cache, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, ]: vid_qkv, txt_qkv = self.proj_qkv(vid, txt) vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) vid_q, vid_k, vid_v = vid_qkv.unbind(1) txt_q, txt_k, txt_v = txt_qkv.unbind(1) vid_q, txt_q = self.norm_q(vid_q, txt_q) vid_k, txt_k = self.norm_k(vid_k, txt_k) if self.rope: if self.rope.mm: vid_q, vid_k, txt_q, txt_k = self.rope( vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache ) else: vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache) vid_len = cache("vid_len", lambda: vid_shape.prod(-1)) txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) all_len = cache("all_len", lambda: vid_len + txt_len) b = len(vid_len) vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)] tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)] q = torch.cat([vq, tq], dim=1) k = torch.cat([vk, tk], dim=1) v = torch.cat([vv, tv], dim=1) _, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len)) attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True) attn = attn.flatten(0, 1) # to continue working with the rest of the code attn = rearrange(attn, "l h d -> l (h d)") vid_out, txt_out = unconcat(attn) vid_out, txt_out = self.proj_out(vid_out, txt_out) return vid_out, txt_out def window( hid: torch.FloatTensor, # (L c) hid_shape: torch.LongTensor, # (b n) window_fn: Callable[[torch.Tensor], List[torch.Tensor]], ): hid = unflatten(hid, hid_shape) hid = list(map(window_fn, hid)) hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) hid, hid_shape = flatten(list(chain(*hid))) return hid, hid_shape, hid_windows def window_idx( hid_shape: torch.LongTensor, # (b n) window_fn: Callable[[torch.Tensor], List[torch.Tensor]], ): hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) tgt_idx = tgt_idx.squeeze(-1) src_idx = torch.argsort(tgt_idx) return ( lambda hid: torch.index_select(hid, 0, tgt_idx), lambda hid: torch.index_select(hid, 0, src_idx), tgt_shape, tgt_windows, ) class NaSwinAttention(NaMMAttention): def __init__( self, *args, window: Union[int, Tuple[int, int, int]], window_method: bool, # shifted or not **kwargs, ): super().__init__(*args, **kwargs) self.window = _triple(window) self.window_method = window_method assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) self.window_op = window_method def forward( self, vid: torch.FloatTensor, # l c txt: torch.FloatTensor, # l c vid_shape: torch.LongTensor, # b 3 txt_shape: torch.LongTensor, # b 1 cache: Cache, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, ]: vid_qkv, txt_qkv = self.proj_qkv(vid, txt) # re-org the input seq for window attn cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") def make_window(x: torch.Tensor): t, h, w, _ = x.shape window_slices = self.window_op((t, h, w), self.window) return [x[st, sh, sw] for (st, sh, sw) in window_slices] window_partition, window_reverse, window_shape, window_count = cache_win( "win_transform", lambda: window_idx(vid_shape, make_window), ) vid_qkv_win = window_partition(vid_qkv) vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) txt_q, txt_k, txt_v = txt_qkv.unbind(1) vid_q, txt_q = self.norm_q(vid_q, txt_q) vid_k, txt_k = self.norm_k(vid_k, txt_k) txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) concat_win, unconcat_win = cache_win( "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) ) # window rope if self.rope: if self.rope.mm: # repeat text q and k for window mmrope _, num_h, _ = txt_q.shape txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") txt_q_repeat = unflatten(txt_q_repeat, txt_shape) txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] txt_q_repeat = list(chain(*txt_q_repeat)) txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") txt_k_repeat = unflatten(txt_k_repeat, txt_shape) txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] txt_k_repeat = list(chain(*txt_k_repeat)) txt_k_repeat, _ = flatten(txt_k_repeat) txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) vid_q, vid_k, txt_q, txt_k = self.rope( vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win ) else: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) out = self.attn( q=concat_win(vid_q, txt_q).bfloat16(), k=concat_win(vid_k, txt_k).bfloat16(), v=concat_win(vid_v, txt_v).bfloat16(), cu_seqlens_q=cache_win( "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() ), cu_seqlens_k=cache_win( "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() ), max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), ).type_as(vid_q) # text pooling vid_out, txt_out = unconcat_win(out) vid_out = rearrange(vid_out, "l h d -> l (h d)") txt_out = rearrange(txt_out, "l h d -> l (h d)") vid_out = window_reverse(vid_out) vid_out, txt_out = self.proj_out(vid_out, txt_out) return vid_out, txt_out class MLP(nn.Module): def __init__( self, dim: int, expand_ratio: int, ): super().__init__() self.proj_in = nn.Linear(dim, dim * expand_ratio) self.act = nn.GELU("tanh") self.proj_out = nn.Linear(dim * expand_ratio, dim) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = self.proj_in(x) x = self.act(x) x = self.proj_out(x) return x class SwiGLUMLP(nn.Module): def __init__( self, dim: int, expand_ratio: int, multiple_of: int = 256, ): super().__init__() hidden_dim = int(2 * dim * expand_ratio / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) self.proj_out = nn.Linear(hidden_dim, dim, bias=False) self.proj_in = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) return x def get_mlp(mlp_type: Optional[str] = "normal"): # 3b and 7b uses different mlp types if mlp_type == "normal": return MLP elif mlp_type == "swiglu": return SwiGLUMLP class NaMMSRTransformerBlock(nn.Module): def __init__( self, *, vid_dim: int, txt_dim: int, emb_dim: int, heads: int, head_dim: int, expand_ratio: int, norm, norm_eps: float, ada, qk_bias: bool, qk_norm, mlp_type: str, shared_weights: bool, rope_type: str, rope_dim: int, is_last_layer: bool, **kwargs, ): super().__init__() dim = MMArg(vid_dim, txt_dim) self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) self.attn = NaSwinAttention( vid_dim=vid_dim, txt_dim=txt_dim, heads=heads, head_dim=head_dim, qk_bias=qk_bias, qk_norm=qk_norm, qk_norm_eps=norm_eps, rope_type=rope_type, rope_dim=rope_dim, shared_weights=shared_weights, window=kwargs.pop("window", None), window_method=kwargs.pop("window_method", None), ) self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) self.mlp = MMModule( get_mlp(mlp_type), dim=dim, expand_ratio=expand_ratio, shared_weights=shared_weights, vid_only=is_last_layer ) self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) self.is_last_layer = is_last_layer def forward( self, vid: torch.FloatTensor, # l c txt: torch.FloatTensor, # l c vid_shape: torch.LongTensor, # b 3 txt_shape: torch.LongTensor, # b 1 emb: torch.FloatTensor, cache: Cache, ) -> Tuple[ torch.FloatTensor, torch.FloatTensor, torch.LongTensor, torch.LongTensor, ]: hid_len = MMArg( cache("vid_len", lambda: vid_shape.prod(-1)), cache("txt_len", lambda: txt_shape.prod(-1)), ) ada_kwargs = { "emb": emb, "hid_len": hid_len, "cache": cache, "branch_tag": MMArg("vid", "txt"), } vid_attn, txt_attn = self.attn_norm(vid, txt) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) return vid_mlp, txt_mlp, vid_shape, txt_shape class PatchOut(nn.Module): def __init__( self, out_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w self.proj = nn.Linear(dim, out_channels * t * h * w) def forward( self, vid: torch.Tensor, ) -> torch.Tensor: t, h, w = self.patch_size vid = self.proj(vid) vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) if t > 1: vid = vid[:, :, (t - 1) :] return vid class NaPatchOut(PatchOut): def forward( self, vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, cache: Cache = Cache(disable=True), # for test ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: cache = cache.namespace("patch") vid_shape_before_patchify = cache.get("vid_shape_before_patchify") t, h, w = self.patch_size vid = self.proj(vid) if not (t == h == w == 1): vid = unflatten(vid, vid_shape) for i in range(len(vid)): vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] vid, vid_shape = flatten(vid) return vid, vid_shape class PatchIn(nn.Module): def __init__( self, in_channels: int, patch_size: Union[int, Tuple[int, int, int]], dim: int, ): super().__init__() t, h, w = _triple(patch_size) self.patch_size = t, h, w self.proj = nn.Linear(in_channels * t * h * w, dim) def forward( self, vid: torch.Tensor, ) -> torch.Tensor: t, h, w = self.patch_size if t > 1: assert vid.size(2) % t == 1 vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) vid = self.proj(vid) return vid class NaPatchIn(PatchIn): def forward( self, vid: torch.Tensor, # l c vid_shape: torch.LongTensor, cache: Cache = Cache(disable=True), # for test ) -> torch.Tensor: cache = cache.namespace("patch") vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) t, h, w = self.patch_size if not (t == h == w == 1): vid = unflatten(vid, vid_shape) for i in range(len(vid)): if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) vid, vid_shape = flatten(vid) vid = self.proj(vid) return vid, vid_shape def expand_dims(x: torch.Tensor, dim: int, ndim: int): shape = x.shape shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] return x.reshape(shape) class AdaSingle(nn.Module): def __init__( self, dim: int, emb_dim: int, layers: List[str], modes: List[str] = ["in", "out"], ): assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" super().__init__() self.dim = dim self.emb_dim = emb_dim self.layers = layers for l in layers: if "in" in modes: self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) self.register_parameter( f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) ) if "out" in modes: self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) def forward( self, hid: torch.FloatTensor, # b ... c emb: torch.FloatTensor, # b d layer: str, mode: str, cache: Cache = Cache(disable=True), branch_tag: str = "", hid_len: Optional[torch.LongTensor] = None, # b ) -> torch.FloatTensor: idx = self.layers.index(layer) emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] emb = expand_dims(emb, 1, hid.ndim + 1) shiftA, scaleA, gateA = emb.unbind(-1) shiftB, scaleB, gateB = ( getattr(self, f"{layer}_shift", None), getattr(self, f"{layer}_scale", None), getattr(self, f"{layer}_gate", None), ) if mode == "in": return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) if mode == "out": return hid.mul_(gateA + gateB) raise NotImplementedError def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): return emb1 if emb2 is None else emb1 + emb2 class TimeEmbedding(nn.Module): def __init__( self, sinusoidal_dim: int, hidden_dim: int, output_dim: int, ): super().__init__() self.sinusoidal_dim = sinusoidal_dim self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) self.proj_hid = nn.Linear(hidden_dim, hidden_dim) self.proj_out = nn.Linear(hidden_dim, output_dim) self.act = nn.SiLU() def forward( self, timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], device: torch.device, dtype: torch.dtype, ) -> torch.FloatTensor: if not torch.is_tensor(timestep): timestep = torch.tensor([timestep], device=device, dtype=dtype) if timestep.ndim == 0: timestep = timestep[None] emb = get_timestep_embedding( timesteps=timestep, embedding_dim=self.sinusoidal_dim, flip_sin_to_cos=False, downscale_freq_shift=0, ) emb = emb.to(dtype) emb = self.proj_in(emb) emb = self.act(emb) emb = self.proj_hid(emb) emb = self.act(emb) emb = self.proj_out(emb) return emb def flatten( hid: List[torch.FloatTensor], # List of (*** c) ) -> Tuple[ torch.FloatTensor, # (L c) torch.LongTensor, # (b n) ]: assert len(hid) > 0 shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) hid = torch.cat([x.flatten(0, -2) for x in hid]) return hid, shape def unflatten( hid: torch.FloatTensor, # (L c) or (L ... c) hid_shape: torch.LongTensor, # (b n) ) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) hid_len = hid_shape.prod(-1) hid = hid.split(hid_len.tolist()) hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] return hid def repeat( hid: torch.FloatTensor, # (L c) hid_shape: torch.LongTensor, # (b n) pattern: str, **kwargs: Dict[str, torch.LongTensor], # (b) ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: hid = unflatten(hid, hid_shape) kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) @dataclass class NaDiTOutput: vid_sample: torch.Tensor class NaDiT(nn.Module): def __init__( self, norm_eps, qk_rope, num_layers, mlp_type, vid_in_channels = 33, vid_out_channels = 16, vid_dim = 2560, txt_in_dim = 5120, heads = 20, head_dim = 128, expand_ratio = 4, qk_bias = False, patch_size = [ 1,2,2 ], shared_qkv: bool = False, shared_mlp: bool = False, window_method: Optional[Tuple[str]] = None, temporal_window_size: int = None, temporal_shifted: bool = False, **kwargs, ): txt_dim = vid_dim emb_dim = vid_dim * 6 block_type = ["mmdit_sr"] * num_layers window = num_layers * [(4,3,3)] ada = AdaSingle norm = RMSNorm qk_norm = RMSNorm if isinstance(block_type, str): block_type = [block_type] * num_layers elif len(block_type) != num_layers: raise ValueError("The ``block_type`` list should equal to ``num_layers``.") super().__init__() self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, dim=vid_dim, ) self.txt_in = ( nn.Linear(txt_in_dim, txt_dim) if txt_in_dim and txt_in_dim != txt_dim else nn.Identity() ) self.emb_in = TimeEmbedding( sinusoidal_dim=256, hidden_dim=max(vid_dim, txt_dim), output_dim=emb_dim, ) if window is None or isinstance(window[0], int): window = [window] * num_layers if window_method is None or isinstance(window_method, str): window_method = [window_method] * num_layers if temporal_window_size is None or isinstance(temporal_window_size, int): temporal_window_size = [temporal_window_size] * num_layers if temporal_shifted is None or isinstance(temporal_shifted, bool): temporal_shifted = [temporal_shifted] * num_layers self.blocks = nn.ModuleList( [ NaMMSRTransformerBlock( vid_dim=vid_dim, txt_dim=txt_dim, emb_dim=emb_dim, heads=heads, head_dim=head_dim, expand_ratio=expand_ratio, norm=norm, norm_eps=norm_eps, ada=ada, qk_bias=qk_bias, qk_rope=qk_rope, qk_norm=qk_norm, shared_qkv=shared_qkv, shared_mlp=shared_mlp, mlp_type=mlp_type, window=window[i], window_method=window_method[i], temporal_window_size=temporal_window_size[i], temporal_shifted=temporal_shifted[i], **kwargs, ) for i in range(num_layers) ] ) self.vid_out = NaPatchOut( out_channels=vid_out_channels, patch_size=patch_size, dim=vid_dim, ) self.need_txt_repeat = block_type[0] in [ "mmdit_stwin", "mmdit_stwin_spatial", "mmdit_stwin_3d_spatial", ] def set_gradient_checkpointing(self, enable: bool): self.gradient_checkpointing = enable def forward( self, vid: torch.FloatTensor, # l c txt: torch.FloatTensor, # l c vid_shape: torch.LongTensor, # b 3 txt_shape: torch.LongTensor, # b 1 timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b disable_cache: bool = True, # for test ): # Text input. if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) # slice vid after patching in when using sequence parallelism txt = self.txt_in(txt) # Video input. # Sequence parallel slicing is done inside patching class. vid, vid_shape = self.vid_in(vid, vid_shape) # Embedding input. emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) # Body cache = Cache(disable=disable_cache) for i, block in enumerate(self.blocks): vid, txt, vid_shape, txt_shape = block( vid=vid, txt=txt, vid_shape=vid_shape, txt_shape=txt_shape, emb=emb, cache=cache, ) vid, vid_shape = self.vid_out(vid, vid_shape, cache) return NaDiTOutput(vid_sample=vid)