ComfyUI/comfy/ldm/seedvr/model.py
2026-07-02 22:59:38 -04:00

1359 lines
48 KiB
Python

from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import torch.nn.functional as F
from math import ceil, pi
import torch
from itertools import accumulate, chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
from comfy.ldm.seedvr.attention import optimized_var_attention
from torch.nn.modules.utils import _triple
from torch import nn
import math
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.seedvr.constants import (
BYTEDANCE_720P_REF_AREA,
BYTEDANCE_MAX_TEMPORAL_WINDOW,
BYTEDANCE_ROPE_MAX_FREQ,
BYTEDANCE_SINUSOIDAL_DIM,
ROPE_THETA,
SEEDVR2_7B_MLP_CHUNK,
SEEDVR2_7B_VID_DIM,
SEEDVR2_LATENT_CHANNELS,
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
)
import comfy.model_management
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
self.cache = cache if cache is not None else {}
self.disable = disable
self.prefix = prefix
def __call__(self, key: str, fn: Callable):
if self.disable:
return fn()
key = self.prefix + key
if key not in self.cache:
result = fn()
self.cache[key] = result
return self.cache[key]
def namespace(self, namespace: str):
return Cache(
disable=self.disable,
prefix=self.prefix + namespace + ".",
cache=self.cache,
)
def repeat_concat(
vid: torch.FloatTensor, # (VL ... c)
txt: torch.FloatTensor, # (TL ... c)
vid_len: torch.LongTensor, # (n*b)
txt_len: torch.LongTensor, # (b)
txt_repeat: List, # (n)
) -> torch.FloatTensor: # (L ... c)
vid = torch.split(vid, vid_len.tolist())
txt = torch.split(txt, txt_len.tolist())
txt = [[x] * n for x, n in zip(txt, txt_repeat)]
txt = list(chain(*txt))
return torch.cat(list(chain(*zip(vid, txt))))
def repeat_concat_idx(
vid_len: torch.LongTensor, # (n*b)
txt_len: torch.LongTensor, # (b)
txt_repeat: torch.LongTensor, # (n)
) -> Tuple[
Callable,
Callable,
]:
device = vid_len.device
vid_idx = torch.arange(vid_len.sum(), device=device)
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
txt_repeat_list = txt_repeat.tolist()
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat_list)
src_idx = torch.argsort(tgt_idx)
txt_idx_len = len(tgt_idx) - len(vid_idx)
repeat_txt_len = (txt_len * txt_repeat).tolist()
def unconcat_coalesce(all):
vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len])
txt_out_coalesced = []
for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list):
txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1)
txt_out_coalesced.append(txt)
return vid_out, torch.cat(txt_out_coalesced)
return (
lambda vid, txt: torch.cat([vid, txt])[tgt_idx],
lambda all: unconcat_coalesce(all),
)
def cumulative_lengths(lengths):
return [0, *accumulate(lengths)]
@dataclass
class MMArg:
vid: Any
txt: Any
def get_args(key: str, args: List[Any]) -> List[Any]:
return [getattr(v, key) if isinstance(v, MMArg) else v for v in args]
def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
def get_window_op(name: str):
if name == "720pwin_by_size_bysize":
return make_720Pwindows_bysize
if name == "720pswin_by_size_bysize":
return make_shifted_720Pwindows_bysize
raise ValueError(f"Unknown windowing method: {name}")
def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww)
return [
(
slice(it * wt, min((it + 1) * wt, t)),
slice(ih * wh, min((ih + 1) * wh, h)),
slice(iw * ww, min((iw + 1) * ww, w)),
)
for iw in range(nw)
if min((iw + 1) * ww, w) > iw * ww
for ih in range(nh)
if min((ih + 1) * wh, h) > ih * wh
for it in range(nt)
if min((it + 1) * wt, t) > it * wt
]
def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
st, sh, sw = (
0.5 if wt < t else 0,
0.5 if wh < h else 0,
0.5 if ww < w else 0,
)
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww)
nt, nh, nw = (
nt + 1 if st > 0 else 1,
nh + 1 if sh > 0 else 1,
nw + 1 if sw > 0 else 1,
)
return [
(
slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)),
slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)),
slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)),
)
for iw in range(nw)
if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0)
for ih in range(nh)
if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0)
for it in range(nt)
if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0)
]
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
freqs_for = 'lang',
theta = 10000,
max_freq = 10,
):
super().__init__()
self.freqs_for = freqs_for
if freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
else:
raise ValueError(f"Unknown rotary frequency type: {freqs_for}")
self.register_buffer("freqs", freqs)
@property
def device(self):
return self.freqs.device
def get_axial_freqs(
self,
*dims,
offsets = None
):
Colon = slice(None)
all_freqs = []
if exists(offsets):
if len(offsets) != len(dims):
raise ValueError(f"SeedVR2 rotary offsets length must match dims length, got {len(offsets)} and {len(dims)}.")
for ind, dim in enumerate(dims):
offset = 0
if exists(offsets):
offset = offsets[ind]
if self.freqs_for == 'pixel':
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
else:
pos = torch.arange(dim, device = self.device)
pos = pos + offset
freqs = self.forward(pos)
all_axis = [None] * len(dims)
all_axis[ind] = Colon
new_axis_slice = (Ellipsis, *all_axis, Colon)
all_freqs.append(freqs[new_axis_slice])
all_freqs = torch.broadcast_tensors(*all_freqs)
return torch.cat(all_freqs, dim = -1)
def forward(
self,
t,
):
freqs = self.freqs
freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = freqs.unsqueeze(-1).expand(*freqs.shape, 2).flatten(-2)
return freqs
class RotaryEmbeddingBase(nn.Module):
def __init__(self, dim: int, rope_dim: int):
super().__init__()
self.rope = RotaryEmbedding(
dim=dim // rope_dim,
freqs_for="pixel",
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
)
def get_axial_freqs(self, *dims):
return self.rope.get_axial_freqs(*dims)
class RotaryEmbedding3d(RotaryEmbeddingBase):
def __init__(self, dim: int):
super().__init__(dim, rope_dim=3)
self.mm = False
class NaRotaryEmbedding3d(RotaryEmbedding3d):
def forward(
self,
q: torch.FloatTensor,
k: torch.FloatTensor,
shape: torch.LongTensor,
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape))
freqs = freqs.to(device=q.device)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype)
k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
return q, k
@torch._dynamo.disable
def get_freqs(
self,
shape: torch.LongTensor,
) -> torch.Tensor:
# Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds
# 7B pixel RoPE with the interleaved-angle convention, not Comfy's
# Flux freqs_cis matrix.
plain_rope = RotaryEmbedding(
dim=self.rope.freqs.numel() * 2,
freqs_for="pixel",
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
)
plain_rope = plain_rope.to(self.rope.device)
freq_list = []
for f, h, w in shape.tolist():
freqs = plain_rope.get_axial_freqs(f, h, w)
freq_list.append(freqs.view(-1, freqs.size(-1)))
return torch.cat(freq_list, dim=0)
class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
def __init__(self, dim: int, rope_dim: int):
super().__init__(dim, rope_dim)
self.rope = RotaryEmbedding(
dim=dim // rope_dim,
freqs_for="lang",
theta=ROPE_THETA,
)
self.mm = True
def slice_at_dim(t, dim_slice: slice, *, dim):
dim += (t.ndim if dim < 0 else 0)
colons = [slice(None)] * t.ndim
colons[dim] = dim_slice
return t[tuple(colons)]
def rotate_half(x):
x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return x.flatten(-2)
def exists(val):
return val is not None
def _apply_seedvr2_rotary_emb(
freqs: torch.Tensor,
t: torch.Tensor,
start_index: int = 0,
scale: float = 1.0,
seq_dim: int = -2,
freqs_seq_dim: int | None = None,
) -> torch.Tensor:
dtype = t.dtype
if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3):
freqs_seq_dim = 0
if t.ndim == 3 or freqs_seq_dim is not None:
seq_len = t.shape[seq_dim]
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
rot_feats = freqs.shape[-1]
end_index = start_index + rot_feats
t_left = t[..., :start_index]
t_middle = t[..., start_index:end_index]
t_right = t[..., end_index:]
freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype)
cos = freqs.cos() * scale
sin = freqs.sin() * scale
t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin)
return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype)
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
angles = freqs_interleaved[..., ::2].float()
cos = torch.cos(angles)
sin = torch.sin(angles)
out = torch.stack([cos, -sin, sin, cos], dim=-1)
return out.reshape(*out.shape[:-1], 2, 2)
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
out = t.clone() if t.requires_grad or comfy.model_management.in_training else t
rot_d = 2 * freqs_cis.shape[-3]
seq_len = out.shape[-2]
for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS):
end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len)
freqs_chunk = freqs_cis[start:end]
if rot_d == out.shape[-1]:
out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype)
else:
out[..., start:end, :rot_d] = apply_rope1(out[..., start:end, :rot_d], freqs_chunk).to(out.dtype)
return out
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
def __init__(self, dim: int):
super().__init__(dim, rope_dim=3)
def forward(
self,
vid_q: torch.FloatTensor, # L h d
vid_k: torch.FloatTensor, # L h d
vid_shape: torch.LongTensor, # B 3
txt_q: torch.FloatTensor, # L h d
txt_k: torch.FloatTensor, # L h d
txt_shape: torch.LongTensor, # B 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
torch.FloatTensor,
]:
vid_freqs, txt_freqs = cache(
"mmrope_freqs_3d",
lambda: self.get_freqs(vid_shape, txt_shape),
)
target_device = vid_q.device
if vid_freqs.device != target_device:
vid_freqs = vid_freqs.to(target_device)
if txt_freqs.device != target_device:
txt_freqs = txt_freqs.to(target_device)
vid_q = vid_q.transpose(0, 1)
vid_k = vid_k.transpose(0, 1)
vid_q = _apply_rope1_partial(vid_q, vid_freqs)
vid_k = _apply_rope1_partial(vid_k, vid_freqs)
vid_q = vid_q.transpose(0, 1)
vid_k = vid_k.transpose(0, 1)
txt_q = txt_q.transpose(0, 1)
txt_k = txt_k.transpose(0, 1)
txt_q = _apply_rope1_partial(txt_q, txt_freqs)
txt_k = _apply_rope1_partial(txt_k, txt_freqs)
txt_q = txt_q.transpose(0, 1)
txt_k = txt_k.transpose(0, 1)
return vid_q, vid_k, txt_q, txt_k
@torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks
def get_freqs(
self,
vid_shape: torch.LongTensor,
txt_shape: torch.LongTensor,
) -> Tuple[
torch.Tensor,
torch.Tensor,
]:
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
with torch.amp.autocast(autocast_device, enabled=False):
vid_freqs = self.get_axial_freqs(
max_temporal + 16,
max_height + 4,
max_width + 4,
).float()
txt_freqs = self.get_axial_freqs(max_txt_len + 16)
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))
txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1))
vid_freq_list.append(vid_freq)
txt_freq_list.append(txt_freq)
vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0)
txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0)
return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved)
class MMModule(nn.Module):
def __init__(
self,
module: Callable[..., nn.Module],
*args,
shared_weights: bool = False,
vid_only: bool = False,
**kwargs,
):
super().__init__()
self.shared_weights = shared_weights
self.vid_only = vid_only
if self.shared_weights:
if get_args("vid", args) != get_args("txt", args):
raise ValueError("SeedVR2 shared MMModule requires matching vid/txt args.")
if get_kwargs("vid", kwargs) != get_kwargs("txt", kwargs):
raise ValueError("SeedVR2 shared MMModule requires matching vid/txt kwargs.")
self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
else:
self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
self.txt = (
module(*get_args("txt", args), **get_kwargs("txt", kwargs))
if not vid_only
else None
)
def forward(
self,
vid: torch.FloatTensor,
txt: torch.FloatTensor,
*args,
**kwargs,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_module = self.vid if not self.shared_weights else self.all
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
if not self.vid_only:
txt_module = self.txt if not self.shared_weights else self.all
txt = txt.to(device=vid.device, dtype=vid.dtype)
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
return vid, txt
def get_na_rope(rope_type: Optional[str], dim: int):
if rope_type is None:
return None
if rope_type == "rope3d":
return NaRotaryEmbedding3d(dim=dim)
if rope_type == "mmrope3d":
return NaMMRotaryEmbedding3d(dim=dim)
raise ValueError(f"Unknown SeedVR2 rope type: {rope_type}")
class NaMMAttention(nn.Module):
def __init__(
self,
vid_dim: int,
txt_dim: int,
heads: int,
head_dim: int,
qk_bias: bool,
qk_norm,
qk_norm_eps: float,
rope_type: Optional[str],
rope_dim: int,
shared_weights: bool,
device, dtype, operations,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
self.heads = heads
inner_dim = heads * head_dim
qkv_dim = inner_dim * 3
self.head_dim = head_dim
self.proj_qkv = MMModule(
operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype
)
self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype)
self.norm_q = MMModule(
qk_norm,
normalized_shape=head_dim,
eps=qk_norm_eps,
elementwise_affine=True,
shared_weights=shared_weights,
device=device, dtype=dtype
)
self.norm_k = MMModule(
qk_norm,
normalized_shape=head_dim,
eps=qk_norm_eps,
elementwise_affine=True,
shared_weights=shared_weights,
device=device, dtype=dtype
)
self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim)
def window(
hid: torch.FloatTensor, # (L c)
hid_shape: torch.LongTensor, # (b n)
window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
):
hid = unflatten(hid, hid_shape)
hid = list(map(window_fn, hid))
hid_windows_list = [len(x) for x in hid]
hid_windows = torch.as_tensor(hid_windows_list, device=hid_shape.device)
hid = list(chain(*hid))
hid_len_list = [math.prod(x.shape[:-1]) for x in hid]
hid, hid_shape = flatten(hid)
return hid, hid_shape, hid_windows, hid_len_list, hid_windows_list
def window_idx(
hid_shape: torch.LongTensor, # (b n)
window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
):
hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
tgt_idx, tgt_shape, tgt_windows, tgt_len_list, tgt_windows_list = window(hid_idx, hid_shape, window_fn)
tgt_idx = tgt_idx.squeeze(-1)
src_idx = torch.argsort(tgt_idx)
return (
lambda hid: torch.index_select(hid, 0, tgt_idx),
lambda hid: torch.index_select(hid, 0, src_idx),
tgt_shape,
tgt_windows,
tgt_len_list,
tgt_windows_list,
)
class NaSwinAttention(NaMMAttention):
def __init__(
self,
*args,
window: Union[int, Tuple[int, int, int]],
window_method: str,
version: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.version_7b = version
self.window = _triple(window)
self.window_method = window_method
if not all(isinstance(v, int) and v >= 0 for v in self.window):
raise ValueError(f"SeedVR2 window must contain non-negative integers, got {self.window}.")
self.window_op = get_window_op(window_method)
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
def make_window(x: torch.Tensor):
t, h, w, _ = x.shape
window_slices = self.window_op((t, h, w), self.window)
return [x[st, sh, sw] for (st, sh, sw) in window_slices]
window_partition, window_reverse, window_shape, window_count, vid_len_win_list, window_count_list = cache_win(
"win_transform",
lambda: window_idx(vid_shape, make_window),
)
vid_qkv_win = window_partition(vid_qkv)
vid_qkv_win = vid_qkv_win.reshape(vid_qkv_win.shape[0], 3, self.heads, self.head_dim)
txt_qkv = txt_qkv.reshape(txt_qkv.shape[0], 3, self.heads, self.head_dim)
vid_q, vid_k, vid_v = vid_qkv_win.unbind(1)
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
vid_q, txt_q = self.norm_q(vid_q, txt_q)
vid_k, txt_k = self.norm_k(vid_k, txt_k)
txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len = txt_len.to(window_count.device)
if self.rope:
if self.version_7b:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
elif self.rope.mm:
_, num_h, _ = txt_q.shape
txt_q_repeat = txt_q.flatten(1, 2)
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count_list)]
txt_q_repeat = list(chain(*txt_q_repeat))
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim)
txt_k_repeat = txt_k.flatten(1, 2)
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count_list)]
txt_k_repeat = list(chain(*txt_k_repeat))
txt_k_repeat, _ = flatten(txt_k_repeat)
txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim)
vid_q, vid_k, txt_q, txt_k = self.rope(
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
)
else:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
txt_len_win_list = cache_win(
"txt_len_list",
lambda: [txt_len for txt_len, window_count in zip(txt_len.tolist(), window_count_list) for _ in range(window_count)],
)
all_len_win = cache_win("all_len", lambda: [vid_len + txt_len for vid_len, txt_len in zip(vid_len_win_list, txt_len_win_list)])
concat_win, unconcat_win = cache_win(
"mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count)
)
out = optimized_var_attention(
q=concat_win(vid_q, txt_q),
k=concat_win(vid_k, txt_k),
v=concat_win(vid_v, txt_v),
heads=self.heads, skip_reshape=True, skip_output_reshape=True,
cu_seqlens_q=cache_win("vid_seqlens_q", lambda: cumulative_lengths(all_len_win)),
cu_seqlens_k=cache_win("vid_seqlens_k", lambda: cumulative_lengths(all_len_win)),
)
vid_out, txt_out = unconcat_win(out)
vid_out = vid_out.flatten(1, 2)
txt_out = txt_out.flatten(1, 2)
vid_out = window_reverse(vid_out)
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
class MLP(nn.Module):
def __init__(
self,
dim: int,
expand_ratio: int,
device, dtype, operations
):
super().__init__()
self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype)
self.act = nn.GELU("tanh")
self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = self.proj_in(x)
x = self.act(x)
x = self.proj_out(x)
return x
class SwiGLUMLP(nn.Module):
def __init__(
self,
dim: int,
expand_ratio: int,
multiple_of: int = 256,
device=None, dtype=None, operations=None
):
super().__init__()
hidden_dim = int(2 * dim * expand_ratio / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype)
self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
def get_mlp(mlp_type: Optional[str] = "normal"):
if mlp_type == "normal":
return MLP
if mlp_type == "swiglu":
return SwiGLUMLP
raise ValueError(f"Unknown SeedVR2 MLP type: {mlp_type}")
class NaMMSRTransformerBlock(nn.Module):
def __init__(
self,
*,
vid_dim: int,
txt_dim: int,
emb_dim: int,
heads: int,
head_dim: int,
expand_ratio: int,
norm,
norm_eps: float,
ada,
qk_bias: bool,
qk_norm,
mlp_type: str,
shared_weights: bool,
rope_type: str,
rope_dim: int,
is_last_layer: bool,
window: Union[int, Tuple[int, int, int]],
window_method: str,
version: bool,
device, dtype, operations,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
self.attn = NaSwinAttention(
vid_dim=vid_dim,
txt_dim=txt_dim,
heads=heads,
head_dim=head_dim,
qk_bias=qk_bias,
qk_norm=qk_norm,
qk_norm_eps=norm_eps,
rope_type=rope_type,
rope_dim=rope_dim,
shared_weights=shared_weights,
window=window,
window_method=window_method,
version=version,
device=device, dtype=dtype, operations=operations
)
self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
self.mlp = MMModule(
get_mlp(mlp_type),
dim=dim,
expand_ratio=expand_ratio,
shared_weights=shared_weights,
vid_only=is_last_layer,
device=device, dtype=dtype, operations=operations
)
self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype)
self.is_last_layer = is_last_layer
self.version = version
def _seedvr2_7b_mlp(
self,
vid: torch.FloatTensor,
txt: torch.FloatTensor,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
vid_module = self.mlp.vid if not self.mlp.shared_weights else self.mlp.all
if comfy.model_management.in_training or vid.requires_grad:
vid = torch.cat([vid_module(chunk) for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0)], dim=0)
else:
vid_out = None
offset = 0
for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0):
chunk_out = vid_module(chunk)
if vid_out is None:
vid_out = chunk_out.new_empty((vid.shape[0], *chunk_out.shape[1:]))
vid_out[offset:offset + chunk_out.shape[0]] = chunk_out
offset += chunk_out.shape[0]
vid = vid_out
if not self.mlp.vid_only:
txt_module = self.mlp.txt if not self.mlp.shared_weights else self.mlp.all
txt = txt.to(device=vid.device, dtype=vid.dtype)
txt = txt_module(txt)
return vid, txt
def forward(
self,
vid: torch.FloatTensor, # l c
txt: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, # b 3
txt_shape: torch.LongTensor, # b 1
emb: torch.FloatTensor,
cache: Cache,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
torch.LongTensor,
torch.LongTensor,
]:
hid_len = MMArg(
cache("vid_len", lambda: vid_shape.prod(-1)),
cache("txt_len", lambda: txt_shape.prod(-1)),
)
ada_kwargs = {
"emb": emb,
"hid_len": hid_len,
"cache": cache,
"branch_tag": MMArg("vid", "txt"),
}
vid_attn, txt_attn = self.attn_norm(vid, txt)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs)
if self.version:
vid_mlp, txt_mlp = self._seedvr2_7b_mlp(vid_mlp, txt_mlp)
else:
vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs)
vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
return vid_mlp, txt_mlp, vid_shape, txt_shape
class PatchOut(nn.Module):
def __init__(
self,
out_channels: int,
patch_size: Union[int, Tuple[int, int, int]],
dim: int,
device, dtype, operations
):
super().__init__()
t, h, w = _triple(patch_size)
self.patch_size = t, h, w
self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype)
def forward(
self,
vid: torch.Tensor,
) -> torch.Tensor:
t, h, w = self.patch_size
vid = self.proj(vid)
b, T, H, W, channels = vid.shape
c = channels // (t * h * w)
vid = vid.view(b, T, H, W, t, h, w, c).permute(0, 7, 1, 4, 2, 5, 3, 6).reshape(b, c, T * t, H * h, W * w)
if t > 1:
vid = vid[:, :, (t - 1) :]
return vid
class NaPatchOut(PatchOut):
def forward(
self,
vid: torch.FloatTensor, # l c
vid_shape: torch.LongTensor,
cache: Optional[Cache] = None,
vid_shape_before_patchify = None
) -> Tuple[
torch.FloatTensor,
torch.LongTensor,
]:
if cache is None:
cache = Cache(disable=True)
t, h, w = self.patch_size
vid = self.proj(vid)
if not (t == h == w == 1):
vid = unflatten(vid, vid_shape)
for i in range(len(vid)):
T, H, W, channels = vid[i].shape
c = channels // (t * h * w)
vid[i] = vid[i].view(T, H, W, t, h, w, c).permute(0, 3, 1, 4, 2, 5, 6).reshape(T * t, H * h, W * w, c)
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :]
vid, vid_shape = flatten(vid)
return vid, vid_shape
class PatchIn(nn.Module):
def __init__(
self,
in_channels: int,
patch_size: Union[int, Tuple[int, int, int]],
dim: int,
device, dtype, operations
):
super().__init__()
t, h, w = _triple(patch_size)
self.patch_size = t, h, w
self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype)
def forward(
self,
vid: torch.Tensor,
) -> torch.Tensor:
t, h, w = self.patch_size
if t > 1:
if vid.size(2) % t != 1:
raise ValueError(
f"SeedVR2 patch input temporal size must satisfy T % {t} == 1, got {vid.size(2)}."
)
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2)
b, c, Tt, Hh, Ww = vid.shape
vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c)
vid = self.proj(vid)
return vid
class NaPatchIn(PatchIn):
def forward(
self,
vid: torch.Tensor, # l c
vid_shape: torch.LongTensor,
cache: Optional[Cache] = None,
) -> torch.Tensor:
if cache is None:
cache = Cache(disable=True)
cache = cache.namespace("patch")
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape)
t, h, w = self.patch_size
if not (t == h == w == 1):
vid = unflatten(vid, vid_shape)
for i in range(len(vid)):
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0)
Tt, Hh, Ww, c = vid[i].shape
vid[i] = vid[i].view(Tt // t, t, Hh // h, h, Ww // w, w, c).permute(0, 2, 4, 1, 3, 5, 6).reshape(Tt // t, Hh // h, Ww // w, t * h * w * c)
vid, vid_shape = flatten(vid)
vid = self.proj(vid)
return vid, vid_shape
def expand_dims(x: torch.Tensor, dim: int, ndim: int):
shape = x.shape
shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:]
return x.reshape(shape)
class AdaSingle(nn.Module):
def __init__(
self,
dim: int,
emb_dim: int,
layers: List[str],
modes: Tuple[str, ...] = ("in", "out"),
device = None, dtype = None,
):
if emb_dim != 6 * dim:
raise ValueError(f"SeedVR2 AdaSingle requires emb_dim == 6 * dim, got emb_dim={emb_dim}, dim={dim}.")
super().__init__()
self.dim = dim
self.emb_dim = emb_dim
self.layers = layers
param_kwargs = {"device": device, "dtype": dtype}
for l in layers:
if "in" in modes:
self.register_parameter(f"{l}_shift", nn.Parameter(torch.empty(dim, **param_kwargs)))
self.register_parameter(f"{l}_scale", nn.Parameter(torch.empty(dim, **param_kwargs)))
if "out" in modes:
self.register_parameter(f"{l}_gate", nn.Parameter(torch.empty(dim, **param_kwargs)))
def forward(
self,
hid: torch.FloatTensor, # b ... c
emb: torch.FloatTensor, # b d
layer: str,
mode: str,
cache: Optional[Cache] = None,
branch_tag: str = "",
hid_len: Optional[torch.LongTensor] = None, # b
) -> torch.FloatTensor:
if cache is None:
cache = Cache(disable=True)
idx = self.layers.index(layer)
emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1)
if hid_len is not None:
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: torch.repeat_interleave(emb, hid_len, dim=0),
)
shiftA, scaleA, gateA = emb.unbind(-1)
shiftB, scaleB, gateB = (
getattr(self, f"{layer}_shift", None),
getattr(self, f"{layer}_scale", None),
getattr(self, f"{layer}_gate", None),
)
if mode == "in":
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
if mode == "out":
if gateB is not None:
return hid.mul_(gateA + gateB)
else:
return hid.mul_(gateA)
raise ValueError(f"Unknown AdaSingle mode: {mode}")
class TimeEmbedding(nn.Module):
def __init__(
self,
sinusoidal_dim: int,
hidden_dim: int,
output_dim: int,
device, dtype, operations
):
super().__init__()
self.sinusoidal_dim = sinusoidal_dim
self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype)
self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)
self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype)
self.act = nn.SiLU()
def forward(
self,
timestep: Union[int, float, torch.IntTensor, torch.FloatTensor],
device: torch.device,
dtype: torch.dtype,
) -> torch.FloatTensor:
if not torch.is_tensor(timestep):
timestep = torch.tensor([timestep], device=device, dtype=dtype)
if timestep.ndim == 0:
timestep = timestep[None]
emb = get_timestep_embedding(
timesteps=timestep,
embedding_dim=self.sinusoidal_dim,
flip_sin_to_cos=False,
downscale_freq_shift=0,
).to(dtype)
emb = self.proj_in(emb)
emb = self.act(emb)
emb = self.proj_hid(emb)
emb = self.act(emb)
emb = self.proj_out(emb)
return emb
def flatten(
hid: List[torch.FloatTensor], # List of (*** c)
) -> Tuple[
torch.FloatTensor, # (L c)
torch.LongTensor, # (b n)
]:
if len(hid) == 0:
raise ValueError("SeedVR2 flatten requires at least one tensor.")
shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device)
hid = torch.cat([x.flatten(0, -2) for x in hid])
return hid, shape
def unflatten(
hid: torch.FloatTensor, # (L c) or (L ... c)
hid_shape: torch.LongTensor, # (b n)
) -> List[torch.Tensor]: # List of (*** c) or (*** ... c)
hid_len = hid_shape.prod(-1)
hid = hid.split(hid_len.tolist())
hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)]
return hid
class NaDiT(nn.Module):
def __init__(
self,
norm_eps,
num_layers,
mlp_type,
vid_in_channels = 33,
vid_out_channels = SEEDVR2_LATENT_CHANNELS,
vid_dim = 2560,
txt_in_dim = 5120,
heads = 20,
head_dim = 128,
mm_layers = 10,
expand_ratio = 4,
qk_bias = False,
patch_size = (1, 2, 2),
rope_dim = 128,
rope_type = "mmrope3d",
vid_out_norm: Optional[str] = None,
image_model = None,
device = None,
dtype = None,
operations = None,
):
if image_model not in (None, "seedvr2"):
raise ValueError(f"SeedVR2 NaDiT expected image_model='seedvr2', got {image_model!r}.")
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
if self._7b_version:
rope_type = "rope3d"
self.dtype = dtype
factory_kwargs = {"device": device, "dtype": dtype}
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
txt_dim = vid_dim
emb_dim = vid_dim * 6
window = num_layers * [(4,3,3)]
ada = AdaSingle
norm = operations.RMSNorm
qk_norm = operations.RMSNorm
super().__init__()
self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype))
self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype))
self.vid_in = NaPatchIn(
in_channels=vid_in_channels,
patch_size=patch_size,
dim=vid_dim,
device=device, dtype=dtype, operations=operations
)
self.txt_in = (
operations.Linear(txt_in_dim, txt_dim, **factory_kwargs)
if txt_in_dim and txt_in_dim != txt_dim
else nn.Identity()
)
self.emb_in = TimeEmbedding(
sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM,
hidden_dim=max(vid_dim, txt_dim),
output_dim=emb_dim,
device=device, dtype=dtype, operations=operations
)
if window is None or isinstance(window[0], int):
window = [window] * num_layers
rope_dim = rope_dim if rope_dim is not None else head_dim // 2
self.blocks = nn.ModuleList(
[
NaMMSRTransformerBlock(
vid_dim=vid_dim,
txt_dim=txt_dim,
emb_dim=emb_dim,
heads=heads,
head_dim=head_dim,
expand_ratio=expand_ratio,
norm=norm,
norm_eps=norm_eps,
ada=ada,
qk_bias=qk_bias,
qk_norm=qk_norm,
mlp_type=mlp_type,
rope_dim = rope_dim,
window=window[i],
window_method=window_method[i],
version = self._7b_version,
is_last_layer=(i == num_layers - 1) and not self._7b_version,
rope_type = rope_type,
shared_weights=not (
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
),
operations = operations,
**factory_kwargs
)
for i in range(num_layers)
]
)
self.vid_out = NaPatchOut(
out_channels=vid_out_channels,
patch_size=patch_size,
dim=vid_dim,
device=device, dtype=dtype, operations=operations
)
self.vid_out_norm = None
if vid_out_norm is not None:
self.vid_out_norm = operations.RMSNorm(
normalized_shape=vid_dim,
eps=norm_eps,
elementwise_affine=True,
device=device, dtype=dtype
)
self.vid_out_ada = ada(
dim=vid_dim,
emb_dim=emb_dim,
layers=["out"],
modes=["in"],
device=device, dtype=dtype
)
def _resolve_text_conditioning(self, context, cond_or_uncond=None):
if context is None or context.numel() == 0:
context = self.positive_conditioning
return flatten([context])
if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond):
if context.shape[0] == 1:
context = context.squeeze(0)
return flatten([context])
return flatten(context.unbind(0))
if context.shape[0] % 2 != 0:
raise ValueError(f"SeedVR2 expected an even text-conditioning batch, got shape {tuple(context.shape)}")
neg_cond, pos_cond = context.chunk(2, dim=0)
if pos_cond.shape[0] == 1:
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
return flatten([pos_cond, neg_cond])
return flatten((*pos_cond.unbind(0), *neg_cond.unbind(0)))
@staticmethod
def _seedvr2_is_single_conditioning_branch(cond_or_uncond):
if cond_or_uncond is None or len(cond_or_uncond) == 0:
return False
first = cond_or_uncond[0]
return all(entry == first for entry in cond_or_uncond)
@staticmethod
def _check_seedvr2_video_latent(x, channels, name):
if x.ndim != 5:
raise ValueError(f"SeedVR2 expected {name} to be 5-D native latent, got shape {tuple(x.shape)}.")
if x.shape[1] != channels:
raise ValueError(f"SeedVR2 expected {name} channels to be {channels}, got shape {tuple(x.shape)}.")
return x
def _swap_pos_neg_halves(self, out, cond_or_uncond=None):
if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond):
return out
pos, neg = out.chunk(2, dim=0)
return torch.cat([neg, pos], dim=0)
def forward(
self,
x,
timestep,
context, # l c
disable_cache: bool = False,
**kwargs
):
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
conditions = kwargs.get("condition")
if conditions is None:
raise ValueError("SeedVR2 requires conditioning latents from the SeedVR2Conditioning node.")
x = self._check_seedvr2_video_latent(x, SEEDVR2_LATENT_CHANNELS, "latent")
conditions = self._check_seedvr2_video_latent(conditions, SEEDVR2_LATENT_CHANNELS + 1, "conditioning")
b, _, t, h, w = x.shape
if conditions.shape[0] != b or conditions.shape[2:] != (t, h, w):
raise ValueError(
f"SeedVR2 conditioning shape must match latent batch/temporal/spatial dimensions; got latent {tuple(x.shape)} and conditioning {tuple(conditions.shape)}."
)
x = x.movedim(1, -1)
conditions = conditions.movedim(1, -1)
cache = Cache(disable=disable_cache)
txt, txt_shape = self._resolve_text_conditioning(context, transformer_options.get("cond_or_uncond"))
vid, vid_shape = flatten(x)
cond_latent, _ = flatten(conditions)
vid = torch.cat([vid, cond_latent], dim=-1)
txt = self.txt_in(txt)
vid_shape_before_patchify = vid_shape
vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache)
emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
for i, block in enumerate(self.blocks):
if ("block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block(
vid=args["vid"],
txt=args["txt"],
vid_shape=args["vid_shape"],
txt_shape=args["txt_shape"],
emb=args["emb"],
cache=args["cache"],
)
return out
out = blocks_replace[("block", i)]({
"vid":vid,
"txt":txt,
"vid_shape":vid_shape,
"txt_shape":txt_shape,
"emb":emb,
"cache":cache,
}, {"original_block": block_wrap})
vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"]
else:
vid, txt, vid_shape, txt_shape = block(
vid=vid,
txt=txt,
vid_shape=vid_shape,
txt_shape=txt_shape,
emb=emb,
cache=cache,
)
if self.vid_out_norm:
vid = self.vid_out_norm(vid)
vid = self.vid_out_ada(
vid,
emb=emb,
layer="out",
mode="in",
hid_len=cache("vid_len", lambda: vid_shape.prod(-1)),
cache=cache,
branch_tag="vid",
)
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
vid = unflatten(vid, vid_shape)
out = torch.stack(vid)
out = out.movedim(-1, 1)
return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond"))