mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
LTX2: Refactor forward function for better VRAM efficiency and fix spatial inpainting (#12046)
* Disable timestep embed compression when inpainting Spatial inpainting not compatible with the compression * Reduce crossattn peak VRAM * LTX2: Refactor forward function for better VRAM efficiency
This commit is contained in:
parent
79cdbc81cb
commit
55bd606e92
@ -18,12 +18,12 @@ class CompressedTimestep:
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tuple[torch.Tensor, torch.Tensor],
|
||||
v_context=None,
|
||||
a_context=None,
|
||||
attention_mask=None,
|
||||
v_timestep=None,
|
||||
a_timestep=None,
|
||||
v_pe=None,
|
||||
a_pe=None,
|
||||
v_cross_pe=None,
|
||||
a_cross_pe=None,
|
||||
v_cross_scale_shift_timestep=None,
|
||||
a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None,
|
||||
a_cross_gate_timestep=None,
|
||||
transformer_options=None,
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||
|
||||
# video
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# video self-attention
|
||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
||||
vx += self.attn2(
|
||||
comfy.ldm.common_dit.rms_norm(vx),
|
||||
context=v_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
del vshift_msa, vscale_msa, vgate_msa
|
||||
del vshift_msa, vscale_msa
|
||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||
del norm_vx
|
||||
# video cross-attention
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
||||
)
|
||||
|
||||
# audio self-attention
|
||||
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||
ax += (
|
||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
* agate_msa
|
||||
)
|
||||
ax += self.audio_attn2(
|
||||
comfy.ldm.common_dit.rms_norm(ax),
|
||||
context=a_context,
|
||||
mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del ashift_msa, ascale_msa
|
||||
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||
del norm_ax
|
||||
# audio cross-attention
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
|
||||
del ashift_msa, ascale_msa, agate_msa
|
||||
|
||||
# Audio - Video cross attention.
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
# norm3
|
||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||
|
||||
(
|
||||
scale_ca_audio_hidden_states_a2v,
|
||||
shift_ca_audio_hidden_states_a2v,
|
||||
scale_ca_audio_hidden_states_v2a,
|
||||
shift_ca_audio_hidden_states_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
a_cross_scale_shift_timestep,
|
||||
a_cross_gate_timestep,
|
||||
)
|
||||
|
||||
(
|
||||
scale_ca_video_hidden_states_a2v,
|
||||
shift_ca_video_hidden_states_a2v,
|
||||
scale_ca_video_hidden_states_v2a,
|
||||
shift_ca_video_hidden_states_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
v_cross_scale_shift_timestep,
|
||||
v_cross_gate_timestep,
|
||||
)
|
||||
|
||||
# audio to video cross attention
|
||||
if run_a2v:
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
||||
+ shift_ca_video_hidden_states_a2v
|
||||
)
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
||||
+ shift_ca_audio_hidden_states_a2v
|
||||
)
|
||||
vx += (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=v_cross_pe,
|
||||
k_pe=a_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||
|
||||
del gate_out_a2v
|
||||
del scale_ca_video_hidden_states_a2v,\
|
||||
shift_ca_video_hidden_states_a2v,\
|
||||
scale_ca_audio_hidden_states_a2v,\
|
||||
shift_ca_audio_hidden_states_a2v,\
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||
|
||||
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||
del vx_scaled, ax_scaled
|
||||
|
||||
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||
del gate_out_a2v, a2v_out
|
||||
|
||||
# video to audio cross attention
|
||||
if run_v2a:
|
||||
ax_scaled = (
|
||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
||||
+ shift_ca_audio_hidden_states_v2a
|
||||
)
|
||||
vx_scaled = (
|
||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
||||
+ shift_ca_video_hidden_states_v2a
|
||||
)
|
||||
ax += (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=a_cross_pe,
|
||||
k_pe=v_cross_pe,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||
|
||||
del gate_out_v2a
|
||||
del scale_ca_video_hidden_states_v2a,\
|
||||
shift_ca_video_hidden_states_v2a,\
|
||||
scale_ca_audio_hidden_states_v2a,\
|
||||
shift_ca_audio_hidden_states_v2a
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||
|
||||
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||
del ax_scaled, vx_scaled
|
||||
|
||||
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||
del gate_out_v2a, v2a_out
|
||||
|
||||
del vx_norm3, ax_norm3
|
||||
|
||||
# video feedforward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx += self.ff(vx_scaled) * vgate_mlp
|
||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
||||
del vshift_mlp, vscale_mlp
|
||||
|
||||
ff_out = self.ff(vx_scaled)
|
||||
del vx_scaled
|
||||
|
||||
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||
vx.addcmul_(ff_out, vgate_mlp)
|
||||
del vgate_mlp, ff_out
|
||||
|
||||
# audio feedforward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
||||
)
|
||||
|
||||
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
||||
del ashift_mlp, ascale_mlp
|
||||
|
||||
del ashift_mlp, ascale_mlp, agate_mlp
|
||||
ff_out = self.audio_ff(ax_scaled)
|
||||
del ax_scaled
|
||||
|
||||
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||
ax.addcmul_(ff_out, agate_mlp)
|
||||
del agate_mlp, ff_out
|
||||
|
||||
return vx, ax
|
||||
|
||||
@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
|
||||
audio_length = kwargs.get("audio_length", 0)
|
||||
# Separate audio and video latents
|
||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||
|
||||
has_spatial_mask = False
|
||||
if denoise_mask is not None:
|
||||
# check if any frame has spatial variation (inpainting)
|
||||
for frame_idx in range(denoise_mask.shape[2]):
|
||||
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||
has_spatial_mask = True
|
||||
break
|
||||
|
||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||
)
|
||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||
|
||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||
ax = self.audio_patchify_proj(ax)
|
||||
@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if orig_shape is not None and len(orig_shape) == 5:
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||
cross_av_timestep_ss = [
|
||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user