mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
LTX2: Refactor forward function for better VRAM efficiency and fix spatial inpainting (#12046)
* Disable timestep embed compression when inpainting Spatial inpainting not compatible with the compression * Reduce crossattn peak VRAM * LTX2: Refactor forward function for better VRAM efficiency
This commit is contained in:
parent
79cdbc81cb
commit
55bd606e92
@ -18,12 +18,12 @@ class CompressedTimestep:
|
|||||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||||
"""
|
"""
|
||||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||||
"""
|
"""
|
||||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||||
|
|
||||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||||
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||||
self.patches_per_frame = patches_per_frame
|
self.patches_per_frame = patches_per_frame
|
||||||
self.num_frames = num_tokens // patches_per_frame
|
self.num_frames = num_tokens // patches_per_frame
|
||||||
|
|
||||||
@ -215,22 +215,9 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
return (*scale_shift_ada_values, *gate_ada_values)
|
return (*scale_shift_ada_values, *gate_ada_values)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
x: Tuple[torch.Tensor, torch.Tensor],
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_context=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
||||||
a_context=None,
|
|
||||||
attention_mask=None,
|
|
||||||
v_timestep=None,
|
|
||||||
a_timestep=None,
|
|
||||||
v_pe=None,
|
|
||||||
a_pe=None,
|
|
||||||
v_cross_pe=None,
|
|
||||||
a_cross_pe=None,
|
|
||||||
v_cross_scale_shift_timestep=None,
|
|
||||||
a_cross_scale_shift_timestep=None,
|
|
||||||
v_cross_gate_timestep=None,
|
|
||||||
a_cross_gate_timestep=None,
|
|
||||||
transformer_options=None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@ -240,144 +227,102 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
|
||||||
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
|
||||||
|
|
||||||
|
# video
|
||||||
if run_vx:
|
if run_vx:
|
||||||
vshift_msa, vscale_msa, vgate_msa = (
|
# video self-attention
|
||||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
|
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||||
)
|
|
||||||
|
|
||||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
|
del vshift_msa, vscale_msa
|
||||||
vx += self.attn2(
|
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
||||||
comfy.ldm.common_dit.rms_norm(vx),
|
del norm_vx
|
||||||
context=v_context,
|
# video cross-attention
|
||||||
mask=attention_mask,
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
transformer_options=transformer_options,
|
vx.addcmul_(attn1_out, vgate_msa)
|
||||||
)
|
del vgate_msa, attn1_out
|
||||||
|
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
del vshift_msa, vscale_msa, vgate_msa
|
|
||||||
|
|
||||||
|
# audio
|
||||||
if run_ax:
|
if run_ax:
|
||||||
ashift_msa, ascale_msa, agate_msa = (
|
# audio self-attention
|
||||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
|
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
|
||||||
)
|
|
||||||
|
|
||||||
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
|
||||||
ax += (
|
del ashift_msa, ascale_msa
|
||||||
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
|
||||||
* agate_msa
|
del norm_ax
|
||||||
)
|
# audio cross-attention
|
||||||
ax += self.audio_attn2(
|
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||||
comfy.ldm.common_dit.rms_norm(ax),
|
ax.addcmul_(attn1_out, agate_msa)
|
||||||
context=a_context,
|
del agate_msa, attn1_out
|
||||||
mask=attention_mask,
|
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||||
transformer_options=transformer_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
del ashift_msa, ascale_msa, agate_msa
|
# video - audio cross attention.
|
||||||
|
|
||||||
# Audio - Video cross attention.
|
|
||||||
if run_a2v or run_v2a:
|
if run_a2v or run_v2a:
|
||||||
# norm3
|
|
||||||
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
|
||||||
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
|
||||||
|
|
||||||
(
|
# audio to video cross attention
|
||||||
scale_ca_audio_hidden_states_a2v,
|
|
||||||
shift_ca_audio_hidden_states_a2v,
|
|
||||||
scale_ca_audio_hidden_states_v2a,
|
|
||||||
shift_ca_audio_hidden_states_v2a,
|
|
||||||
gate_out_v2a,
|
|
||||||
) = self.get_av_ca_ada_values(
|
|
||||||
self.scale_shift_table_a2v_ca_audio,
|
|
||||||
ax.shape[0],
|
|
||||||
a_cross_scale_shift_timestep,
|
|
||||||
a_cross_gate_timestep,
|
|
||||||
)
|
|
||||||
|
|
||||||
(
|
|
||||||
scale_ca_video_hidden_states_a2v,
|
|
||||||
shift_ca_video_hidden_states_a2v,
|
|
||||||
scale_ca_video_hidden_states_v2a,
|
|
||||||
shift_ca_video_hidden_states_v2a,
|
|
||||||
gate_out_a2v,
|
|
||||||
) = self.get_av_ca_ada_values(
|
|
||||||
self.scale_shift_table_a2v_ca_video,
|
|
||||||
vx.shape[0],
|
|
||||||
v_cross_scale_shift_timestep,
|
|
||||||
v_cross_gate_timestep,
|
|
||||||
)
|
|
||||||
|
|
||||||
if run_a2v:
|
if run_a2v:
|
||||||
vx_scaled = (
|
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
|
||||||
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
|
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2]
|
||||||
+ shift_ca_video_hidden_states_a2v
|
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
|
||||||
)
|
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2]
|
||||||
ax_scaled = (
|
|
||||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
|
|
||||||
+ shift_ca_audio_hidden_states_a2v
|
|
||||||
)
|
|
||||||
vx += (
|
|
||||||
self.audio_to_video_attn(
|
|
||||||
vx_scaled,
|
|
||||||
context=ax_scaled,
|
|
||||||
pe=v_cross_pe,
|
|
||||||
k_pe=a_cross_pe,
|
|
||||||
transformer_options=transformer_options,
|
|
||||||
)
|
|
||||||
* gate_out_a2v
|
|
||||||
)
|
|
||||||
|
|
||||||
del gate_out_a2v
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
|
||||||
del scale_ca_video_hidden_states_a2v,\
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
|
||||||
shift_ca_video_hidden_states_a2v,\
|
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
|
||||||
scale_ca_audio_hidden_states_a2v,\
|
|
||||||
shift_ca_audio_hidden_states_a2v,\
|
|
||||||
|
|
||||||
|
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
|
||||||
|
del vx_scaled, ax_scaled
|
||||||
|
|
||||||
|
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0]
|
||||||
|
vx.addcmul_(a2v_out, gate_out_a2v)
|
||||||
|
del gate_out_a2v, a2v_out
|
||||||
|
|
||||||
|
# video to audio cross attention
|
||||||
if run_v2a:
|
if run_v2a:
|
||||||
ax_scaled = (
|
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
|
||||||
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
|
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4]
|
||||||
+ shift_ca_audio_hidden_states_v2a
|
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
|
||||||
)
|
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4]
|
||||||
vx_scaled = (
|
|
||||||
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
|
|
||||||
+ shift_ca_video_hidden_states_v2a
|
|
||||||
)
|
|
||||||
ax += (
|
|
||||||
self.video_to_audio_attn(
|
|
||||||
ax_scaled,
|
|
||||||
context=vx_scaled,
|
|
||||||
pe=a_cross_pe,
|
|
||||||
k_pe=v_cross_pe,
|
|
||||||
transformer_options=transformer_options,
|
|
||||||
)
|
|
||||||
* gate_out_v2a
|
|
||||||
)
|
|
||||||
|
|
||||||
del gate_out_v2a
|
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
|
||||||
del scale_ca_video_hidden_states_v2a,\
|
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
|
||||||
shift_ca_video_hidden_states_v2a,\
|
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a
|
||||||
scale_ca_audio_hidden_states_v2a,\
|
|
||||||
shift_ca_audio_hidden_states_v2a
|
|
||||||
|
|
||||||
|
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
|
||||||
|
del ax_scaled, vx_scaled
|
||||||
|
|
||||||
|
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0]
|
||||||
|
ax.addcmul_(v2a_out, gate_out_v2a)
|
||||||
|
del gate_out_v2a, v2a_out
|
||||||
|
|
||||||
|
del vx_norm3, ax_norm3
|
||||||
|
|
||||||
|
# video feedforward
|
||||||
if run_vx:
|
if run_vx:
|
||||||
vshift_mlp, vscale_mlp, vgate_mlp = (
|
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
|
||||||
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
|
|
||||||
)
|
|
||||||
|
|
||||||
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
|
||||||
vx += self.ff(vx_scaled) * vgate_mlp
|
del vshift_mlp, vscale_mlp
|
||||||
del vshift_mlp, vscale_mlp, vgate_mlp
|
|
||||||
|
|
||||||
|
ff_out = self.ff(vx_scaled)
|
||||||
|
del vx_scaled
|
||||||
|
|
||||||
|
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
|
||||||
|
vx.addcmul_(ff_out, vgate_mlp)
|
||||||
|
del vgate_mlp, ff_out
|
||||||
|
|
||||||
|
# audio feedforward
|
||||||
if run_ax:
|
if run_ax:
|
||||||
ashift_mlp, ascale_mlp, agate_mlp = (
|
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
|
||||||
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
|
|
||||||
)
|
|
||||||
|
|
||||||
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
|
||||||
ax += self.audio_ff(ax_scaled) * agate_mlp
|
del ashift_mlp, ascale_mlp
|
||||||
|
|
||||||
del ashift_mlp, ascale_mlp, agate_mlp
|
ff_out = self.audio_ff(ax_scaled)
|
||||||
|
del ax_scaled
|
||||||
|
|
||||||
|
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
|
||||||
|
ax.addcmul_(ff_out, agate_mlp)
|
||||||
|
del agate_mlp, ff_out
|
||||||
|
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|
||||||
@ -589,9 +534,20 @@ class LTXAVModel(LTXVModel):
|
|||||||
audio_length = kwargs.get("audio_length", 0)
|
audio_length = kwargs.get("audio_length", 0)
|
||||||
# Separate audio and video latents
|
# Separate audio and video latents
|
||||||
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
|
||||||
|
|
||||||
|
has_spatial_mask = False
|
||||||
|
if denoise_mask is not None:
|
||||||
|
# check if any frame has spatial variation (inpainting)
|
||||||
|
for frame_idx in range(denoise_mask.shape[2]):
|
||||||
|
frame_mask = denoise_mask[0, 0, frame_idx]
|
||||||
|
if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max():
|
||||||
|
has_spatial_mask = True
|
||||||
|
break
|
||||||
|
|
||||||
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
[vx, v_pixel_coords, additional_args] = super()._process_input(
|
||||||
vx, keyframe_idxs, denoise_mask, **kwargs
|
vx, keyframe_idxs, denoise_mask, **kwargs
|
||||||
)
|
)
|
||||||
|
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||||
|
|
||||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
ax = self.audio_patchify_proj(ax)
|
ax = self.audio_patchify_proj(ax)
|
||||||
@ -618,8 +574,9 @@ class LTXAVModel(LTXVModel):
|
|||||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||||
orig_shape = kwargs.get("orig_shape")
|
orig_shape = kwargs.get("orig_shape")
|
||||||
|
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||||
v_patches_per_frame = None
|
v_patches_per_frame = None
|
||||||
if orig_shape is not None and len(orig_shape) == 5:
|
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||||
|
|
||||||
@ -662,10 +619,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||||
|
# v_patches_per_frame is None for spatial masks, set for temporal masks or no mask
|
||||||
cross_av_timestep_ss = [
|
cross_av_timestep_ss = [
|
||||||
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||||
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||||
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible
|
||||||
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user