ComfyUI/comfy/ldm/cogvideo/model.py
2026-04-10 20:28:05 +02:00

562 lines
26 KiB
Python

# CogVideoX 3D Transformer - ported to ComfyUI native ops
# Architecture reference: diffusers CogVideoXTransformer3DModel
# Style reference: comfy/ldm/wan/model.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.patcher_extension
import comfy.ldm.common_dit
def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0):
"""Returns (cos, sin) each with shape [seq_len, dim].
Frequencies are computed at dim//2 resolution then repeat_interleaved
to full dim, matching CogVideoX's interleaved (real, imag) pair format.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim))
angles = torch.outer(pos.float(), freqs.float())
cos = angles.cos().repeat_interleave(2, dim=-1).float()
sin = angles.sin().repeat_interleave(2, dim=-1).float()
return (cos, sin)
def apply_rotary_emb(x, freqs_cos_sin):
"""Apply CogVideoX rotary embedding to query or key tensor.
x: [B, heads, seq_len, head_dim]
freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2]
Uses interleaved pair rotation (same as diffusers CogVideoX/Flux).
head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back.
"""
cos, sin = freqs_cos_sin
cos = cos[None, None, :, :].to(x.device)
sin = sin[None, None, :, :].to(x.device)
# Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
args = timesteps[:, None].float() * freqs[None] * scale
embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if flip_sin_to_cos:
embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None):
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale
grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale
grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale
grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij")
embed_dim_spatial = 2 * (embed_dim // 3)
embed_dim_temporal = embed_dim // 3
pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device)
pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device)
T, H, W = grid_t.shape
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1)
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1)
return pos_embed
def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None):
T, H, W = grid_h.shape
half_dim = embed_dim // 2
pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim)
pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim)
return torch.cat([pos_h, pos_w], dim=-1)
def _get_1d_sincos_pos_embed(embed_dim, pos, device=None):
half = embed_dim // 2
freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half)
args = pos.float().reshape(-1)[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if embed_dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class CogVideoXPatchEmbed(nn.Module):
def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920,
text_dim=4096, bias=True, sample_width=90, sample_height=60,
sample_frames=49, temporal_compression_ratio=4,
max_text_seq_length=226, spatial_interpolation_scale=1.875,
temporal_interpolation_scale=1.0, use_positional_embeddings=True,
use_learned_positional_embeddings=True,
device=None, dtype=None, operations=None):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.dim = dim
self.sample_height = sample_height
self.sample_width = sample_width
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
if patch_size_t is None:
self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype)
else:
self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype)
self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype)
if use_positional_embeddings or use_learned_positional_embeddings:
persistent = use_learned_positional_embeddings
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None):
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
if self.patch_size_t is not None:
post_time_compression_frames = post_time_compression_frames // self.patch_size_t
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=device,
)
pos_embedding = pos_embedding.reshape(-1, self.dim)
joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
return joint_pos_embedding
def forward(self, text_embeds, image_embeds):
text_embeds = self.text_proj(text_embeds)
batch_size, num_frames, channels, height, width = image_embeds.shape
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3)
image_embeds = image_embeds.flatten(1, 2)
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds)
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
text_seq_length = text_embeds.shape[1]
num_image_patches = image_embeds.shape[1]
# Compute sincos pos embedding for image patches
pos_embedding = get_3d_sincos_pos_embed(
self.dim,
(width // self.patch_size, height // self.patch_size),
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=embeds.device,
).reshape(-1, self.dim)
# Build joint: zeros for text + sincos for image
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype)
embeds = embeds + joint_pos
return embeds
class CogVideoXLayerNormZero(nn.Module):
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype)
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
def forward(self, hidden_states, encoder_hidden_states, temb):
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
class CogVideoXAdaLayerNorm(nn.Module):
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5,
device=None, dtype=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype)
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
def forward(self, x, temb):
temb = self.linear(self.silu(temb))
shift, scale = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class CogVideoXBlock(nn.Module):
def __init__(self, dim, num_heads, head_dim, time_dim,
eps=1e-5, ff_inner_dim=None, ff_bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = head_dim
self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
# Self-attention (joint text + latent)
self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
# Feed-forward (GELU approximate)
inner_dim = ff_inner_dim or dim * 4
self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype)
self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype)
def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options={}):
text_seq_length = encoder_hidden_states.size(1)
# Norm & modulate
norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb)
# Joint self-attention
qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1)
b, s, _ = qkv_input.shape
n, d = self.num_heads, self.head_dim
q = self.q(qkv_input).view(b, s, n, d)
k = self.k(qkv_input).view(b, s, n, d)
v = self.v(qkv_input)
q = self.norm_q(q).view(b, s, n, d)
k = self.norm_k(k).view(b, s, n, d)
# Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim])
if image_rotary_emb is not None:
q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim]
k_img = k[:, text_seq_length:].transpose(1, 2)
q_img = apply_rotary_emb(q_img, image_rotary_emb)
k_img = apply_rotary_emb(k_img, image_rotary_emb)
q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1)
k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1)
attn_out = optimized_attention(
q.reshape(b, s, n * d),
k.reshape(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
attn_out = self.attn_out(attn_out)
attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1)
hidden_states = hidden_states + gate_msa * attn_hidden
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder
# Norm & modulate for FF
norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb)
# Feed-forward (GELU on concatenated text + latent)
ff_input = torch.cat([norm_encoder, norm_hidden], dim=1)
ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh"))
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(nn.Module):
def __init__(self,
num_attention_heads=30,
attention_head_dim=64,
in_channels=16,
out_channels=16,
flip_sin_to_cos=True,
freq_shift=0,
time_embed_dim=512,
ofs_embed_dim=None,
text_embed_dim=4096,
num_layers=30,
dropout=0.0,
attention_bias=True,
sample_width=90,
sample_height=60,
sample_frames=49,
patch_size=2,
patch_size_t=None,
temporal_compression_ratio=4,
max_text_seq_length=226,
spatial_interpolation_scale=1.875,
temporal_interpolation_scale=1.0,
use_rotary_positional_embeddings=False,
use_learned_positional_embeddings=False,
patch_bias=True,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.dtype = dtype
dim = num_attention_heads * attention_head_dim
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.max_text_seq_length = max_text_seq_length
self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels,
dim=dim,
text_dim=text_embed_dim,
bias=patch_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
device=device, dtype=torch.float32, operations=operations,
)
# 2. Time embedding
self.time_proj_dim = dim
self.time_proj_flip = flip_sin_to_cos
self.time_proj_shift = freq_shift
self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype)
self.time_embedding_act = nn.SiLU()
self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype)
# Optional OFS embedding (CogVideoX 1.5 I2V)
self.ofs_proj_dim = ofs_embed_dim
if ofs_embed_dim:
self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
self.ofs_embedding_act = nn.SiLU()
self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
else:
self.ofs_embedding_linear_1 = None
# 3. Transformer blocks
self.blocks = nn.ModuleList([
CogVideoXBlock(
dim=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
time_dim=time_embed_dim,
eps=1e-5,
device=device, dtype=dtype, operations=operations,
)
for _ in range(num_layers)
])
self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype)
# 4. Output
self.norm_out = CogVideoXAdaLayerNorm(
time_dim=time_embed_dim, dim=dim, eps=1e-5,
device=device, dtype=dtype, operations=operations,
)
if patch_size_t is None:
output_dim = patch_size * patch_size * out_channels
else:
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype)
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.temporal_compression_ratio = temporal_compression_ratio
def forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, ofs, transformer_options, **kwargs)
def _forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs):
# ComfyUI passes [B, C, T, H, W]
batch_size, channels, t, h, w = x.shape
# Pad to patch size (temporal + spatial), same pattern as WAN
p_t = self.patch_size_t if self.patch_size_t is not None else 1
x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size))
# CogVideoX expects [B, T, C, H, W]
x = x.permute(0, 2, 1, 3, 4)
batch_size, num_frames, channels, height, width = x.shape
# Time embedding
t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift)
t_emb = t_emb.to(dtype=x.dtype)
emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb)))
if self.ofs_embedding_linear_1 is not None and ofs is not None:
ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift)
ofs_emb = ofs_emb.to(dtype=x.dtype)
ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb)))
emb = emb + ofs_emb
# Patch embedding
hidden_states = self.patch_embed(context, x)
text_seq_length = context.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# Rotary embeddings (if used)
image_rotary_emb = None
if self.use_rotary_positional_embeddings:
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
if self.patch_size_t is None:
post_time = num_frames
else:
post_time = num_frames // self.patch_size_t
image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device)
# Transformer blocks
for i, block in enumerate(self.blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
hidden_states = self.norm_final(hidden_states)
# Output projection
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# Unpatchify
p = self.patch_size
p_t = self.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
# Back to ComfyUI format [B, C, T, H, W] and crop padding
output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w]
return output
def _get_rotary_emb(self, h, w, t, device):
"""Compute CogVideoX 3D rotary positional embeddings.
For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions
are integer arange computed at max_size, then sliced to actual size.
For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords
scaled by spatial_interpolation_scale.
"""
d = self.attention_head_dim
dim_t = d // 4
dim_h = d // 8 * 3
dim_w = d // 8 * 3
if self.patch_size_t is not None:
# CogVideoX 1.5: "slice" mode — positions are simple integer indices
# Compute at max(sample_size, actual_size) then slice to actual
base_h = self.patch_embed.sample_height // self.patch_size
base_w = self.patch_embed.sample_width // self.patch_size
max_h = max(base_h, h)
max_w = max(base_w, w)
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
grid_t = torch.arange(t, device=device, dtype=torch.float32)
else:
# CogVideoX 1.0: "linspace" mode with interpolation scale
grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
grid_t = torch.arange(t, device=device, dtype=torch.float32)
freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t)
freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h)
freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w)
t_cos, t_sin = freqs_t
h_cos, h_sin = freqs_h
w_cos, w_sin = freqs_w
# Slice to actual size (for "slice" mode where grids may be larger)
t_cos, t_sin = t_cos[:t], t_sin[:t]
h_cos, h_sin = h_cos[:h], h_sin[:h]
w_cos, w_sin = w_cos[:w], w_sin[:w]
# Broadcast and concatenate into [T*H*W, head_dim]
t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1)
t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1)
h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1)
h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1)
w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1)
w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1)
cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1)
sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1)
return (cos, sin)