From fd5c0755af18530ff1225f5946c7a647b9694032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 13 Jan 2026 00:28:59 +0200 Subject: [PATCH] Reduce LTX2 VRAM use by more efficient timestep embed handling (#11829) --- comfy/ldm/lightricks/av_model.py | 136 ++++++++++++++++++++++++------- 1 file changed, 106 insertions(+), 30 deletions(-) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 759535501..c12ace241 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier import comfy.ldm.common_dit +class CompressedTimestep: + """Store video timestep embeddings in compressed form using per-frame indexing.""" + __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') + + def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + """ + tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame + patches_per_frame: Number of spatial patches per frame (height * width in latent space) + """ + self.batch_size, num_tokens, self.feature_dim = tensor.shape + + # Check if compression is valid (num_tokens must be divisible by patches_per_frame) + if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.patches_per_frame = patches_per_frame + self.num_frames = num_tokens // patches_per_frame + + # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame + # All patches in a frame are identical, so we only keep the first one + reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) + self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + else: + # Not divisible or too small - store directly without compression + self.patches_per_frame = 1 + self.num_frames = num_tokens + self.data = tensor + + def expand(self): + """Expand back to original tensor.""" + if self.patches_per_frame == 1: + return self.data + + # [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim] + expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim) + return expanded.reshape(self.batch_size, -1, self.feature_dim) + + def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)): + """Compute ada values on compressed per-frame data, then expand spatially.""" + num_ada_params = scale_shift_table.shape[0] + + # No compression - compute directly + if self.patches_per_frame == 1: + num_tokens = self.data.shape[1] + dim_per_param = self.feature_dim // num_ada_params + reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype) + ada_values = (table_values + reshaped).unbind(dim=2) + return ada_values + + # Compressed: compute on per-frame data then expand spatially + # Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param] + frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to( + device=self.data.device, dtype=self.data.dtype + ) + frame_ada = (table_values + frame_reshaped).unbind(dim=2) + + # Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim] + return tuple( + frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1) + .reshape(batch_size, -1, frame_val.shape[-1]) + for frame_val in frame_ada + ) + class BasicAVTransformerBlock(nn.Module): def __init__( self, @@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module): def get_ada_values( self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) ): + if isinstance(timestep, CompressedTimestep): + return timestep.expand_for_computation(scale_shift_table, batch_size, indices) + num_ada_params = scale_shift_table.shape[0] ada_values = ( @@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module): gate_timestep, ) - scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] - gate_ada_values = [t.squeeze(2) for t in gate_ada_values] - - return (*scale_shift_chunks, *gate_ada_values) + return (*scale_shift_ada_values, *gate_ada_values) def forward( self, @@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) - v_embedded_timestep = v_embedded_timestep.view( - batch_size, -1, v_embedded_timestep.shape[-1] - ) + # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] + # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width + orig_shape = kwargs.get("orig_shape") + v_patches_per_frame = None + if orig_shape is not None and len(orig_shape) == 5: + # orig_shape[3] = height, orig_shape[4] = width (in latent space) + v_patches_per_frame = orig_shape[3] * orig_shape[4] + + # Reshape to [batch_size, num_tokens, dim] and compress for storage + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: - a_timestep = a_timestep * self.timestep_scale_multiplier + a_timestep_scaled = a_timestep * self.timestep_scale_multiplier + a_timestep_flat = a_timestep_scaled.flatten() + timestep_flat = timestep_scaled.flatten() av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep.flatten(), + timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep.flatten() * av_ca_factor, + timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep.flatten() * av_ca_factor, + a_timestep_flat * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Compress cross-attention timesteps (only video side, audio is too small to benefit) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed + av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), + ] + a_timestep, a_embedded_timestep = self.audio_adaln_single( - a_timestep.flatten(), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) + # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) - a_embedded_timestep = a_embedded_timestep.view( - batch_size, -1, a_embedded_timestep.shape[-1] - ) - cross_av_timestep_ss = [ - av_ca_audio_scale_shift_timestep, - av_ca_video_scale_shift_timestep, - av_ca_a2v_gate_noise_timestep, - av_ca_v2a_gate_noise_timestep, - ] - cross_av_timestep_ss = list( - [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] - ) + a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) else: - a_timestep = timestep + a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] @@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel): ax = x[1] v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] + + # Expand compressed video timestep if needed + if isinstance(v_embedded_timestep, CompressedTimestep): + v_embedded_timestep = v_embedded_timestep.expand() + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) # Process audio output