from typing import Tuple import torch import torch.nn as nn from comfy.ldm.lightricks.model import ( CrossAttention, FeedForward, AdaLayerNormSingle, PixArtAlphaTextProjection, LTXVModel, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier import comfy.ldm.common_dit class BasicAVTransformerBlock(nn.Module): def __init__( self, v_dim, a_dim, v_heads, a_heads, vd_head, ad_head, v_context_dim=None, a_context_dim=None, attn_precision=None, dtype=None, device=None, operations=None, ): super().__init__() self.attn_precision = attn_precision self.attn1 = CrossAttention( query_dim=v_dim, heads=v_heads, dim_head=vd_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) self.audio_attn1 = CrossAttention( query_dim=a_dim, heads=a_heads, dim_head=ad_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) self.attn2 = CrossAttention( query_dim=v_dim, context_dim=v_context_dim, heads=v_heads, dim_head=vd_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) self.audio_attn2 = CrossAttention( query_dim=a_dim, context_dim=a_context_dim, heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) # Q: Video, K,V: Audio self.audio_to_video_attn = CrossAttention( query_dim=v_dim, context_dim=a_dim, heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) # Q: Audio, K,V: Video self.video_to_audio_attn = CrossAttention( query_dim=a_dim, context_dim=v_dim, heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) self.ff = FeedForward( v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations ) self.audio_ff = FeedForward( a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations ) self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) self.audio_scale_shift_table = nn.Parameter( torch.empty(6, a_dim, device=device, dtype=dtype) ) self.scale_shift_table_a2v_ca_audio = nn.Parameter( torch.empty(5, a_dim, device=device, dtype=dtype) ) self.scale_shift_table_a2v_ca_video = nn.Parameter( torch.empty(5, v_dim, device=device, dtype=dtype) ) def get_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) ): num_ada_params = scale_shift_table.shape[0] ada_values = ( scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] ).unbind(dim=2) return ada_values def get_av_ca_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, scale_shift_timestep: torch.Tensor, gate_timestep: torch.Tensor, num_scale_shift_values: int = 4, ): scale_shift_ada_values = self.get_ada_values( scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, ) gate_ada_values = self.get_ada_values( scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, ) scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] gate_ada_values = [t.squeeze(2) for t in gate_ada_values] return (*scale_shift_chunks, *gate_ada_values) def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) vx, ax = x run_ax = run_ax and ax.numel() > 0 run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) if run_vx: vshift_msa, vscale_msa, vgate_msa = ( self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) ) norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa vx += self.attn2( comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options, ) del vshift_msa, vscale_msa, vgate_msa if run_ax: ashift_msa, ascale_msa, agate_msa = ( self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) ) norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa ax += ( self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) * agate_msa ) ax += self.audio_attn2( comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options, ) del ashift_msa, ascale_msa, agate_msa # Audio - Video cross attention. if run_a2v or run_v2a: # norm3 vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) ( scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a, gate_out_v2a, ) = self.get_av_ca_ada_values( self.scale_shift_table_a2v_ca_audio, ax.shape[0], a_cross_scale_shift_timestep, a_cross_gate_timestep, ) ( scale_ca_video_hidden_states_a2v, shift_ca_video_hidden_states_a2v, scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, gate_out_a2v, ) = self.get_av_ca_ada_values( self.scale_shift_table_a2v_ca_video, vx.shape[0], v_cross_scale_shift_timestep, v_cross_gate_timestep, ) if run_a2v: vx_scaled = ( vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + shift_ca_video_hidden_states_a2v ) ax_scaled = ( ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v ) vx += ( self.audio_to_video_attn( vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options, ) * gate_out_a2v ) del gate_out_a2v del scale_ca_video_hidden_states_a2v,\ shift_ca_video_hidden_states_a2v,\ scale_ca_audio_hidden_states_a2v,\ shift_ca_audio_hidden_states_a2v,\ if run_v2a: ax_scaled = ( ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a ) vx_scaled = ( vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a ) ax += ( self.video_to_audio_attn( ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options, ) * gate_out_v2a ) del gate_out_v2a del scale_ca_video_hidden_states_v2a,\ shift_ca_video_hidden_states_v2a,\ scale_ca_audio_hidden_states_v2a,\ shift_ca_audio_hidden_states_v2a if run_vx: vshift_mlp, vscale_mlp, vgate_mlp = ( self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None)) ) vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp vx += self.ff(vx_scaled) * vgate_mlp del vshift_mlp, vscale_mlp, vgate_mlp if run_ax: ashift_mlp, ascale_mlp, agate_mlp = ( self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None)) ) ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp ax += self.audio_ff(ax_scaled) * agate_mlp del ashift_mlp, ascale_mlp, agate_mlp return vx, ax class LTXAVModel(LTXVModel): """LTXAV model for audio-video generation.""" def __init__( self, in_channels=128, audio_in_channels=128, cross_attention_dim=4096, audio_cross_attention_dim=2048, attention_head_dim=128, audio_attention_head_dim=64, num_attention_heads=32, audio_num_attention_heads=32, caption_channels=3840, num_layers=48, positional_embedding_theta=10000.0, positional_embedding_max_pos=[20, 2048, 2048], audio_positional_embedding_max_pos=[20], causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier=1000.0, av_ca_timestep_scale_multiplier=1.0, dtype=None, device=None, operations=None, **kwargs, ): # Store audio-specific parameters self.audio_in_channels = audio_in_channels self.audio_cross_attention_dim = audio_cross_attention_dim self.audio_attention_head_dim = audio_attention_head_dim self.audio_num_attention_heads = audio_num_attention_heads self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos # Calculate audio dimensions self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim self.audio_out_channels = audio_in_channels # Audio-specific constants self.num_audio_channels = 8 self.audio_frequency_bins = 16 self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier super().__init__( in_channels=in_channels, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, num_attention_heads=num_attention_heads, caption_channels=caption_channels, num_layers=num_layers, positional_embedding_theta=positional_embedding_theta, positional_embedding_max_pos=positional_embedding_max_pos, causal_temporal_positioning=causal_temporal_positioning, vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, dtype=dtype, device=device, operations=operations, **kwargs, ) def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXAV-specific components.""" # Audio-specific projections self.audio_patchify_proj = self.operations.Linear( self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device ) # Audio-specific AdaLN self.audio_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations, ) num_scale_shift_values = 4 self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, use_additional_conditions=False, embedding_coefficient=num_scale_shift_values, dtype=dtype, device=device, operations=self.operations, ) self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( self.inner_dim, use_additional_conditions=False, embedding_coefficient=1, dtype=dtype, device=device, operations=self.operations, ) self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, use_additional_conditions=False, embedding_coefficient=num_scale_shift_values, dtype=dtype, device=device, operations=self.operations, ) self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, use_additional_conditions=False, embedding_coefficient=1, dtype=dtype, device=device, operations=self.operations, ) # Audio caption projection self.audio_caption_projection = PixArtAlphaTextProjection( in_features=self.caption_channels, hidden_size=self.audio_inner_dim, dtype=dtype, device=device, operations=self.operations, ) def _init_transformer_blocks(self, device, dtype, **kwargs): """Initialize transformer blocks for LTXAV.""" self.transformer_blocks = nn.ModuleList( [ BasicAVTransformerBlock( v_dim=self.inner_dim, a_dim=self.audio_inner_dim, v_heads=self.num_attention_heads, a_heads=self.audio_num_attention_heads, vd_head=self.attention_head_dim, ad_head=self.audio_attention_head_dim, v_context_dim=self.cross_attention_dim, a_context_dim=self.audio_cross_attention_dim, dtype=dtype, device=device, operations=self.operations, ) for _ in range(self.num_layers) ] ) def _init_output_components(self, device, dtype): """Initialize output components for LTXAV.""" # Video output components super()._init_output_components(device, dtype) # Audio output components self.audio_scale_shift_table = nn.Parameter( torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device) ) self.audio_norm_out = self.operations.LayerNorm( self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device ) self.audio_proj_out = self.operations.Linear( self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device ) self.a_patchifier = AudioPatchifier(1, start_end=True) def separate_audio_and_video_latents(self, x, audio_length): """Separate audio and video latents from combined input.""" # vx = x[:, : self.in_channels] # ax = x[:, self.in_channels :] # # ax = ax.reshape(ax.shape[0], -1) # ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins] # # ax = ax.reshape( # ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins # ) vx = x[0] ax = x[1] if len(x) > 1 else torch.zeros( (vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins), device=vx.device, dtype=vx.dtype ) return vx, ax def recombine_audio_and_video_latents(self, vx, ax, target_shape=None): if ax.numel() == 0: return vx else: return [vx, ax] """Recombine audio and video latents for output.""" # if ax.device != vx.device or ax.dtype != vx.dtype: # logging.warning("Audio and video latents are on different devices or dtypes.") # ax = ax.to(device=vx.device, dtype=vx.dtype) # logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}") # # ax = ax.reshape(ax.shape[0], -1) # # pad to f x h x w of the video latents # divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3] # if target_shape is None: # repetitions = math.ceil(ax.shape[-1] / divisor) # else: # repetitions = target_shape[1] - vx.shape[1] # padded_len = repetitions * divisor # ax = F.pad(ax, (0, padded_len - ax.shape[-1])) # ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1]) # return torch.cat([vx, ax], dim=1) def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): """Process input for LTXAV - separate audio and video, then patchify.""" audio_length = kwargs.get("audio_length", 0) # Separate audio and video latents vx, ax = self.separate_audio_and_video_latents(x, audio_length) [vx, v_pixel_coords, additional_args] = super()._process_input( vx, keyframe_idxs, denoise_mask, **kwargs ) ax, a_latent_coords = self.a_patchifier.patchify(ax) ax = self.audio_patchify_proj(ax) # additional_args.update({"av_orig_shape": list(x.shape)}) return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): """Prepare timestep embeddings.""" # TODO: some code reuse is needed here. grid_mask = kwargs.get("grid_mask", None) if grid_mask is not None: timestep = timestep[:, grid_mask] timestep = timestep * self.timestep_scale_multiplier v_timestep, v_embedded_timestep = self.adaln_single( timestep.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) # Second dimension is 1 or number of tokens (if timestep_per_token) v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) v_embedded_timestep = v_embedded_timestep.view( batch_size, -1, v_embedded_timestep.shape[-1] ) # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: a_timestep = a_timestep * self.timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( a_timestep.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( timestep.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( timestep.flatten() * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( a_timestep.flatten() * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) a_timestep, a_embedded_timestep = self.audio_adaln_single( a_timestep.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_embedded_timestep = a_embedded_timestep.view( batch_size, -1, a_embedded_timestep.shape[-1] ) cross_av_timestep_ss = [ av_ca_audio_scale_shift_timestep, av_ca_video_scale_shift_timestep, av_ca_a2v_gate_noise_timestep, av_ca_v2a_gate_noise_timestep, ] cross_av_timestep_ss = list( [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] ) else: a_timestep = timestep a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] return [v_timestep, a_timestep, cross_av_timestep_ss], [ v_embedded_timestep, a_embedded_timestep, ] def _prepare_context(self, context, batch_size, x, attention_mask=None): vx = x[0] ax = x[1] v_context, a_context = torch.split( context, int(context.shape[-1] / 2), len(context.shape) - 1 ) v_context, attention_mask = super()._prepare_context( v_context, batch_size, vx, attention_mask ) if self.audio_caption_projection is not None: a_context = self.audio_caption_projection(a_context) a_context = a_context.view(batch_size, -1, ax.shape[-1]) return [v_context, a_context], attention_mask def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): v_pixel_coords = pixel_coords[0] v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype) a_latent_coords = pixel_coords[1] a_pe = self._precompute_freqs_cis( a_latent_coords, dim=self.audio_inner_dim, out_dtype=x_dtype, max_pos=self.audio_positional_embedding_max_pos, use_middle_indices_grid=self.use_middle_indices_grid, num_attention_heads=self.audio_num_attention_heads, ) # calculate positional embeddings for the middle of the token duration, to use in av cross attention layers. max_pos = max( self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0] ) v_pixel_coords = v_pixel_coords.to(torch.float32) v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate) av_cross_video_freq_cis = self._precompute_freqs_cis( v_pixel_coords[:, 0:1, :], dim=self.audio_cross_attention_dim, out_dtype=x_dtype, max_pos=[max_pos], use_middle_indices_grid=True, num_attention_heads=self.audio_num_attention_heads, ) av_cross_audio_freq_cis = self._precompute_freqs_cis( a_latent_coords[:, 0:1, :], dim=self.audio_cross_attention_dim, out_dtype=x_dtype, max_pos=[max_pos], use_middle_indices_grid=True, num_attention_heads=self.audio_num_attention_heads, ) return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] def _process_transformer_blocks( self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs ): vx = x[0] ax = x[1] v_context = context[0] a_context = context[1] v_timestep = timestep[0] a_timestep = timestep[1] v_pe, av_cross_video_freq_cis = pe[0] a_pe, av_cross_audio_freq_cis = pe[1] ( av_ca_audio_scale_shift_timestep, av_ca_video_scale_shift_timestep, av_ca_a2v_gate_noise_timestep, av_ca_v2a_gate_noise_timestep, ) = timestep[2] """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block( args["img"], v_context=args["v_context"], a_context=args["a_context"], attention_mask=args["attention_mask"], v_timestep=args["v_timestep"], a_timestep=args["a_timestep"], v_pe=args["v_pe"], a_pe=args["a_pe"], v_cross_pe=args["v_cross_pe"], a_cross_pe=args["a_cross_pe"], v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"], a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"], v_cross_gate_timestep=args["v_cross_gate_timestep"], a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], ) return out out = blocks_replace[("double_block", i)]( { "img": (vx, ax), "v_context": v_context, "a_context": a_context, "attention_mask": attention_mask, "v_timestep": v_timestep, "a_timestep": a_timestep, "v_pe": v_pe, "a_pe": a_pe, "v_cross_pe": av_cross_video_freq_cis, "a_cross_pe": av_cross_audio_freq_cis, "v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep, "a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep, "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, }, {"original_block": block_wrap}, ) vx, ax = out["img"] else: vx, ax = block( (vx, ax), v_context=v_context, a_context=a_context, attention_mask=attention_mask, v_timestep=v_timestep, a_timestep=a_timestep, v_pe=v_pe, a_pe=a_pe, v_cross_pe=av_cross_video_freq_cis, a_cross_pe=av_cross_audio_freq_cis, v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep, a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep, v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, ) return [vx, ax] def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): vx = x[0] ax = x[1] v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) # Process audio output a_scale_shift_values = ( self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype) + a_embedded_timestep[:, :, None] ) a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1] ax = self.audio_norm_out(ax) ax = ax * (1 + a_scale) + a_shift ax = self.audio_proj_out(ax) # Unpatchify audio ax = self.a_patchifier.unpatchify( ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins ) # Recombine audio and video original_shape = kwargs.get("av_orig_shape") return self.recombine_audio_and_video_latents(vx, ax, original_shape) def forward( self, x, timestep, context, attention_mask=None, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs, ): """ Forward pass for LTXAV model. Args: x: Combined audio-video input tensor timestep: Tuple of (video_timestep, audio_timestep) or single timestep context: Context tensor (e.g., text embeddings) attention_mask: Attention mask tensor frame_rate: Frame rate for temporal processing transformer_options: Additional options for transformer blocks keyframe_idxs: Keyframe indices for temporal processing **kwargs: Additional keyword arguments including audio_length Returns: Combined audio-video output tensor """ # Handle timestep format if isinstance(timestep, (tuple, list)) and len(timestep) == 2: v_timestep, a_timestep = timestep kwargs["a_timestep"] = a_timestep timestep = v_timestep else: kwargs["a_timestep"] = timestep # Call parent forward method return super().forward( x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs, )