from abc import ABC, abstractmethod from enum import Enum import functools import math from typing import Dict, Optional, Tuple from einops import rearrange import numpy as np import torch from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords def _log_base(x, base): return np.log(x) / np.log(base) class LTXRopeType(str, Enum): INTERLEAVED = "interleaved" SPLIT = "split" KEY = "rope_type" @classmethod def from_dict(cls, kwargs, default=None): if default is None: default = cls.INTERLEAVED return cls(kwargs.get(cls.KEY, default)) class LTXFrequenciesPrecision(str, Enum): FLOAT32 = "float32" FLOAT64 = "float64" KEY = "frequencies_precision" @classmethod def from_dict(cls, kwargs, default=None): if default is None: default = cls.FLOAT32 return cls(kwargs.get(cls.KEY, default)) def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, dtype=None, device=None, operations=None, ): super().__init__() self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device) if cond_proj_dim is not None: self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device) else: self.cond_proj = None self.act = nn.SiLU() if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = operations.Linear( time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device ) if post_act_fn is None: self.post_act = None # else: # self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift self.scale = scale def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) return t_emb class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. Reference: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ def __init__( self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None, ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding( in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations ) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) return timesteps_emb class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). Parameters: embedding_dim (`int`): The size of each embedding vector. use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ def __init__( self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None ): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations, ) self.silu = nn.SiLU() self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) def forward( self, timestep: torch.Tensor, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, batch_size: Optional[int] = None, hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ def __init__( self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None ): super().__init__() if out_features is None: out_features = hidden_size self.linear_1 = operations.Linear( in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device ) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") self.linear_2 = operations.Linear( in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device ) def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device) def forward(self, x): return torch.nn.functional.gelu(self.proj(x), approximate="tanh") class FeedForward(nn.Module): def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): super().__init__() inner_dim = int(dim * mult) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): return self.net(x) def apply_rotary_emb(input_tensor, freqs_cis): cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False return ( apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) if split_pe else apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) ) def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) t1, t2 = t_dup.unbind(dim=-1) t_dup = torch.stack((-t2, t1), dim=-1) input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs return out def apply_split_rotary_emb(input_tensor, cos, sin): needs_reshape = False if input_tensor.ndim != 4 and cos.ndim == 4: B, H, T, _ = cos.shape input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) needs_reshape = True split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) first_half_input = split_input[..., :1, :] second_half_input = split_input[..., 1:, :] output = split_input * cos.unsqueeze(-2) first_half_output = output[..., :1, :] second_half_output = output[..., 1:, :] first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) output = rearrange(output, "... d r -> ... (d r)") return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output class CrossAttention(nn.Module): def __init__( self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, attn_precision=None, dtype=None, device=None, operations=None, ): super().__init__() inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim self.attn_precision = attn_precision self.heads = heads self.dim_head = dim_head self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) v = self.to_v(context) q = self.q_norm(q) k = self.k_norm(k) if pe is not None: q = apply_rotary_emb(q, pe) k = apply_rotary_emb(k, pe if k_pe is None else k_pe) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) return self.to_out(out) class BasicTransformerBlock(nn.Module): def __init__( self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None ): super().__init__() self.attn_precision = attn_precision self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) self.attn2 = CrossAttention( query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations, ) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) attn1_input = comfy.ldm.common_dit.rms_norm(x) attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) x.addcmul_(attn1_input, gate_msa) del attn1_input x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) x.addcmul_(self.ff(y), gate_mlp) return x def get_fractional_positions(indices_grid, max_pos): n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' fractional_positions = torch.stack( [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], axis=-1, ) return fractional_positions @functools.lru_cache(maxsize=5) def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): theta = positional_embedding_theta start = 1 end = theta n_elem = 2 * positional_embedding_max_pos_count pow_indices = np.power( theta, np.linspace( _log_base(start, theta), _log_base(end, theta), inner_dim // n_elem, dtype=np.float64, ), ) return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): theta = positional_embedding_theta start = 1 end = theta n_elem = 2 * positional_embedding_max_pos_count indices = theta ** ( torch.linspace( math.log(start, theta), math.log(end, theta), inner_dim // n_elem, device=device, dtype=torch.float32, ) ) indices = indices.to(dtype=torch.float32) indices = indices * math.pi / 2 return indices def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): if use_middle_indices_grid: assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] indices_grid = (indices_grid_start + indices_grid_end) / 2.0 elif len(indices_grid.shape) == 4: indices_grid = indices_grid[..., 0] # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) indices = indices.to(device=fractional_positions.device) freqs = ( (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) .transpose(-1, -2) .flatten(2) ) return freqs def interleaved_freqs_cis(freqs, pad_size): cos_freq = freqs.cos().repeat_interleave(2, dim=-1) sin_freq = freqs.sin().repeat_interleave(2, dim=-1) if pad_size != 0: cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) return cos_freq, sin_freq def split_freqs_cis(freqs, pad_size, num_attention_heads): cos_freq = freqs.cos() sin_freq = freqs.sin() if pad_size != 0: cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) # Reshape freqs to be compatible with multi-head attention B , T, half_HD = cos_freq.shape cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) return cos_freq, sin_freq class LTXBaseModel(torch.nn.Module, ABC): """ Abstract base class for LTX models (Lightricks Transformer models). This class defines the common interface and shared functionality for all LTX models, including LTXV (video) and LTXAV (audio-video) variants. """ def __init__( self, in_channels: int, cross_attention_dim: int, attention_head_dim: int, num_attention_heads: int, caption_channels: int, num_layers: int, positional_embedding_theta: float = 10000.0, positional_embedding_max_pos: list = [20, 2048, 2048], causal_temporal_positioning: bool = False, vae_scale_factors: tuple = (8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, dtype=None, device=None, operations=None, **kwargs, ): super().__init__() self.generator = None self.vae_scale_factors = vae_scale_factors self.use_middle_indices_grid = use_middle_indices_grid self.dtype = dtype self.in_channels = in_channels self.cross_attention_dim = cross_attention_dim self.attention_head_dim = attention_head_dim self.num_attention_heads = num_attention_heads self.caption_channels = caption_channels self.num_layers = num_layers self.positional_embedding_theta = positional_embedding_theta self.positional_embedding_max_pos = positional_embedding_max_pos self.split_positional_embedding = LTXRopeType.from_dict(kwargs) self.freq_grid_generator = ( generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 else generate_freq_grid_pytorch ) self.causal_temporal_positioning = causal_temporal_positioning self.operations = operations self.timestep_scale_multiplier = timestep_scale_multiplier # Common dimensions self.inner_dim = num_attention_heads * attention_head_dim self.out_channels = in_channels # Initialize common components self._init_common_components(device, dtype) # Initialize model-specific components self._init_model_components(device, dtype, **kwargs) # Initialize transformer blocks self._init_transformer_blocks(device, dtype, **kwargs) # Initialize output components self._init_output_components(device, dtype) def _init_common_components(self, device, dtype): """Initialize components common to all LTX models - patchify_proj: Linear projection for patchifying input - adaln_single: AdaLN layer for timestep embedding - caption_projection: Linear projection for caption embedding """ self.patchify_proj = self.operations.Linear( self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) self.adaln_single = AdaLayerNormSingle( self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) self.caption_projection = PixArtAlphaTextProjection( in_features=self.caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=self.operations, ) @abstractmethod def _init_model_components(self, device, dtype, **kwargs): """Initialize model-specific components. Must be implemented by subclasses.""" pass @abstractmethod def _init_transformer_blocks(self, device, dtype, **kwargs): """Initialize transformer blocks. Must be implemented by subclasses.""" pass @abstractmethod def _init_output_components(self, device, dtype): """Initialize output components. Must be implemented by subclasses.""" pass @abstractmethod def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): """Process input data. Must be implemented by subclasses.""" pass @abstractmethod def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): """Process transformer blocks. Must be implemented by subclasses.""" pass @abstractmethod def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): """Process output data. Must be implemented by subclasses.""" pass def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): """Prepare timestep embeddings.""" grid_mask = kwargs.get("grid_mask", None) if grid_mask is not None: timestep = timestep[:, grid_mask] timestep = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = self.adaln_single( timestep.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) # Second dimension is 1 or number of tokens (if timestep_per_token) timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) return timestep, embedded_timestep def _prepare_context(self, context, batch_size, x, attention_mask=None): """Prepare context for transformer blocks.""" if self.caption_projection is not None: context = self.caption_projection(context) context = context.view(batch_size, -1, x.shape[-1]) return context, attention_mask def _precompute_freqs_cis( self, indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048], use_middle_indices_grid=False, num_attention_heads=32, ): split_mode = self.split_positional_embedding == LTXRopeType.SPLIT indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) if split_mode: expected_freqs = dim // 2 current_freqs = freqs.shape[-1] pad_size = expected_freqs - current_freqs cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) else: # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only n_elem = 2 * indices_grid.shape[1] cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): """Prepare positional embeddings.""" fractional_coords = pixel_coords.to(torch.float32) fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) pe = self._precompute_freqs_cis( fractional_coords, dim=self.inner_dim, out_dtype=x_dtype, max_pos=self.positional_embedding_max_pos, use_middle_indices_grid=self.use_middle_indices_grid, num_attention_heads=self.num_attention_heads, ) return pe def _prepare_attention_mask(self, attention_mask, x_dtype): """Prepare attention mask.""" if attention_mask is not None and not torch.is_floating_point(attention_mask): attention_mask = (attention_mask - 1).to(x_dtype).reshape( (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) ) * torch.finfo(x_dtype).max return attention_mask def forward( self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs ): """ Forward pass for LTX models. Args: x: Input tensor timestep: Timestep tensor context: Context tensor (e.g., text embeddings) attention_mask: Attention mask tensor frame_rate: Frame rate for temporal processing transformer_options: Additional options for transformer blocks keyframe_idxs: Keyframe indices for temporal processing **kwargs: Additional keyword arguments Returns: Processed output tensor """ return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers( comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options ), ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) def _forward( self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs ): """ Internal forward pass for LTX models. Args: x: Input tensor timestep: Timestep tensor context: Context tensor (e.g., text embeddings) attention_mask: Attention mask tensor frame_rate: Frame rate for temporal processing transformer_options: Additional options for transformer blocks keyframe_idxs: Keyframe indices for temporal processing **kwargs: Additional keyword arguments Returns: Processed output tensor """ if isinstance(x, list): input_dtype = x[0].dtype batch_size = x[0].shape[0] else: input_dtype = x.dtype batch_size = x.shape[0] # Process input merged_args = {**transformer_options, **kwargs} x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) merged_args.update(additional_args) # Prepare timestep and context timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) # Prepare attention mask and positional embeddings attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) # Process transformer blocks x = self._process_transformer_blocks( x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args ) # Process output x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) return x class LTXVModel(LTXBaseModel): """LTXV model for video generation.""" def __init__( self, in_channels=128, cross_attention_dim=2048, attention_head_dim=64, num_attention_heads=32, caption_channels=4096, num_layers=28, positional_embedding_theta=10000.0, positional_embedding_max_pos=[20, 2048, 2048], causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, dtype=None, device=None, operations=None, **kwargs, ): super().__init__( in_channels=in_channels, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, num_attention_heads=num_attention_heads, caption_channels=caption_channels, num_layers=num_layers, positional_embedding_theta=positional_embedding_theta, positional_embedding_max_pos=positional_embedding_max_pos, causal_temporal_positioning=causal_temporal_positioning, vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, dtype=dtype, device=device, operations=operations, **kwargs, ) def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXV-specific components.""" # No additional components needed for LTXV beyond base class pass def _init_transformer_blocks(self, device, dtype, **kwargs): """Initialize transformer blocks for LTXV.""" self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, self.num_attention_heads, self.attention_head_dim, context_dim=self.cross_attention_dim, dtype=dtype, device=device, operations=self.operations, ) for _ in range(self.num_layers) ] ) def _init_output_components(self, device, dtype): """Initialize output components for LTXV.""" self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) self.norm_out = self.operations.LayerNorm( self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device ) self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) self.patchifier = SymmetricPatchifier(1, start_end=True) def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): """Process input for LTXV.""" additional_args = {"orig_shape": list(x.shape)} x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, scale_factors=self.vae_scale_factors, causal_fix=self.causal_temporal_positioning, ) grid_mask = None if keyframe_idxs is not None: additional_args.update({ "orig_patchified_shape": list(x.shape)}) denoise_mask = self.patchifier.patchify(denoise_mask)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] additional_args.update({"grid_mask": grid_mask}) x = x[:, grid_mask, :] pixel_coords = pixel_coords[:, :, grid_mask, ...] kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs x = self.patchify_proj(x) return x, pixel_coords, additional_args def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: x = block( x, context=context, attention_mask=attention_mask, timestep=timestep, pe=pe, transformer_options=transformer_options, ) return x def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): """Process output for LTXV.""" # Apply scale-shift modulation scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) x = x * (1 + scale) + shift x = self.proj_out(x) if keyframe_idxs is not None: grid_mask = kwargs["grid_mask"] orig_patchified_shape = kwargs["orig_patchified_shape"] full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) full_x[:, grid_mask, :] = x x = full_x # Unpatchify to restore original dimensions orig_shape = kwargs["orig_shape"] x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], output_width=orig_shape[4], output_num_frames=orig_shape[2], out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), ) return x