ComfyUI/comfy/ldm/lightricks/av_model.py
2026-01-05 01:58:59 -05:00

838 lines
31 KiB
Python

from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
LTXVModel,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
import comfy.ldm.common_dit
class BasicAVTransformerBlock(nn.Module):
def __init__(
self,
v_dim,
a_dim,
v_heads,
a_heads,
vd_head,
ad_head,
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(
query_dim=v_dim,
heads=v_heads,
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn1 = CrossAttention(
query_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.attn2 = CrossAttention(
query_dim=v_dim,
context_dim=v_context_dim,
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn2 = CrossAttention(
query_dim=a_dim,
context_dim=a_context_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Video, K,V: Audio
self.audio_to_video_attn = CrossAttention(
query_dim=v_dim,
context_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Audio, K,V: Video
self.video_to_audio_attn = CrossAttention(
query_dim=a_dim,
context_dim=v_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.ff = FeedForward(
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.audio_ff = FeedForward(
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_video = nn.Parameter(
torch.empty(5, v_dim, device=device, dtype=dtype)
)
def get_ada_values(
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
):
num_ada_params = scale_shift_table.shape[0]
ada_values = (
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
).unbind(dim=2)
return ada_values
def get_av_ca_ada_values(
self,
scale_shift_table: torch.Tensor,
batch_size: int,
scale_shift_timestep: torch.Tensor,
gate_timestep: torch.Tensor,
num_scale_shift_values: int = 4,
):
scale_shift_ada_values = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
)
gate_ada_values = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
)
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
return (*scale_shift_chunks, *gate_ada_values)
def forward(
self,
x: Tuple[torch.Tensor, torch.Tensor],
v_context=None,
a_context=None,
attention_mask=None,
v_timestep=None,
a_timestep=None,
v_pe=None,
a_pe=None,
v_cross_pe=None,
a_cross_pe=None,
v_cross_scale_shift_timestep=None,
a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None,
a_cross_gate_timestep=None,
transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
vx, ax = x
run_ax = run_ax and ax.numel() > 0
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
if run_vx:
vshift_msa, vscale_msa, vgate_msa = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
)
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
vx += self.attn2(
comfy.ldm.common_dit.rms_norm(vx),
context=v_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del vshift_msa, vscale_msa, vgate_msa
if run_ax:
ashift_msa, ascale_msa, agate_msa = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
)
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
ax += (
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
* agate_msa
)
ax += self.audio_attn2(
comfy.ldm.common_dit.rms_norm(ax),
context=a_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del ashift_msa, ascale_msa, agate_msa
# Audio - Video cross attention.
if run_a2v or run_v2a:
# norm3
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
a_cross_scale_shift_timestep,
a_cross_gate_timestep,
)
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
v_cross_scale_shift_timestep,
v_cross_gate_timestep,
)
if run_a2v:
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
+ shift_ca_video_hidden_states_a2v
)
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
+ shift_ca_audio_hidden_states_a2v
)
vx += (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=v_cross_pe,
k_pe=a_cross_pe,
transformer_options=transformer_options,
)
* gate_out_a2v
)
del gate_out_a2v
del scale_ca_video_hidden_states_a2v,\
shift_ca_video_hidden_states_a2v,\
scale_ca_audio_hidden_states_a2v,\
shift_ca_audio_hidden_states_a2v,\
if run_v2a:
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
+ shift_ca_audio_hidden_states_v2a
)
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
+ shift_ca_video_hidden_states_v2a
)
ax += (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=a_cross_pe,
k_pe=v_cross_pe,
transformer_options=transformer_options,
)
* gate_out_v2a
)
del gate_out_v2a
del scale_ca_video_hidden_states_v2a,\
shift_ca_video_hidden_states_v2a,\
scale_ca_audio_hidden_states_v2a,\
shift_ca_audio_hidden_states_v2a
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
)
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
vx += self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
)
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
ax += self.audio_ff(ax_scaled) * agate_mlp
del ashift_mlp, ascale_mlp, agate_mlp
return vx, ax
class LTXAVModel(LTXVModel):
"""LTXAV model for audio-video generation."""
def __init__(
self,
in_channels=128,
audio_in_channels=128,
cross_attention_dim=4096,
audio_cross_attention_dim=2048,
attention_head_dim=128,
audio_attention_head_dim=64,
num_attention_heads=32,
audio_num_attention_heads=32,
caption_channels=3840,
num_layers=48,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
audio_positional_embedding_max_pos=[20],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
# Store audio-specific parameters
self.audio_in_channels = audio_in_channels
self.audio_cross_attention_dim = audio_cross_attention_dim
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
self.audio_out_channels = audio_in_channels
# Audio-specific constants
self.num_audio_channels = 8
self.audio_frequency_bins = 16
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXAV-specific components."""
# Audio-specific projections
self.audio_patchify_proj = self.operations.Linear(
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
)
# Audio-specific AdaLN
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList(
[
BasicAVTransformerBlock(
v_dim=self.inner_dim,
a_dim=self.audio_inner_dim,
v_heads=self.num_attention_heads,
a_heads=self.audio_num_attention_heads,
vd_head=self.attention_head_dim,
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
for _ in range(self.num_layers)
]
)
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXAV."""
# Video output components
super()._init_output_components(device, dtype)
# Audio output components
self.audio_scale_shift_table = nn.Parameter(
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
)
self.audio_norm_out = self.operations.LayerNorm(
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.audio_proj_out = self.operations.Linear(
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
)
self.a_patchifier = AudioPatchifier(1, start_end=True)
def separate_audio_and_video_latents(self, x, audio_length):
"""Separate audio and video latents from combined input."""
# vx = x[:, : self.in_channels]
# ax = x[:, self.in_channels :]
#
# ax = ax.reshape(ax.shape[0], -1)
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
#
# ax = ax.reshape(
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
# )
vx = x[0]
ax = x[1] if len(x) > 1 else torch.zeros(
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
device=vx.device, dtype=vx.dtype
)
return vx, ax
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
if ax.numel() == 0:
return vx
else:
return [vx, ax]
"""Recombine audio and video latents for output."""
# if ax.device != vx.device or ax.dtype != vx.dtype:
# logging.warning("Audio and video latents are on different devices or dtypes.")
# ax = ax.to(device=vx.device, dtype=vx.dtype)
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
#
# ax = ax.reshape(ax.shape[0], -1)
# # pad to f x h x w of the video latents
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
# if target_shape is None:
# repetitions = math.ceil(ax.shape[-1] / divisor)
# else:
# repetitions = target_shape[1] - vx.shape[1]
# padded_len = repetitions * divisor
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
# return torch.cat([vx, ax], dim=1)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXAV - separate audio and video, then patchify."""
audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
[vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs
)
ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
v_embedded_timestep = v_embedded_timestep.view(
batch_size, -1, v_embedded_timestep.shape[-1]
)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
a_timestep = a_timestep * self.timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep.flatten() * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep.flatten() * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
a_timestep, a_embedded_timestep = self.audio_adaln_single(
a_timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(
batch_size, -1, a_embedded_timestep.shape[-1]
)
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
]
cross_av_timestep_ss = list(
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
)
else:
a_timestep = timestep
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
return [v_timestep, a_timestep, cross_av_timestep_ss], [
v_embedded_timestep,
a_embedded_timestep,
]
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
return [v_context, a_context], attention_mask
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
v_pixel_coords = pixel_coords[0]
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
a_latent_coords = pixel_coords[1]
a_pe = self._precompute_freqs_cis(
a_latent_coords,
dim=self.audio_inner_dim,
out_dtype=x_dtype,
max_pos=self.audio_positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.audio_num_attention_heads,
)
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
max_pos = max(
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
)
v_pixel_coords = v_pixel_coords.to(torch.float32)
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
av_cross_video_freq_cis = self._precompute_freqs_cis(
v_pixel_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
av_cross_audio_freq_cis = self._precompute_freqs_cis(
a_latent_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
):
vx = x[0]
ax = x[1]
v_context = context[0]
a_context = context[1]
v_timestep = timestep[0]
a_timestep = timestep[1]
v_pe, av_cross_video_freq_cis = pe[0]
a_pe, av_cross_audio_freq_cis = pe[1]
(
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(
args["img"],
v_context=args["v_context"],
a_context=args["a_context"],
attention_mask=args["attention_mask"],
v_timestep=args["v_timestep"],
a_timestep=args["a_timestep"],
v_pe=args["v_pe"],
a_pe=args["a_pe"],
v_cross_pe=args["v_cross_pe"],
a_cross_pe=args["a_cross_pe"],
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
)
return out
out = blocks_replace[("double_block", i)](
{
"img": (vx, ax),
"v_context": v_context,
"a_context": a_context,
"attention_mask": attention_mask,
"v_timestep": v_timestep,
"a_timestep": a_timestep,
"v_pe": v_pe,
"a_pe": a_pe,
"v_cross_pe": av_cross_video_freq_cis,
"a_cross_pe": av_cross_audio_freq_cis,
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
vx, ax = out["img"]
else:
vx, ax = block(
(vx, ax),
v_context=v_context,
a_context=a_context,
attention_mask=attention_mask,
v_timestep=v_timestep,
a_timestep=a_timestep,
v_pe=v_pe,
a_pe=a_pe,
v_cross_pe=av_cross_video_freq_cis,
a_cross_pe=av_cross_audio_freq_cis,
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
vx = x[0]
ax = x[1]
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
# Process audio output
a_scale_shift_values = (
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
+ a_embedded_timestep[:, :, None]
)
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
ax = self.audio_norm_out(ax)
ax = ax * (1 + a_scale) + a_shift
ax = self.audio_proj_out(ax)
# Unpatchify audio
ax = self.a_patchifier.unpatchify(
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
)
# Recombine audio and video
original_shape = kwargs.get("av_orig_shape")
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
def forward(
self,
x,
timestep,
context,
attention_mask=None,
frame_rate=25,
transformer_options={},
keyframe_idxs=None,
**kwargs,
):
"""
Forward pass for LTXAV model.
Args:
x: Combined audio-video input tensor
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments including audio_length
Returns:
Combined audio-video output tensor
"""
# Handle timestep format
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
v_timestep, a_timestep = timestep
kwargs["a_timestep"] = a_timestep
timestep = v_timestep
else:
kwargs["a_timestep"] = timestep
# Call parent forward method
return super().forward(
x,
timestep,
context,
attention_mask,
frame_rate,
transformer_options,
keyframe_idxs,
**kwargs,
)