mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
914 lines
36 KiB
Python
914 lines
36 KiB
Python
from typing import Tuple
|
|
import torch
|
|
import torch.nn as nn
|
|
from comfy.ldm.lightricks.model import (
|
|
CrossAttention,
|
|
FeedForward,
|
|
AdaLayerNormSingle,
|
|
PixArtAlphaTextProjection,
|
|
LTXVModel,
|
|
)
|
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
|
import comfy.ldm.common_dit
|
|
|
|
class CompressedTimestep:
|
|
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
|
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
|
|
|
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
|
"""
|
|
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
|
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
|
"""
|
|
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
|
|
|
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
|
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
|
self.patches_per_frame = patches_per_frame
|
|
self.num_frames = num_tokens // patches_per_frame
|
|
|
|
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
|
# All patches in a frame are identical, so we only keep the first one
|
|
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
|
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
|
else:
|
|
# Not divisible or too small - store directly without compression
|
|
self.patches_per_frame = 1
|
|
self.num_frames = num_tokens
|
|
self.data = tensor
|
|
|
|
def expand(self):
|
|
"""Expand back to original tensor."""
|
|
if self.patches_per_frame == 1:
|
|
return self.data
|
|
|
|
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
|
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
|
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
|
|
|
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
|
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
|
num_ada_params = scale_shift_table.shape[0]
|
|
|
|
# No compression - compute directly
|
|
if self.patches_per_frame == 1:
|
|
num_tokens = self.data.shape[1]
|
|
dim_per_param = self.feature_dim // num_ada_params
|
|
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
|
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
|
ada_values = (table_values + reshaped).unbind(dim=2)
|
|
return ada_values
|
|
|
|
# Compressed: compute on per-frame data then expand spatially
|
|
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
|
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
|
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
|
device=self.data.device, dtype=self.data.dtype
|
|
)
|
|
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
|
|
|
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
|
return tuple(
|
|
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
|
.reshape(batch_size, -1, frame_val.shape[-1])
|
|
for frame_val in frame_ada
|
|
)
|
|
|
|
class BasicAVTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
v_dim,
|
|
a_dim,
|
|
v_heads,
|
|
a_heads,
|
|
vd_head,
|
|
ad_head,
|
|
v_context_dim=None,
|
|
a_context_dim=None,
|
|
attn_precision=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.attn_precision = attn_precision
|
|
|
|
self.attn1 = CrossAttention(
|
|
query_dim=v_dim,
|
|
heads=v_heads,
|
|
dim_head=vd_head,
|
|
context_dim=None,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.audio_attn1 = CrossAttention(
|
|
query_dim=a_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
context_dim=None,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.attn2 = CrossAttention(
|
|
query_dim=v_dim,
|
|
context_dim=v_context_dim,
|
|
heads=v_heads,
|
|
dim_head=vd_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.audio_attn2 = CrossAttention(
|
|
query_dim=a_dim,
|
|
context_dim=a_context_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Q: Video, K,V: Audio
|
|
self.audio_to_video_attn = CrossAttention(
|
|
query_dim=v_dim,
|
|
context_dim=a_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Q: Audio, K,V: Video
|
|
self.video_to_audio_attn = CrossAttention(
|
|
query_dim=a_dim,
|
|
context_dim=v_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.ff = FeedForward(
|
|
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.audio_ff = FeedForward(
|
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
|
self.audio_scale_shift_table = nn.Parameter(
|
|
torch.empty(6, a_dim, device=device, dtype=dtype)
|
|
)
|
|
|
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
|
)
|
|
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
|
torch.empty(5, v_dim, device=device, dtype=dtype)
|
|
)
|
|
|
|
def get_ada_values(
|
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
|
):
|
|
if isinstance(timestep, CompressedTimestep):
|
|
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
|
|
|
num_ada_params = scale_shift_table.shape[0]
|
|
|
|
ada_values = (
|
|
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
|
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
|
).unbind(dim=2)
|
|
return ada_values
|
|
|
|
def get_av_ca_ada_values(
|
|
self,
|
|
scale_shift_table: torch.Tensor,
|
|
batch_size: int,
|
|
scale_shift_timestep: torch.Tensor,
|
|
gate_timestep: torch.Tensor,
|
|
num_scale_shift_values: int = 4,
|
|
):
|
|
scale_shift_ada_values = self.get_ada_values(
|
|
scale_shift_table[:num_scale_shift_values, :],
|
|
batch_size,
|
|
scale_shift_timestep,
|
|
)
|
|
gate_ada_values = self.get_ada_values(
|
|
scale_shift_table[num_scale_shift_values:, :],
|
|
batch_size,
|
|
gate_timestep,
|
|
)
|
|
|
|
return (*scale_shift_ada_values, *gate_ada_values)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tuple[torch.Tensor, torch.Tensor],
|
|
v_context=None,
|
|
a_context=None,
|
|
attention_mask=None,
|
|
v_timestep=None,
|
|
a_timestep=None,
|
|
v_pe=None,
|
|
a_pe=None,
|
|
v_cross_pe=None,
|
|
a_cross_pe=None,
|
|
v_cross_scale_shift_timestep=None,
|
|
a_cross_scale_shift_timestep=None,
|
|
v_cross_gate_timestep=None,
|
|
a_cross_gate_timestep=None,
|
|
transformer_options=None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
run_vx = transformer_options.get("run_vx", True)
|
|
run_ax = transformer_options.get("run_ax", True)
|
|
|
|
vx, ax = x
|
|
run_ax = run_ax and ax.numel() > 0
|
|
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
|
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
|
|
|
if run_vx:
|
|
vshift_msa, vscale_msa, vgate_msa = (
|
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
|
)
|
|
|
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
|
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
|
vx += self.attn2(
|
|
comfy.ldm.common_dit.rms_norm(vx),
|
|
context=v_context,
|
|
mask=attention_mask,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
del vshift_msa, vscale_msa, vgate_msa
|
|
|
|
if run_ax:
|
|
ashift_msa, ascale_msa, agate_msa = (
|
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
|
)
|
|
|
|
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
|
ax += (
|
|
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
|
* agate_msa
|
|
)
|
|
ax += self.audio_attn2(
|
|
comfy.ldm.common_dit.rms_norm(ax),
|
|
context=a_context,
|
|
mask=attention_mask,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
del ashift_msa, ascale_msa, agate_msa
|
|
|
|
# Audio - Video cross attention.
|
|
if run_a2v or run_v2a:
|
|
# norm3
|
|
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
|
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
|
|
|
(
|
|
scale_ca_audio_hidden_states_a2v,
|
|
shift_ca_audio_hidden_states_a2v,
|
|
scale_ca_audio_hidden_states_v2a,
|
|
shift_ca_audio_hidden_states_v2a,
|
|
gate_out_v2a,
|
|
) = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_audio,
|
|
ax.shape[0],
|
|
a_cross_scale_shift_timestep,
|
|
a_cross_gate_timestep,
|
|
)
|
|
|
|
(
|
|
scale_ca_video_hidden_states_a2v,
|
|
shift_ca_video_hidden_states_a2v,
|
|
scale_ca_video_hidden_states_v2a,
|
|
shift_ca_video_hidden_states_v2a,
|
|
gate_out_a2v,
|
|
) = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_video,
|
|
vx.shape[0],
|
|
v_cross_scale_shift_timestep,
|
|
v_cross_gate_timestep,
|
|
)
|
|
|
|
if run_a2v:
|
|
vx_scaled = (
|
|
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
|
+ shift_ca_video_hidden_states_a2v
|
|
)
|
|
ax_scaled = (
|
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
|
+ shift_ca_audio_hidden_states_a2v
|
|
)
|
|
vx += (
|
|
self.audio_to_video_attn(
|
|
vx_scaled,
|
|
context=ax_scaled,
|
|
pe=v_cross_pe,
|
|
k_pe=a_cross_pe,
|
|
transformer_options=transformer_options,
|
|
)
|
|
* gate_out_a2v
|
|
)
|
|
|
|
del gate_out_a2v
|
|
del scale_ca_video_hidden_states_a2v,\
|
|
shift_ca_video_hidden_states_a2v,\
|
|
scale_ca_audio_hidden_states_a2v,\
|
|
shift_ca_audio_hidden_states_a2v,\
|
|
|
|
if run_v2a:
|
|
ax_scaled = (
|
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
|
+ shift_ca_audio_hidden_states_v2a
|
|
)
|
|
vx_scaled = (
|
|
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
|
+ shift_ca_video_hidden_states_v2a
|
|
)
|
|
ax += (
|
|
self.video_to_audio_attn(
|
|
ax_scaled,
|
|
context=vx_scaled,
|
|
pe=a_cross_pe,
|
|
k_pe=v_cross_pe,
|
|
transformer_options=transformer_options,
|
|
)
|
|
* gate_out_v2a
|
|
)
|
|
|
|
del gate_out_v2a
|
|
del scale_ca_video_hidden_states_v2a,\
|
|
shift_ca_video_hidden_states_v2a,\
|
|
scale_ca_audio_hidden_states_v2a,\
|
|
shift_ca_audio_hidden_states_v2a
|
|
|
|
if run_vx:
|
|
vshift_mlp, vscale_mlp, vgate_mlp = (
|
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
|
)
|
|
|
|
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
|
vx += self.ff(vx_scaled) * vgate_mlp
|
|
del vshift_mlp, vscale_mlp, vgate_mlp
|
|
|
|
if run_ax:
|
|
ashift_mlp, ascale_mlp, agate_mlp = (
|
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
|
)
|
|
|
|
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
|
ax += self.audio_ff(ax_scaled) * agate_mlp
|
|
|
|
del ashift_mlp, ascale_mlp, agate_mlp
|
|
|
|
|
|
return vx, ax
|
|
|
|
|
|
class LTXAVModel(LTXVModel):
|
|
"""LTXAV model for audio-video generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=128,
|
|
audio_in_channels=128,
|
|
cross_attention_dim=4096,
|
|
audio_cross_attention_dim=2048,
|
|
attention_head_dim=128,
|
|
audio_attention_head_dim=64,
|
|
num_attention_heads=32,
|
|
audio_num_attention_heads=32,
|
|
caption_channels=3840,
|
|
num_layers=48,
|
|
positional_embedding_theta=10000.0,
|
|
positional_embedding_max_pos=[20, 2048, 2048],
|
|
audio_positional_embedding_max_pos=[20],
|
|
causal_temporal_positioning=False,
|
|
vae_scale_factors=(8, 32, 32),
|
|
use_middle_indices_grid=False,
|
|
timestep_scale_multiplier=1000.0,
|
|
av_ca_timestep_scale_multiplier=1.0,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
# Store audio-specific parameters
|
|
self.audio_in_channels = audio_in_channels
|
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
|
self.audio_attention_head_dim = audio_attention_head_dim
|
|
self.audio_num_attention_heads = audio_num_attention_heads
|
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
|
|
|
# Calculate audio dimensions
|
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
|
self.audio_out_channels = audio_in_channels
|
|
|
|
# Audio-specific constants
|
|
self.num_audio_channels = 8
|
|
self.audio_frequency_bins = 16
|
|
|
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
|
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=attention_head_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
caption_channels=caption_channels,
|
|
num_layers=num_layers,
|
|
positional_embedding_theta=positional_embedding_theta,
|
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
|
causal_temporal_positioning=causal_temporal_positioning,
|
|
vae_scale_factors=vae_scale_factors,
|
|
use_middle_indices_grid=use_middle_indices_grid,
|
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
**kwargs,
|
|
)
|
|
|
|
def _init_model_components(self, device, dtype, **kwargs):
|
|
"""Initialize LTXAV-specific components."""
|
|
# Audio-specific projections
|
|
self.audio_patchify_proj = self.operations.Linear(
|
|
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
|
)
|
|
|
|
# Audio-specific AdaLN
|
|
self.audio_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
use_additional_conditions=False,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
num_scale_shift_values = 4
|
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=num_scale_shift_values,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=1,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=num_scale_shift_values,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=1,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
# Audio caption projection
|
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
|
in_features=self.caption_channels,
|
|
hidden_size=self.audio_inner_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
|
"""Initialize transformer blocks for LTXAV."""
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicAVTransformerBlock(
|
|
v_dim=self.inner_dim,
|
|
a_dim=self.audio_inner_dim,
|
|
v_heads=self.num_attention_heads,
|
|
a_heads=self.audio_num_attention_heads,
|
|
vd_head=self.attention_head_dim,
|
|
ad_head=self.audio_attention_head_dim,
|
|
v_context_dim=self.cross_attention_dim,
|
|
a_context_dim=self.audio_cross_attention_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
for _ in range(self.num_layers)
|
|
]
|
|
)
|
|
|
|
def _init_output_components(self, device, dtype):
|
|
"""Initialize output components for LTXAV."""
|
|
# Video output components
|
|
super()._init_output_components(device, dtype)
|
|
# Audio output components
|
|
self.audio_scale_shift_table = nn.Parameter(
|
|
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
|
)
|
|
self.audio_norm_out = self.operations.LayerNorm(
|
|
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
|
)
|
|
self.audio_proj_out = self.operations.Linear(
|
|
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
|
)
|
|
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
|
|
|
def separate_audio_and_video_latents(self, x, audio_length):
|
|
"""Separate audio and video latents from combined input."""
|
|
# vx = x[:, : self.in_channels]
|
|
# ax = x[:, self.in_channels :]
|
|
#
|
|
# ax = ax.reshape(ax.shape[0], -1)
|
|
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
|
#
|
|
# ax = ax.reshape(
|
|
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
|
# )
|
|
|
|
vx = x[0]
|
|
ax = x[1] if len(x) > 1 else torch.zeros(
|
|
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
|
device=vx.device, dtype=vx.dtype
|
|
)
|
|
return vx, ax
|
|
|
|
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
|
if ax.numel() == 0:
|
|
return vx
|
|
else:
|
|
return [vx, ax]
|
|
"""Recombine audio and video latents for output."""
|
|
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
|
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
|
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
|
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
|
#
|
|
# ax = ax.reshape(ax.shape[0], -1)
|
|
# # pad to f x h x w of the video latents
|
|
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
|
# if target_shape is None:
|
|
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
|
# else:
|
|
# repetitions = target_shape[1] - vx.shape[1]
|
|
# padded_len = repetitions * divisor
|
|
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
|
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
|
# return torch.cat([vx, ax], dim=1)
|
|
|
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
|
"""Process input for LTXAV - separate audio and video, then patchify."""
|
|
audio_length = kwargs.get("audio_length", 0)
|
|
# Separate audio and video latents
|
|
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
|
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
|
vx, keyframe_idxs, denoise_mask, **kwargs
|
|
)
|
|
|
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
|
ax = self.audio_patchify_proj(ax)
|
|
|
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
|
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
|
|
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
|
"""Prepare timestep embeddings."""
|
|
# TODO: some code reuse is needed here.
|
|
grid_mask = kwargs.get("grid_mask", None)
|
|
if grid_mask is not None:
|
|
timestep = timestep[:, grid_mask]
|
|
|
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
|
|
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
|
timestep_scaled.flatten(),
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
|
|
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
|
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
|
orig_shape = kwargs.get("orig_shape")
|
|
v_patches_per_frame = None
|
|
if orig_shape is not None and len(orig_shape) == 5:
|
|
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
|
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
|
|
|
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
|
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
|
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
|
|
|
# Prepare audio timestep
|
|
a_timestep = kwargs.get("a_timestep")
|
|
if a_timestep is not None:
|
|
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
|
a_timestep_flat = a_timestep_scaled.flatten()
|
|
timestep_flat = timestep_scaled.flatten()
|
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
|
|
|
# Cross-attention timesteps - compress these too
|
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
|
a_timestep_flat,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
|
timestep_flat,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
|
timestep_flat * av_ca_factor,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
|
a_timestep_flat * av_ca_factor,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
|
|
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
|
cross_av_timestep_ss = [
|
|
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
|
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
|
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
|
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
|
]
|
|
|
|
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
|
a_timestep_flat,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
# Audio timesteps
|
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
|
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
|
else:
|
|
a_timestep = timestep_scaled
|
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
|
cross_av_timestep_ss = []
|
|
|
|
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
|
v_embedded_timestep,
|
|
a_embedded_timestep,
|
|
]
|
|
|
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
|
vx = x[0]
|
|
ax = x[1]
|
|
v_context, a_context = torch.split(
|
|
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
|
)
|
|
|
|
v_context, attention_mask = super()._prepare_context(
|
|
v_context, batch_size, vx, attention_mask
|
|
)
|
|
if self.audio_caption_projection is not None:
|
|
a_context = self.audio_caption_projection(a_context)
|
|
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
|
|
|
return [v_context, a_context], attention_mask
|
|
|
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
|
v_pixel_coords = pixel_coords[0]
|
|
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
|
|
|
a_latent_coords = pixel_coords[1]
|
|
a_pe = self._precompute_freqs_cis(
|
|
a_latent_coords,
|
|
dim=self.audio_inner_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=self.audio_positional_embedding_max_pos,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
)
|
|
|
|
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
|
max_pos = max(
|
|
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
|
)
|
|
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
|
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
|
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
|
v_pixel_coords[:, 0:1, :],
|
|
dim=self.audio_cross_attention_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=[max_pos],
|
|
use_middle_indices_grid=True,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
)
|
|
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
|
a_latent_coords[:, 0:1, :],
|
|
dim=self.audio_cross_attention_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=[max_pos],
|
|
use_middle_indices_grid=True,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
)
|
|
|
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
|
|
|
def _process_transformer_blocks(
|
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
|
):
|
|
vx = x[0]
|
|
ax = x[1]
|
|
v_context = context[0]
|
|
a_context = context[1]
|
|
v_timestep = timestep[0]
|
|
a_timestep = timestep[1]
|
|
v_pe, av_cross_video_freq_cis = pe[0]
|
|
a_pe, av_cross_audio_freq_cis = pe[1]
|
|
|
|
(
|
|
av_ca_audio_scale_shift_timestep,
|
|
av_ca_video_scale_shift_timestep,
|
|
av_ca_a2v_gate_noise_timestep,
|
|
av_ca_v2a_gate_noise_timestep,
|
|
) = timestep[2]
|
|
|
|
"""Process transformer blocks for LTXAV."""
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
|
|
# Process transformer blocks
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if ("double_block", i) in blocks_replace:
|
|
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"] = block(
|
|
args["img"],
|
|
v_context=args["v_context"],
|
|
a_context=args["a_context"],
|
|
attention_mask=args["attention_mask"],
|
|
v_timestep=args["v_timestep"],
|
|
a_timestep=args["a_timestep"],
|
|
v_pe=args["v_pe"],
|
|
a_pe=args["a_pe"],
|
|
v_cross_pe=args["v_cross_pe"],
|
|
a_cross_pe=args["a_cross_pe"],
|
|
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
|
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
|
transformer_options=args["transformer_options"],
|
|
)
|
|
return out
|
|
|
|
out = blocks_replace[("double_block", i)](
|
|
{
|
|
"img": (vx, ax),
|
|
"v_context": v_context,
|
|
"a_context": a_context,
|
|
"attention_mask": attention_mask,
|
|
"v_timestep": v_timestep,
|
|
"a_timestep": a_timestep,
|
|
"v_pe": v_pe,
|
|
"a_pe": a_pe,
|
|
"v_cross_pe": av_cross_video_freq_cis,
|
|
"a_cross_pe": av_cross_audio_freq_cis,
|
|
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
|
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
|
"transformer_options": transformer_options,
|
|
},
|
|
{"original_block": block_wrap},
|
|
)
|
|
vx, ax = out["img"]
|
|
else:
|
|
vx, ax = block(
|
|
(vx, ax),
|
|
v_context=v_context,
|
|
a_context=a_context,
|
|
attention_mask=attention_mask,
|
|
v_timestep=v_timestep,
|
|
a_timestep=a_timestep,
|
|
v_pe=v_pe,
|
|
a_pe=a_pe,
|
|
v_cross_pe=av_cross_video_freq_cis,
|
|
a_cross_pe=av_cross_audio_freq_cis,
|
|
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
|
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
return [vx, ax]
|
|
|
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
|
vx = x[0]
|
|
ax = x[1]
|
|
v_embedded_timestep = embedded_timestep[0]
|
|
a_embedded_timestep = embedded_timestep[1]
|
|
|
|
# Expand compressed video timestep if needed
|
|
if isinstance(v_embedded_timestep, CompressedTimestep):
|
|
v_embedded_timestep = v_embedded_timestep.expand()
|
|
|
|
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
|
|
|
# Process audio output
|
|
a_scale_shift_values = (
|
|
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
|
+ a_embedded_timestep[:, :, None]
|
|
)
|
|
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
|
|
|
ax = self.audio_norm_out(ax)
|
|
ax = ax * (1 + a_scale) + a_shift
|
|
ax = self.audio_proj_out(ax)
|
|
|
|
# Unpatchify audio
|
|
ax = self.a_patchifier.unpatchify(
|
|
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
|
)
|
|
|
|
# Recombine audio and video
|
|
original_shape = kwargs.get("av_orig_shape")
|
|
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
timestep,
|
|
context,
|
|
attention_mask=None,
|
|
frame_rate=25,
|
|
transformer_options={},
|
|
keyframe_idxs=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Forward pass for LTXAV model.
|
|
|
|
Args:
|
|
x: Combined audio-video input tensor
|
|
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
|
context: Context tensor (e.g., text embeddings)
|
|
attention_mask: Attention mask tensor
|
|
frame_rate: Frame rate for temporal processing
|
|
transformer_options: Additional options for transformer blocks
|
|
keyframe_idxs: Keyframe indices for temporal processing
|
|
**kwargs: Additional keyword arguments including audio_length
|
|
|
|
Returns:
|
|
Combined audio-video output tensor
|
|
"""
|
|
# Handle timestep format
|
|
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
|
v_timestep, a_timestep = timestep
|
|
kwargs["a_timestep"] = a_timestep
|
|
timestep = v_timestep
|
|
else:
|
|
kwargs["a_timestep"] = timestep
|
|
|
|
# Call parent forward method
|
|
return super().forward(
|
|
x,
|
|
timestep,
|
|
context,
|
|
attention_mask,
|
|
frame_rate,
|
|
transformer_options,
|
|
keyframe_idxs,
|
|
**kwargs,
|
|
)
|