LTX2: Refactor forward function for better VRAM efficiency and fix spatial inpainting (#12046)

* Disable timestep embed compression when inpainting

Spatial inpainting not compatible with the compression

* Reduce crossattn peak VRAM

* LTX2: Refactor forward function for better VRAM efficiency
This commit is contained in:
Jukka Seppänen 2026-01-23 22:26:38 +02:00 committed by GitHub
parent 79cdbc81cb
commit 55bd606e92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,12 +18,12 @@ class CompressedTimestep:
def __init__(self, tensor: torch.Tensor, patches_per_frame: int): def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
""" """
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space) patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
""" """
self.batch_size, num_tokens, self.feature_dim = tensor.shape self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame) # Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame self.num_frames = num_tokens // patches_per_frame
@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values) return (*scale_shift_ada_values, *gate_ada_values)
def forward( def forward(
self, self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
x: Tuple[torch.Tensor, torch.Tensor], v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_context=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
a_context=None,
attention_mask=None,
v_timestep=None,
a_timestep=None,
v_pe=None,
a_pe=None,
v_cross_pe=None,
a_cross_pe=None,
v_cross_scale_shift_timestep=None,
a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None,
a_cross_gate_timestep=None,
transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True) run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True) run_ax = transformer_options.get("run_ax", True)
@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
# video
if run_vx: if run_vx:
vshift_msa, vscale_msa, vgate_msa = ( # video self-attention
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
)
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa del vshift_msa, vscale_msa
vx += self.attn2( attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
comfy.ldm.common_dit.rms_norm(vx), del norm_vx
context=v_context, # video cross-attention
mask=attention_mask, vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
transformer_options=transformer_options, vx.addcmul_(attn1_out, vgate_msa)
) del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
del vshift_msa, vscale_msa, vgate_msa
# audio
if run_ax: if run_ax:
ashift_msa, ascale_msa, agate_msa = ( # audio self-attention
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
)
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
ax += ( del ashift_msa, ascale_msa
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
* agate_msa del norm_ax
) # audio cross-attention
ax += self.audio_attn2( agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
comfy.ldm.common_dit.rms_norm(ax), ax.addcmul_(attn1_out, agate_msa)
context=a_context, del agate_msa, attn1_out
mask=attention_mask, ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
transformer_options=transformer_options,
)
del ashift_msa, ascale_msa, agate_msa # video - audio cross attention.
# Audio - Video cross attention.
if run_a2v or run_v2a: if run_a2v or run_v2a:
# norm3
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
( # audio to video cross attention
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
a_cross_scale_shift_timestep,
a_cross_gate_timestep,
)
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
v_cross_scale_shift_timestep,
v_cross_gate_timestep,
)
if run_a2v: if run_a2v:
vx_scaled = ( scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
+ shift_ca_video_hidden_states_a2v scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
) self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
+ shift_ca_audio_hidden_states_a2v
)
vx += (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=v_cross_pe,
k_pe=a_cross_pe,
transformer_options=transformer_options,
)
* gate_out_a2v
)
del gate_out_a2v vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
del scale_ca_video_hidden_states_a2v,\ ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
shift_ca_video_hidden_states_a2v,\ del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
scale_ca_audio_hidden_states_a2v,\
shift_ca_audio_hidden_states_a2v,\
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
del vx_scaled, ax_scaled
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
vx.addcmul_(a2v_out, gate_out_a2v)
del gate_out_a2v, a2v_out
# video to audio cross attention
if run_v2a: if run_v2a:
ax_scaled = ( scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
+ shift_ca_audio_hidden_states_v2a scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
) self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
+ shift_ca_video_hidden_states_v2a
)
ax += (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=a_cross_pe,
k_pe=v_cross_pe,
transformer_options=transformer_options,
)
* gate_out_v2a
)
del gate_out_v2a ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
del scale_ca_video_hidden_states_v2a,\ vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
shift_ca_video_hidden_states_v2a,\ del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
scale_ca_audio_hidden_states_v2a,\
shift_ca_audio_hidden_states_v2a
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
del ax_scaled, vx_scaled
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
ax.addcmul_(v2a_out, gate_out_v2a)
del gate_out_v2a, v2a_out
del vx_norm3, ax_norm3
# video feedforward
if run_vx: if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = ( vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
)
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
vx += self.ff(vx_scaled) * vgate_mlp del vshift_mlp, vscale_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
ff_out = self.ff(vx_scaled)
del vx_scaled
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
vx.addcmul_(ff_out, vgate_mlp)
del vgate_mlp, ff_out
# audio feedforward
if run_ax: if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = ( ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
)
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
ax += self.audio_ff(ax_scaled) * agate_mlp del ashift_mlp, ascale_mlp
del ashift_mlp, ascale_mlp, agate_mlp ff_out = self.audio_ff(ax_scaled)
del ax_scaled
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
ax.addcmul_(ff_out, agate_mlp)
del agate_mlp, ff_out
return vx, ax return vx, ax
@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
audio_length = kwargs.get("audio_length", 0) audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents # Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length) vx, ax = self.separate_audio_and_video_latents(x, audio_length)
has_spatial_mask = False
if denoise_mask is not None:
# check if any frame has spatial variation (inpainting)
for frame_idx in range(denoise_mask.shape[2]):
frame_mask = denoise_mask[0, 0, frame_idx]
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
has_spatial_mask = True
break
[vx, v_pixel_coords, additional_args] = super()._process_input( [vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs vx, keyframe_idxs, denoise_mask, **kwargs
) )
additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax) ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax) ax = self.audio_patchify_proj(ax)
@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape") orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None v_patches_per_frame = None
if orig_shape is not None and len(orig_shape) == 5: if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space) # orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4] v_patches_per_frame = orig_shape[3] * orig_shape[4]
@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
) )
# Compress cross-attention timesteps (only video side, audio is too small to benefit) # Compress cross-attention timesteps (only video side, audio is too small to benefit)
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
cross_av_timestep_ss = [ cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
] ]