mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 22:00:49 +08:00
838 lines
31 KiB
Python
838 lines
31 KiB
Python
from typing import Tuple
|
|
import torch
|
|
import torch.nn as nn
|
|
from comfy.ldm.lightricks.model import (
|
|
CrossAttention,
|
|
FeedForward,
|
|
AdaLayerNormSingle,
|
|
PixArtAlphaTextProjection,
|
|
LTXVModel,
|
|
)
|
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
|
import comfy.ldm.common_dit
|
|
|
|
class BasicAVTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
v_dim,
|
|
a_dim,
|
|
v_heads,
|
|
a_heads,
|
|
vd_head,
|
|
ad_head,
|
|
v_context_dim=None,
|
|
a_context_dim=None,
|
|
attn_precision=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.attn_precision = attn_precision
|
|
|
|
self.attn1 = CrossAttention(
|
|
query_dim=v_dim,
|
|
heads=v_heads,
|
|
dim_head=vd_head,
|
|
context_dim=None,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.audio_attn1 = CrossAttention(
|
|
query_dim=a_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
context_dim=None,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.attn2 = CrossAttention(
|
|
query_dim=v_dim,
|
|
context_dim=v_context_dim,
|
|
heads=v_heads,
|
|
dim_head=vd_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.audio_attn2 = CrossAttention(
|
|
query_dim=a_dim,
|
|
context_dim=a_context_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Q: Video, K,V: Audio
|
|
self.audio_to_video_attn = CrossAttention(
|
|
query_dim=v_dim,
|
|
context_dim=a_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
# Q: Audio, K,V: Video
|
|
self.video_to_audio_attn = CrossAttention(
|
|
query_dim=a_dim,
|
|
context_dim=v_dim,
|
|
heads=a_heads,
|
|
dim_head=ad_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.ff = FeedForward(
|
|
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
|
|
)
|
|
self.audio_ff = FeedForward(
|
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
|
self.audio_scale_shift_table = nn.Parameter(
|
|
torch.empty(6, a_dim, device=device, dtype=dtype)
|
|
)
|
|
|
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
|
)
|
|
self.scale_shift_table_a2v_ca_video = nn.Parameter(
|
|
torch.empty(5, v_dim, device=device, dtype=dtype)
|
|
)
|
|
|
|
def get_ada_values(
|
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
|
):
|
|
num_ada_params = scale_shift_table.shape[0]
|
|
|
|
ada_values = (
|
|
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
|
|
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
|
|
).unbind(dim=2)
|
|
return ada_values
|
|
|
|
def get_av_ca_ada_values(
|
|
self,
|
|
scale_shift_table: torch.Tensor,
|
|
batch_size: int,
|
|
scale_shift_timestep: torch.Tensor,
|
|
gate_timestep: torch.Tensor,
|
|
num_scale_shift_values: int = 4,
|
|
):
|
|
scale_shift_ada_values = self.get_ada_values(
|
|
scale_shift_table[:num_scale_shift_values, :],
|
|
batch_size,
|
|
scale_shift_timestep,
|
|
)
|
|
gate_ada_values = self.get_ada_values(
|
|
scale_shift_table[num_scale_shift_values:, :],
|
|
batch_size,
|
|
gate_timestep,
|
|
)
|
|
|
|
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
|
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
|
|
|
return (*scale_shift_chunks, *gate_ada_values)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tuple[torch.Tensor, torch.Tensor],
|
|
v_context=None,
|
|
a_context=None,
|
|
attention_mask=None,
|
|
v_timestep=None,
|
|
a_timestep=None,
|
|
v_pe=None,
|
|
a_pe=None,
|
|
v_cross_pe=None,
|
|
a_cross_pe=None,
|
|
v_cross_scale_shift_timestep=None,
|
|
a_cross_scale_shift_timestep=None,
|
|
v_cross_gate_timestep=None,
|
|
a_cross_gate_timestep=None,
|
|
transformer_options=None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
run_vx = transformer_options.get("run_vx", True)
|
|
run_ax = transformer_options.get("run_ax", True)
|
|
|
|
vx, ax = x
|
|
run_ax = run_ax and ax.numel() > 0
|
|
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
|
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
|
|
|
if run_vx:
|
|
vshift_msa, vscale_msa, vgate_msa = (
|
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
|
)
|
|
|
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
|
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
|
vx += self.attn2(
|
|
comfy.ldm.common_dit.rms_norm(vx),
|
|
context=v_context,
|
|
mask=attention_mask,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
del vshift_msa, vscale_msa, vgate_msa
|
|
|
|
if run_ax:
|
|
ashift_msa, ascale_msa, agate_msa = (
|
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
|
)
|
|
|
|
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
|
ax += (
|
|
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
|
* agate_msa
|
|
)
|
|
ax += self.audio_attn2(
|
|
comfy.ldm.common_dit.rms_norm(ax),
|
|
context=a_context,
|
|
mask=attention_mask,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
del ashift_msa, ascale_msa, agate_msa
|
|
|
|
# Audio - Video cross attention.
|
|
if run_a2v or run_v2a:
|
|
# norm3
|
|
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
|
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
|
|
|
(
|
|
scale_ca_audio_hidden_states_a2v,
|
|
shift_ca_audio_hidden_states_a2v,
|
|
scale_ca_audio_hidden_states_v2a,
|
|
shift_ca_audio_hidden_states_v2a,
|
|
gate_out_v2a,
|
|
) = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_audio,
|
|
ax.shape[0],
|
|
a_cross_scale_shift_timestep,
|
|
a_cross_gate_timestep,
|
|
)
|
|
|
|
(
|
|
scale_ca_video_hidden_states_a2v,
|
|
shift_ca_video_hidden_states_a2v,
|
|
scale_ca_video_hidden_states_v2a,
|
|
shift_ca_video_hidden_states_v2a,
|
|
gate_out_a2v,
|
|
) = self.get_av_ca_ada_values(
|
|
self.scale_shift_table_a2v_ca_video,
|
|
vx.shape[0],
|
|
v_cross_scale_shift_timestep,
|
|
v_cross_gate_timestep,
|
|
)
|
|
|
|
if run_a2v:
|
|
vx_scaled = (
|
|
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
|
+ shift_ca_video_hidden_states_a2v
|
|
)
|
|
ax_scaled = (
|
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
|
+ shift_ca_audio_hidden_states_a2v
|
|
)
|
|
vx += (
|
|
self.audio_to_video_attn(
|
|
vx_scaled,
|
|
context=ax_scaled,
|
|
pe=v_cross_pe,
|
|
k_pe=a_cross_pe,
|
|
transformer_options=transformer_options,
|
|
)
|
|
* gate_out_a2v
|
|
)
|
|
|
|
del gate_out_a2v
|
|
del scale_ca_video_hidden_states_a2v,\
|
|
shift_ca_video_hidden_states_a2v,\
|
|
scale_ca_audio_hidden_states_a2v,\
|
|
shift_ca_audio_hidden_states_a2v,\
|
|
|
|
if run_v2a:
|
|
ax_scaled = (
|
|
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
|
+ shift_ca_audio_hidden_states_v2a
|
|
)
|
|
vx_scaled = (
|
|
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
|
+ shift_ca_video_hidden_states_v2a
|
|
)
|
|
ax += (
|
|
self.video_to_audio_attn(
|
|
ax_scaled,
|
|
context=vx_scaled,
|
|
pe=a_cross_pe,
|
|
k_pe=v_cross_pe,
|
|
transformer_options=transformer_options,
|
|
)
|
|
* gate_out_v2a
|
|
)
|
|
|
|
del gate_out_v2a
|
|
del scale_ca_video_hidden_states_v2a,\
|
|
shift_ca_video_hidden_states_v2a,\
|
|
scale_ca_audio_hidden_states_v2a,\
|
|
shift_ca_audio_hidden_states_v2a
|
|
|
|
if run_vx:
|
|
vshift_mlp, vscale_mlp, vgate_mlp = (
|
|
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
|
)
|
|
|
|
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
|
vx += self.ff(vx_scaled) * vgate_mlp
|
|
del vshift_mlp, vscale_mlp, vgate_mlp
|
|
|
|
if run_ax:
|
|
ashift_mlp, ascale_mlp, agate_mlp = (
|
|
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
|
)
|
|
|
|
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
|
ax += self.audio_ff(ax_scaled) * agate_mlp
|
|
|
|
del ashift_mlp, ascale_mlp, agate_mlp
|
|
|
|
|
|
return vx, ax
|
|
|
|
|
|
class LTXAVModel(LTXVModel):
|
|
"""LTXAV model for audio-video generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=128,
|
|
audio_in_channels=128,
|
|
cross_attention_dim=4096,
|
|
audio_cross_attention_dim=2048,
|
|
attention_head_dim=128,
|
|
audio_attention_head_dim=64,
|
|
num_attention_heads=32,
|
|
audio_num_attention_heads=32,
|
|
caption_channels=3840,
|
|
num_layers=48,
|
|
positional_embedding_theta=10000.0,
|
|
positional_embedding_max_pos=[20, 2048, 2048],
|
|
audio_positional_embedding_max_pos=[20],
|
|
causal_temporal_positioning=False,
|
|
vae_scale_factors=(8, 32, 32),
|
|
use_middle_indices_grid=False,
|
|
timestep_scale_multiplier=1000.0,
|
|
av_ca_timestep_scale_multiplier=1.0,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
# Store audio-specific parameters
|
|
self.audio_in_channels = audio_in_channels
|
|
self.audio_cross_attention_dim = audio_cross_attention_dim
|
|
self.audio_attention_head_dim = audio_attention_head_dim
|
|
self.audio_num_attention_heads = audio_num_attention_heads
|
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
|
|
|
# Calculate audio dimensions
|
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
|
self.audio_out_channels = audio_in_channels
|
|
|
|
# Audio-specific constants
|
|
self.num_audio_channels = 8
|
|
self.audio_frequency_bins = 16
|
|
|
|
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
|
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=attention_head_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
caption_channels=caption_channels,
|
|
num_layers=num_layers,
|
|
positional_embedding_theta=positional_embedding_theta,
|
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
|
causal_temporal_positioning=causal_temporal_positioning,
|
|
vae_scale_factors=vae_scale_factors,
|
|
use_middle_indices_grid=use_middle_indices_grid,
|
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
**kwargs,
|
|
)
|
|
|
|
def _init_model_components(self, device, dtype, **kwargs):
|
|
"""Initialize LTXAV-specific components."""
|
|
# Audio-specific projections
|
|
self.audio_patchify_proj = self.operations.Linear(
|
|
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
|
|
)
|
|
|
|
# Audio-specific AdaLN
|
|
self.audio_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
use_additional_conditions=False,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
num_scale_shift_values = 4
|
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=num_scale_shift_values,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=1,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=num_scale_shift_values,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
|
self.audio_inner_dim,
|
|
use_additional_conditions=False,
|
|
embedding_coefficient=1,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
# Audio caption projection
|
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
|
in_features=self.caption_channels,
|
|
hidden_size=self.audio_inner_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
|
"""Initialize transformer blocks for LTXAV."""
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicAVTransformerBlock(
|
|
v_dim=self.inner_dim,
|
|
a_dim=self.audio_inner_dim,
|
|
v_heads=self.num_attention_heads,
|
|
a_heads=self.audio_num_attention_heads,
|
|
vd_head=self.attention_head_dim,
|
|
ad_head=self.audio_attention_head_dim,
|
|
v_context_dim=self.cross_attention_dim,
|
|
a_context_dim=self.audio_cross_attention_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
for _ in range(self.num_layers)
|
|
]
|
|
)
|
|
|
|
def _init_output_components(self, device, dtype):
|
|
"""Initialize output components for LTXAV."""
|
|
# Video output components
|
|
super()._init_output_components(device, dtype)
|
|
# Audio output components
|
|
self.audio_scale_shift_table = nn.Parameter(
|
|
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
|
|
)
|
|
self.audio_norm_out = self.operations.LayerNorm(
|
|
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
|
)
|
|
self.audio_proj_out = self.operations.Linear(
|
|
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
|
|
)
|
|
self.a_patchifier = AudioPatchifier(1, start_end=True)
|
|
|
|
def separate_audio_and_video_latents(self, x, audio_length):
|
|
"""Separate audio and video latents from combined input."""
|
|
# vx = x[:, : self.in_channels]
|
|
# ax = x[:, self.in_channels :]
|
|
#
|
|
# ax = ax.reshape(ax.shape[0], -1)
|
|
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
|
|
#
|
|
# ax = ax.reshape(
|
|
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
|
|
# )
|
|
|
|
vx = x[0]
|
|
ax = x[1] if len(x) > 1 else torch.zeros(
|
|
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
|
|
device=vx.device, dtype=vx.dtype
|
|
)
|
|
return vx, ax
|
|
|
|
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
|
|
if ax.numel() == 0:
|
|
return vx
|
|
else:
|
|
return [vx, ax]
|
|
"""Recombine audio and video latents for output."""
|
|
# if ax.device != vx.device or ax.dtype != vx.dtype:
|
|
# logging.warning("Audio and video latents are on different devices or dtypes.")
|
|
# ax = ax.to(device=vx.device, dtype=vx.dtype)
|
|
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
|
|
#
|
|
# ax = ax.reshape(ax.shape[0], -1)
|
|
# # pad to f x h x w of the video latents
|
|
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
|
|
# if target_shape is None:
|
|
# repetitions = math.ceil(ax.shape[-1] / divisor)
|
|
# else:
|
|
# repetitions = target_shape[1] - vx.shape[1]
|
|
# padded_len = repetitions * divisor
|
|
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
|
|
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
|
|
# return torch.cat([vx, ax], dim=1)
|
|
|
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
|
"""Process input for LTXAV - separate audio and video, then patchify."""
|
|
audio_length = kwargs.get("audio_length", 0)
|
|
# Separate audio and video latents
|
|
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
|
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
|
vx, keyframe_idxs, denoise_mask, **kwargs
|
|
)
|
|
|
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
|
ax = self.audio_patchify_proj(ax)
|
|
|
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
|
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
|
|
|
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
|
"""Prepare timestep embeddings."""
|
|
# TODO: some code reuse is needed here.
|
|
grid_mask = kwargs.get("grid_mask", None)
|
|
if grid_mask is not None:
|
|
timestep = timestep[:, grid_mask]
|
|
|
|
timestep = timestep * self.timestep_scale_multiplier
|
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
|
timestep.flatten(),
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
|
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
|
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
|
v_embedded_timestep = v_embedded_timestep.view(
|
|
batch_size, -1, v_embedded_timestep.shape[-1]
|
|
)
|
|
|
|
# Prepare audio timestep
|
|
a_timestep = kwargs.get("a_timestep")
|
|
if a_timestep is not None:
|
|
a_timestep = a_timestep * self.timestep_scale_multiplier
|
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
|
|
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
|
a_timestep.flatten(),
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
|
timestep.flatten(),
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
|
timestep.flatten() * av_ca_factor,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
|
a_timestep.flatten() * av_ca_factor,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
|
|
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
|
a_timestep.flatten(),
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
|
a_embedded_timestep = a_embedded_timestep.view(
|
|
batch_size, -1, a_embedded_timestep.shape[-1]
|
|
)
|
|
cross_av_timestep_ss = [
|
|
av_ca_audio_scale_shift_timestep,
|
|
av_ca_video_scale_shift_timestep,
|
|
av_ca_a2v_gate_noise_timestep,
|
|
av_ca_v2a_gate_noise_timestep,
|
|
]
|
|
cross_av_timestep_ss = list(
|
|
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
|
)
|
|
else:
|
|
a_timestep = timestep
|
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
|
cross_av_timestep_ss = []
|
|
|
|
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
|
v_embedded_timestep,
|
|
a_embedded_timestep,
|
|
]
|
|
|
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
|
vx = x[0]
|
|
ax = x[1]
|
|
v_context, a_context = torch.split(
|
|
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
|
)
|
|
|
|
v_context, attention_mask = super()._prepare_context(
|
|
v_context, batch_size, vx, attention_mask
|
|
)
|
|
if self.audio_caption_projection is not None:
|
|
a_context = self.audio_caption_projection(a_context)
|
|
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
|
|
|
return [v_context, a_context], attention_mask
|
|
|
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
|
v_pixel_coords = pixel_coords[0]
|
|
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
|
|
|
|
a_latent_coords = pixel_coords[1]
|
|
a_pe = self._precompute_freqs_cis(
|
|
a_latent_coords,
|
|
dim=self.audio_inner_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=self.audio_positional_embedding_max_pos,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
)
|
|
|
|
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
|
|
max_pos = max(
|
|
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
|
|
)
|
|
v_pixel_coords = v_pixel_coords.to(torch.float32)
|
|
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
|
|
av_cross_video_freq_cis = self._precompute_freqs_cis(
|
|
v_pixel_coords[:, 0:1, :],
|
|
dim=self.audio_cross_attention_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=[max_pos],
|
|
use_middle_indices_grid=True,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
)
|
|
av_cross_audio_freq_cis = self._precompute_freqs_cis(
|
|
a_latent_coords[:, 0:1, :],
|
|
dim=self.audio_cross_attention_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=[max_pos],
|
|
use_middle_indices_grid=True,
|
|
num_attention_heads=self.audio_num_attention_heads,
|
|
)
|
|
|
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
|
|
|
def _process_transformer_blocks(
|
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
|
):
|
|
vx = x[0]
|
|
ax = x[1]
|
|
v_context = context[0]
|
|
a_context = context[1]
|
|
v_timestep = timestep[0]
|
|
a_timestep = timestep[1]
|
|
v_pe, av_cross_video_freq_cis = pe[0]
|
|
a_pe, av_cross_audio_freq_cis = pe[1]
|
|
|
|
(
|
|
av_ca_audio_scale_shift_timestep,
|
|
av_ca_video_scale_shift_timestep,
|
|
av_ca_a2v_gate_noise_timestep,
|
|
av_ca_v2a_gate_noise_timestep,
|
|
) = timestep[2]
|
|
|
|
"""Process transformer blocks for LTXAV."""
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
|
|
# Process transformer blocks
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if ("double_block", i) in blocks_replace:
|
|
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"] = block(
|
|
args["img"],
|
|
v_context=args["v_context"],
|
|
a_context=args["a_context"],
|
|
attention_mask=args["attention_mask"],
|
|
v_timestep=args["v_timestep"],
|
|
a_timestep=args["a_timestep"],
|
|
v_pe=args["v_pe"],
|
|
a_pe=args["a_pe"],
|
|
v_cross_pe=args["v_cross_pe"],
|
|
a_cross_pe=args["a_cross_pe"],
|
|
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
|
|
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
|
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
|
transformer_options=args["transformer_options"],
|
|
)
|
|
return out
|
|
|
|
out = blocks_replace[("double_block", i)](
|
|
{
|
|
"img": (vx, ax),
|
|
"v_context": v_context,
|
|
"a_context": a_context,
|
|
"attention_mask": attention_mask,
|
|
"v_timestep": v_timestep,
|
|
"a_timestep": a_timestep,
|
|
"v_pe": v_pe,
|
|
"a_pe": a_pe,
|
|
"v_cross_pe": av_cross_video_freq_cis,
|
|
"a_cross_pe": av_cross_audio_freq_cis,
|
|
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
|
|
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
|
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
|
"transformer_options": transformer_options,
|
|
},
|
|
{"original_block": block_wrap},
|
|
)
|
|
vx, ax = out["img"]
|
|
else:
|
|
vx, ax = block(
|
|
(vx, ax),
|
|
v_context=v_context,
|
|
a_context=a_context,
|
|
attention_mask=attention_mask,
|
|
v_timestep=v_timestep,
|
|
a_timestep=a_timestep,
|
|
v_pe=v_pe,
|
|
a_pe=a_pe,
|
|
v_cross_pe=av_cross_video_freq_cis,
|
|
a_cross_pe=av_cross_audio_freq_cis,
|
|
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
|
|
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
|
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
return [vx, ax]
|
|
|
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
|
vx = x[0]
|
|
ax = x[1]
|
|
v_embedded_timestep = embedded_timestep[0]
|
|
a_embedded_timestep = embedded_timestep[1]
|
|
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
|
|
|
# Process audio output
|
|
a_scale_shift_values = (
|
|
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
|
|
+ a_embedded_timestep[:, :, None]
|
|
)
|
|
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
|
|
|
|
ax = self.audio_norm_out(ax)
|
|
ax = ax * (1 + a_scale) + a_shift
|
|
ax = self.audio_proj_out(ax)
|
|
|
|
# Unpatchify audio
|
|
ax = self.a_patchifier.unpatchify(
|
|
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
|
|
)
|
|
|
|
# Recombine audio and video
|
|
original_shape = kwargs.get("av_orig_shape")
|
|
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
timestep,
|
|
context,
|
|
attention_mask=None,
|
|
frame_rate=25,
|
|
transformer_options={},
|
|
keyframe_idxs=None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Forward pass for LTXAV model.
|
|
|
|
Args:
|
|
x: Combined audio-video input tensor
|
|
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
|
|
context: Context tensor (e.g., text embeddings)
|
|
attention_mask: Attention mask tensor
|
|
frame_rate: Frame rate for temporal processing
|
|
transformer_options: Additional options for transformer blocks
|
|
keyframe_idxs: Keyframe indices for temporal processing
|
|
**kwargs: Additional keyword arguments including audio_length
|
|
|
|
Returns:
|
|
Combined audio-video output tensor
|
|
"""
|
|
# Handle timestep format
|
|
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
|
|
v_timestep, a_timestep = timestep
|
|
kwargs["a_timestep"] = a_timestep
|
|
timestep = v_timestep
|
|
else:
|
|
kwargs["a_timestep"] = timestep
|
|
|
|
# Call parent forward method
|
|
return super().forward(
|
|
x,
|
|
timestep,
|
|
context,
|
|
attention_mask,
|
|
frame_rate,
|
|
transformer_options,
|
|
keyframe_idxs,
|
|
**kwargs,
|
|
)
|