ComfyUI/comfy/ldm/seedvr/model.py
Yousef Rafat 08d93555d0 init
2025-12-06 23:18:10 +02:00

1288 lines
42 KiB
Python

from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops
from einops import rearrange, einsum
from torch import nn
import torch.nn.functional as F
from math import ceil, sqrt, pi
import torch
from itertools import chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
from comfy.ldm.modules.attention import optimized_attention
from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple
from torch import nn
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
self.cache = cache if cache is not None else {}
self.disable = disable
self.prefix = prefix
def __call__(self, key: str, fn: Callable):
if self.disable:
return fn()
key = self.prefix + key
try:
result = self.cache[key]
except KeyError:
result = fn()
self.cache[key] = result
return result
def namespace(self, namespace: str):
return Cache(
disable=self.disable,
prefix=self.prefix + namespace + ".",
cache=self.cache,
)
def get(self, key: str):
key = self.prefix + key
return self.cache[key]
def repeat_concat(
vid: torch.FloatTensor, # (VL ... c)
txt: torch.FloatTensor, # (TL ... c)
vid_len: torch.LongTensor, # (n*b)
txt_len: torch.LongTensor, # (b)
txt_repeat: List, # (n)
) -> torch.FloatTensor: # (L ... c)
vid = torch.split(vid, vid_len.tolist())
txt = torch.split(txt, txt_len.tolist())
txt = [[x] * n for x, n in zip(txt, txt_repeat)]
txt = list(chain(*txt))
return torch.cat(list(chain(*zip(vid, txt))))
def concat(
vid: torch.FloatTensor, # (VL ... c)
txt: torch.FloatTensor, # (TL ... c)
vid_len: torch.LongTensor, # (b)
txt_len: torch.LongTensor, # (b)
) -> torch.FloatTensor: # (L ... c)
vid = torch.split(vid, vid_len.tolist())
txt = torch.split(txt, txt_len.tolist())
return torch.cat(list(chain(*zip(vid, txt))))
def concat_idx(
vid_len: torch.LongTensor, # (b)
txt_len: torch.LongTensor, # (b)
) -> Tuple[
Callable,
Callable,
]:
device = vid_len.device
vid_idx = torch.arange(vid_len.sum(), device=device)
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len)
src_idx = torch.argsort(tgt_idx)
return (
lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx),
lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]),
)
def repeat_concat_idx(
vid_len: torch.LongTensor, # (n*b)
txt_len: torch.LongTensor, # (b)
txt_repeat: torch.LongTensor, # (n)
) -> Tuple[
Callable,
Callable,
]:
device = vid_len.device
vid_idx = torch.arange(vid_len.sum(), device=device)
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
txt_repeat_list = txt_repeat.tolist()
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat)
src_idx = torch.argsort(tgt_idx)
txt_idx_len = len(tgt_idx) - len(vid_idx)
repeat_txt_len = (txt_len * txt_repeat).tolist()
def unconcat_coalesce(all):
vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len])
txt_out_coalesced = []
for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list):
txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1)
txt_out_coalesced.append(txt)
return vid_out, torch.cat(txt_out_coalesced)
return (
lambda vid, txt: torch.cat([vid, txt])[tgt_idx],
lambda all: unconcat_coalesce(all),
)
@dataclass
class MMArg:
vid: Any
txt: Any
def safe_pad_operation(x, padding, mode='constant', value=0.0):
"""Safe padding operation that handles Half precision only for problematic modes"""
# Modes qui nécessitent le fix Half precision
problematic_modes = ['replicate', 'reflect', 'circular']
if mode in problematic_modes:
try:
return F.pad(x, padding, mode=mode, value=value)
except RuntimeError as e:
if "not implemented for 'Half'" in str(e):
original_dtype = x.dtype
return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype)
else:
raise e
else:
# Pour 'constant' et autres modes compatibles, pas de fix nécessaire
return F.pad(x, padding, mode=mode, value=value)
def get_args(key: str, args: List[Any]) -> List[Any]:
return [getattr(v, key) if isinstance(v, MMArg) else v for v in args]
def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
def make_720Pwindows(size, num_windows, shift = False):
t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
scale = sqrt((45 * 80) / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
wt = ceil(min(t, 30) / resized_nt)
st, sh, sw = (0.5 * shift if wt < t else 0,
0.5 * shift if wh < h else 0,
0.5 * shift if ww < w else 0)
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww)
if shift:
nt += 1 if st > 0 else 0
nh += 1 if sh > 0 else 0
nw += 1 if sw > 0 else 0
windows = []
for iw in range(nw):
w_start = max(int((iw - sw) * ww), 0)
w_end = min(int((iw - sw + 1) * ww), w)
if w_end <= w_start:
continue
for ih in range(nh):
h_start = max(int((ih - sh) * wh), 0)
h_end = min(int((ih - sh + 1) * wh), h)
if h_end <= h_start:
continue
for it in range(nt):
t_start = max(int((it - st) * wt), 0)
t_end = min(int((it - st + 1) * wt), t)
if t_end <= t_start:
continue
windows.append((slice(t_start, t_end),
slice(h_start, h_end),
slice(w_start, w_end)))
return windows
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs,
freqs_for = 'lang',
theta = 10000,
max_freq = 10,
num_freqs = 1,
learned_freq = False,
use_xpos = False,
xpos_scale_base = 512,
interpolate_factor = 1.,
theta_rescale_factor = 1.,
seq_before_head_dim = False,
cache_if_possible = True,
cache_max_seq_len = 8192
):
super().__init__()
theta *= theta_rescale_factor ** (dim / (dim - 2))
self.freqs_for = freqs_for
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
self.cache_if_possible = cache_if_possible
self.cache_max_seq_len = cache_max_seq_len
self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False)
self.cached_freqs_seq_len = 0
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
self.learned_freq = learned_freq
# dummy for device
self.register_buffer('dummy', torch.tensor(0), persistent = False)
# default sequence dimension
self.seq_before_head_dim = seq_before_head_dim
self.default_seq_dim = -3 if seq_before_head_dim else -2
# interpolation factors
assert interpolate_factor >= 1.
self.interpolate_factor = interpolate_factor
# xpos
self.use_xpos = use_xpos
if not use_xpos:
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = xpos_scale_base
self.register_buffer('scale', scale, persistent = False)
self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False)
self.cached_scales_seq_len = 0
# add apply_rotary_emb as static method
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
@property
def device(self):
return self.dummy.device
def get_axial_freqs(
self,
*dims,
offsets = None
):
Colon = slice(None)
all_freqs = []
# handle offset
if exists(offsets):
assert len(offsets) == len(dims)
# get frequencies for each axis
for ind, dim in enumerate(dims):
offset = 0
if exists(offsets):
offset = offsets[ind]
if self.freqs_for == 'pixel':
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
else:
pos = torch.arange(dim, device = self.device)
pos = pos + offset
freqs = self.forward(pos, seq_len = dim)
all_axis = [None] * len(dims)
all_axis[ind] = Colon
new_axis_slice = (Ellipsis, *all_axis, Colon)
all_freqs.append(freqs[new_axis_slice])
# concat all freqs
all_freqs = torch.broadcast_tensors(*all_freqs)
return torch.cat(all_freqs, dim = -1)
def forward(
self,
t,
seq_len: int | None = None,
offset = 0
):
should_cache = (
self.cache_if_possible and
not self.learned_freq and
exists(seq_len) and
self.freqs_for != 'pixel' and
(offset + seq_len) <= self.cache_max_seq_len
)
if (
should_cache and \
exists(self.cached_freqs) and \
(offset + seq_len) <= self.cached_freqs_seq_len
):
return self.cached_freqs[offset:(offset + seq_len)].detach()
freqs = self.freqs
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
if should_cache and offset == 0:
self.cached_freqs[:seq_len] = freqs.detach()
self.cached_freqs_seq_len = seq_len
return freqs
class RotaryEmbeddingBase(nn.Module):
def __init__(self, dim: int, rope_dim: int):
super().__init__()
self.rope = RotaryEmbedding(
dim=dim // rope_dim,
freqs_for="pixel",
max_freq=256,
)
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.data)
def get_axial_freqs(self, *dims):
return self.rope.get_axial_freqs(*dims)
class RotaryEmbedding3d(RotaryEmbeddingBase):
def __init__(self, dim: int):
super().__init__(dim, rope_dim=3)
self.mm = False
def forward(
self,
q: torch.FloatTensor, # b h l d
k: torch.FloatTensor, # b h l d
size: Tuple[int, int, int],
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
T, H, W = size
freqs = self.get_axial_freqs(T, H, W)
q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W)
k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W)
q = apply_rotary_emb(freqs, q.float()).to(q.dtype)
k = apply_rotary_emb(freqs, k.float()).to(k.dtype)
q = rearrange(q, "b h T H W d -> b h (T H W) d")
k = rearrange(k, "b h T H W d -> b h (T H W) d")
return q, k
class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
def __init__(self, dim: int, rope_dim: int):
super().__init__(dim, rope_dim)
self.rope = RotaryEmbedding(
dim=dim // rope_dim,
freqs_for="lang",
theta=10000,
)
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.data)
self.mm = True
def slice_at_dim(t, dim_slice: slice, *, dim):
dim += (t.ndim if dim < 0 else 0)
colons = [slice(None)] * t.ndim
colons[dim] = dim_slice
return t[tuple(colons)]
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
def exists(val):
return val is not None
def apply_rotary_emb(
freqs,
t,
start_index = 0,
scale = 1.,
seq_dim = -2,
freqs_seq_dim = None
):
dtype = t.dtype
if not exists(freqs_seq_dim):
if freqs.ndim == 2 or t.ndim == 3:
freqs_seq_dim = 0
if t.ndim == 3 or exists(freqs_seq_dim):
seq_len = t.shape[seq_dim]
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
t_left = t[..., :start_index]
t_middle = t[..., start_index:end_index]
t_right = t[..., end_index:]
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
return out.type(dtype)
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
def __init__(self, dim: int):
super().__init__(dim, rope_dim=3)
def forward(
self,
vid_q: torch.FloatTensor, # L h d
vid_k: torch.FloatTensor, # L h d
vid_shape: torch.LongTensor, # B 3
txt_q: torch.FloatTensor, # L h d
txt_k: torch.FloatTensor, # L h d
txt_shape: torch.LongTensor, # B 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
]:
vid_freqs, txt_freqs = cache(
"mmrope_freqs_3d",
lambda: self.get_freqs(vid_shape, txt_shape),
)
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
vid_q = rearrange(vid_q, "h L d -> L h d")
vid_k = rearrange(vid_k, "h L d -> L h d")
txt_q = rearrange(txt_q, "L h d -> h L d")
txt_k = rearrange(txt_k, "L h d -> h L d")
txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
txt_q = rearrange(txt_q, "h L d -> L h d")
txt_k = rearrange(txt_k, "h L d -> L h d")
return vid_q, vid_k, txt_q, txt_k
def get_freqs(
self,
vid_shape: torch.LongTensor,
txt_shape: torch.LongTensor,
) -> Tuple[
torch.Tensor,
torch.Tensor,
]:
vid_freqs = self.get_axial_freqs(1024, 128, 128)
txt_freqs = self.get_axial_freqs(1024)
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))
txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1))
vid_freq_list.append(vid_freq)
txt_freq_list.append(txt_freq)
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
class MMModule(nn.Module):
def __init__(
self,
module: Callable[..., nn.Module],
*args,
shared_weights: bool = False,
vid_only: bool = False,
**kwargs,
):
super().__init__()
self.shared_weights = shared_weights
self.vid_only = vid_only
if self.shared_weights:
assert get_args("vid", args) == get_args("txt", args)
assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs)
self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
else:
self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
self.txt = (
module(*get_args("txt", args), **get_kwargs("txt", kwargs))
if not vid_only
else None
)
def forward(
self,
vid: torch.FloatTensor,
txt: torch.FloatTensor,
*args,
**kwargs,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_module = self.vid if not self.shared_weights else self.all
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
if not self.vid_only:
txt_module = self.txt if not self.shared_weights else self.all
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
return vid, txt
def get_na_rope(rope_type: Optional[str], dim: int):
# 7b doesn't use rope
if rope_type is None:
return None
if rope_type == "mmrope3d":
return NaMMRotaryEmbedding3d(dim=dim)
class NaMMAttention(nn.Module):
def __init__(
self,
vid_dim: int,
txt_dim: int,
heads: int,
head_dim: int,
qk_bias: bool,
qk_norm,
qk_norm_eps: float,
rope_type: Optional[str],
rope_dim: int,
shared_weights: bool,
**kwargs,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
inner_dim = heads * head_dim
qkv_dim = inner_dim * 3
self.head_dim = head_dim
self.proj_qkv = MMModule(
nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights
)
self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights)
self.norm_q = MMModule(
qk_norm,
dim=head_dim,
eps=qk_norm_eps,
elementwise_affine=True,
shared_weights=shared_weights,
)
self.norm_k = MMModule(
qk_norm,
dim=head_dim,
eps=qk_norm_eps,
elementwise_affine=True,
shared_weights=shared_weights,
)
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
vid_q, vid_k, vid_v = vid_qkv.unbind(1)
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
vid_q, txt_q = self.norm_q(vid_q, txt_q)
vid_k, txt_k = self.norm_k(vid_k, txt_k)
if self.rope:
if self.rope.mm:
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
)
else:
vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache)
vid_len = cache("vid_len", lambda: vid_shape.prod(-1))
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
all_len = cache("all_len", lambda: vid_len + txt_len)
b = len(vid_len)
vq, vk, vv = [t.view(b, -1, *vid_q.shape[1:]) for t in (vid_q, vid_k, vid_v)]
tq, tk, tv = [t.view(b, -1, *txt_q.shape[1:]) for t in (txt_q, txt_v, txt_v)]
q = torch.cat([vq, tq], dim=1)
k = torch.cat([vk, tk], dim=1)
v = torch.cat([vv, tv], dim=1)
_, unconcat = cache("mm_pnp", lambda: concat_idx(vid_len, txt_len))
attn = optimized_attention(q, k, v, skip_reshape=True, skip_output_reshape=True)
attn = attn.flatten(0, 1) # to continue working with the rest of the code
attn = rearrange(attn, "l h d -> l (h d)")
vid_out, txt_out = unconcat(attn)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
def window(
hid: torch.FloatTensor, # (L c)
hid_shape: torch.LongTensor, # (b n)
window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
):
hid = unflatten(hid, hid_shape)
hid = list(map(window_fn, hid))
hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device)
hid, hid_shape = flatten(list(chain(*hid)))
return hid, hid_shape, hid_windows
def window_idx(
hid_shape: torch.LongTensor, # (b n)
window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
):
hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn)
tgt_idx = tgt_idx.squeeze(-1)
src_idx = torch.argsort(tgt_idx)
return (
lambda hid: torch.index_select(hid, 0, tgt_idx),
lambda hid: torch.index_select(hid, 0, src_idx),
tgt_shape,
tgt_windows,
)
class NaSwinAttention(NaMMAttention):
def __init__(
self,
*args,
window: Union[int, Tuple[int, int, int]],
window_method: bool, # shifted or not
**kwargs,
):
super().__init__(*args, **kwargs)
self.window = _triple(window)
self.window_method = window_method
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
self.window_op = window_method
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
# re-org the input seq for window attn
cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
def make_window(x: torch.Tensor):
t, h, w, _ = x.shape
window_slices = self.window_op((t, h, w), self.window)
return [x[st, sh, sw] for (st, sh, sw) in window_slices]
window_partition, window_reverse, window_shape, window_count = cache_win(
"win_transform",
lambda: window_idx(vid_shape, make_window),
)
vid_qkv_win = window_partition(vid_qkv)
vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim)
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
vid_q, vid_k, vid_v = vid_qkv_win.unbind(1)
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
vid_q, txt_q = self.norm_q(vid_q, txt_q)
vid_k, txt_k = self.norm_k(vid_k, txt_k)
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
concat_win, unconcat_win = cache_win(
"mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count)
)
# window rope
if self.rope:
if self.rope.mm:
# repeat text q and k for window mmrope
_, num_h, _ = txt_q.shape
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
txt_q_repeat = list(chain(*txt_q_repeat))
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
txt_k_repeat = list(chain(*txt_k_repeat))
txt_k_repeat, _ = flatten(txt_k_repeat)
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
)
else:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
out = self.attn(
q=concat_win(vid_q, txt_q).bfloat16(),
k=concat_win(vid_k, txt_k).bfloat16(),
v=concat_win(vid_v, txt_v).bfloat16(),
cu_seqlens_q=cache_win(
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
),
cu_seqlens_k=cache_win(
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
),
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
).type_as(vid_q)
# text pooling
vid_out, txt_out = unconcat_win(out)
vid_out = rearrange(vid_out, "l h d -> l (h d)")
txt_out = rearrange(txt_out, "l h d -> l (h d)")
vid_out = window_reverse(vid_out)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
class MLP(nn.Module):
def __init__(
self,
dim: int,
expand_ratio: int,
):
super().__init__()
self.proj_in = nn.Linear(dim, dim * expand_ratio)
self.act = nn.GELU("tanh")
self.proj_out = nn.Linear(dim * expand_ratio, dim)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = self.proj_in(x)
x = self.act(x)
x = self.proj_out(x)
return x
class SwiGLUMLP(nn.Module):
def __init__(
self,
dim: int,
expand_ratio: int,
multiple_of: int = 256,
):
super().__init__()
hidden_dim = int(2 * dim * expand_ratio / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False)
self.proj_out = nn.Linear(hidden_dim, dim, bias=False)
self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
return x
def get_mlp(mlp_type: Optional[str] = "normal"):
# 3b and 7b uses different mlp types
if mlp_type == "normal":
return MLP
elif mlp_type == "swiglu":
return SwiGLUMLP
class NaMMSRTransformerBlock(nn.Module):
def __init__(
self,
*,
vid_dim: int,
txt_dim: int,
emb_dim: int,
heads: int,
head_dim: int,
expand_ratio: int,
norm,
norm_eps: float,
ada,
qk_bias: bool,
qk_norm,
mlp_type: str,
shared_weights: bool,
rope_type: str,
rope_dim: int,
is_last_layer: bool,
**kwargs,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,)
self.attn = NaSwinAttention(
vid_dim=vid_dim,
txt_dim=txt_dim,
heads=heads,
head_dim=head_dim,
qk_bias=qk_bias,
qk_norm=qk_norm,
qk_norm_eps=norm_eps,
rope_type=rope_type,
rope_dim=rope_dim,
shared_weights=shared_weights,
window=kwargs.pop("window", None),
window_method=kwargs.pop("window_method", None),
)
self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer)
self.mlp = MMModule(
get_mlp(mlp_type),
dim=dim,
expand_ratio=expand_ratio,
shared_weights=shared_weights,
vid_only=is_last_layer
)
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer)
self.is_last_layer = is_last_layer
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
emb: torch.FloatTensor,
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.LongTensor,
torch.LongTensor,
]:
hid_len = MMArg(
cache("vid_len", lambda: vid_shape.prod(-1)),
cache("txt_len", lambda: txt_shape.prod(-1)),
)
ada_kwargs = {
"emb": emb,
"hid_len": hid_len,
"cache": cache,
"branch_tag": MMArg("vid", "txt"),
}
vid_attn, txt_attn = self.attn_norm(vid, txt)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs)
vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs)
vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
return vid_mlp, txt_mlp, vid_shape, txt_shape
class PatchOut(nn.Module):
def __init__(
self,
out_channels: int,
patch_size: Union[int, Tuple[int, int, int]],
dim: int,
):
super().__init__()
t, h, w = _triple(patch_size)
self.patch_size = t, h, w
self.proj = nn.Linear(dim, out_channels * t * h * w)
def forward(
self,
vid: torch.Tensor,
) -> torch.Tensor:
t, h, w = self.patch_size
vid = self.proj(vid)
vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w)
if t > 1:
vid = vid[:, :, (t - 1) :]
return vid
class NaPatchOut(PatchOut):
def forward(
self,
vid: torch.FloatTensor, # l c
vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), # for test
) -> Tuple[
torch.FloatTensor,
torch.LongTensor,
]:
cache = cache.namespace("patch")
vid_shape_before_patchify = cache.get("vid_shape_before_patchify")
t, h, w = self.patch_size
vid = self.proj(vid)
if not (t == h == w == 1):
vid = unflatten(vid, vid_shape)
for i in range(len(vid)):
vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w)
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :]
vid, vid_shape = flatten(vid)
return vid, vid_shape
class PatchIn(nn.Module):
def __init__(
self,
in_channels: int,
patch_size: Union[int, Tuple[int, int, int]],
dim: int,
):
super().__init__()
t, h, w = _triple(patch_size)
self.patch_size = t, h, w
self.proj = nn.Linear(in_channels * t * h * w, dim)
def forward(
self,
vid: torch.Tensor,
) -> torch.Tensor:
t, h, w = self.patch_size
if t > 1:
assert vid.size(2) % t == 1
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2)
vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
vid = self.proj(vid)
return vid
class NaPatchIn(PatchIn):
def forward(
self,
vid: torch.Tensor, # l c
vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), # for test
) -> torch.Tensor:
cache = cache.namespace("patch")
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape)
t, h, w = self.patch_size
if not (t == h == w == 1):
vid = unflatten(vid, vid_shape)
for i in range(len(vid)):
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0)
vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w)
vid, vid_shape = flatten(vid)
vid = self.proj(vid)
return vid, vid_shape
def expand_dims(x: torch.Tensor, dim: int, ndim: int):
shape = x.shape
shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:]
return x.reshape(shape)
class AdaSingle(nn.Module):
def __init__(
self,
dim: int,
emb_dim: int,
layers: List[str],
modes: List[str] = ["in", "out"],
):
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
super().__init__()
self.dim = dim
self.emb_dim = emb_dim
self.layers = layers
for l in layers:
if "in" in modes:
self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5))
self.register_parameter(
f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)
)
if "out" in modes:
self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5))
def forward(
self,
hid: torch.FloatTensor, # b ... c
emb: torch.FloatTensor, # b d
layer: str,
mode: str,
cache: Cache = Cache(disable=True),
branch_tag: str = "",
hid_len: Optional[torch.LongTensor] = None, # b
) -> torch.FloatTensor:
idx = self.layers.index(layer)
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1)
shiftA, scaleA, gateA = emb.unbind(-1)
shiftB, scaleB, gateB = (
getattr(self, f"{layer}_shift", None),
getattr(self, f"{layer}_scale", None),
getattr(self, f"{layer}_gate", None),
)
if mode == "in":
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
if mode == "out":
return hid.mul_(gateA + gateB)
raise NotImplementedError
def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]):
return emb1 if emb2 is None else emb1 + emb2
class TimeEmbedding(nn.Module):
def __init__(
self,
sinusoidal_dim: int,
hidden_dim: int,
output_dim: int,
):
super().__init__()
self.sinusoidal_dim = sinusoidal_dim
self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, output_dim)
self.act = nn.SiLU()
def forward(
self,
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor],
device: torch.device,
dtype: torch.dtype,
) -> torch.FloatTensor:
if not torch.is_tensor(timestep):
timestep = torch.tensor([timestep], device=device, dtype=dtype)
if timestep.ndim == 0:
timestep = timestep[None]
emb = get_timestep_embedding(
timesteps=timestep,
embedding_dim=self.sinusoidal_dim,
flip_sin_to_cos=False,
downscale_freq_shift=0,
)
emb = emb.to(dtype)
emb = self.proj_in(emb)
emb = self.act(emb)
emb = self.proj_hid(emb)
emb = self.act(emb)
emb = self.proj_out(emb)
return emb
def flatten(
hid: List[torch.FloatTensor], # List of (*** c)
) -> Tuple[
torch.FloatTensor, # (L c)
torch.LongTensor, # (b n)
]:
assert len(hid) > 0
shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid])
hid = torch.cat([x.flatten(0, -2) for x in hid])
return hid, shape
def unflatten(
hid: torch.FloatTensor, # (L c) or (L ... c)
hid_shape: torch.LongTensor, # (b n)
) -> List[torch.Tensor]: # List of (*** c) or (*** ... c)
hid_len = hid_shape.prod(-1)
hid = hid.split(hid_len.tolist())
hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)]
return hid
def repeat(
hid: torch.FloatTensor, # (L c)
hid_shape: torch.LongTensor, # (b n)
pattern: str,
**kwargs: Dict[str, torch.LongTensor], # (b)
) -> Tuple[
torch.FloatTensor,
torch.LongTensor,
]:
hid = unflatten(hid, hid_shape)
kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
@dataclass
class NaDiTOutput:
vid_sample: torch.Tensor
class NaDiT(nn.Module):
def __init__(
self,
norm_eps,
qk_rope,
num_layers,
mlp_type,
vid_in_channels = 33,
vid_out_channels = 16,
vid_dim = 2560,
txt_in_dim = 5120,
heads = 20,
head_dim = 128,
expand_ratio = 4,
qk_bias = False,
patch_size = [ 1,2,2 ],
shared_qkv: bool = False,
shared_mlp: bool = False,
window_method: Optional[Tuple[str]] = None,
temporal_window_size: int = None,
temporal_shifted: bool = False,
**kwargs,
):
txt_dim = vid_dim
emb_dim = vid_dim * 6
block_type = ["mmdit_sr"] * num_layers
window = num_layers * [(4,3,3)]
ada = AdaSingle
norm = RMSNorm
qk_norm = RMSNorm
if isinstance(block_type, str):
block_type = [block_type] * num_layers
elif len(block_type) != num_layers:
raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
super().__init__()
self.vid_in = NaPatchIn(
in_channels=vid_in_channels,
patch_size=patch_size,
dim=vid_dim,
)
self.txt_in = (
nn.Linear(txt_in_dim, txt_dim)
if txt_in_dim and txt_in_dim != txt_dim
else nn.Identity()
)
self.emb_in = TimeEmbedding(
sinusoidal_dim=256,
hidden_dim=max(vid_dim, txt_dim),
output_dim=emb_dim,
)
if window is None or isinstance(window[0], int):
window = [window] * num_layers
if window_method is None or isinstance(window_method, str):
window_method = [window_method] * num_layers
if temporal_window_size is None or isinstance(temporal_window_size, int):
temporal_window_size = [temporal_window_size] * num_layers
if temporal_shifted is None or isinstance(temporal_shifted, bool):
temporal_shifted = [temporal_shifted] * num_layers
self.blocks = nn.ModuleList(
[
NaMMSRTransformerBlock(
vid_dim=vid_dim,
txt_dim=txt_dim,
emb_dim=emb_dim,
heads=heads,
head_dim=head_dim,
expand_ratio=expand_ratio,
norm=norm,
norm_eps=norm_eps,
ada=ada,
qk_bias=qk_bias,
qk_rope=qk_rope,
qk_norm=qk_norm,
shared_qkv=shared_qkv,
shared_mlp=shared_mlp,
mlp_type=mlp_type,
window=window[i],
window_method=window_method[i],
temporal_window_size=temporal_window_size[i],
temporal_shifted=temporal_shifted[i],
**kwargs,
)
for i in range(num_layers)
]
)
self.vid_out = NaPatchOut(
out_channels=vid_out_channels,
patch_size=patch_size,
dim=vid_dim,
)
self.need_txt_repeat = block_type[0] in [
"mmdit_stwin",
"mmdit_stwin_spatial",
"mmdit_stwin_3d_spatial",
]
def set_gradient_checkpointing(self, enable: bool):
self.gradient_checkpointing = enable
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
disable_cache: bool = True, # for test
):
# Text input.
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
# slice vid after patching in when using sequence parallelism
txt = self.txt_in(txt)
# Video input.
# Sequence parallel slicing is done inside patching class.
vid, vid_shape = self.vid_in(vid, vid_shape)
# Embedding input.
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
# Body
cache = Cache(disable=disable_cache)
for i, block in enumerate(self.blocks):
vid, txt, vid_shape, txt_shape = block(
vid=vid,
txt=txt,
vid_shape=vid_shape,
txt_shape=txt_shape,
emb=emb,
cache=cache,
)
vid, vid_shape = self.vid_out(vid, vid_shape, cache)
return NaDiTOutput(vid_sample=vid)