mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Reduce LTX2 VRAM use by more efficient timestep embed handling (#11829)
This commit is contained in:
parent
c881a1d689
commit
fd5c0755af
@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import (
|
|||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
class CompressedTimestep:
|
||||||
|
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||||
|
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||||
|
|
||||||
|
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||||
|
"""
|
||||||
|
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||||
|
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
|
||||||
|
"""
|
||||||
|
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||||
|
|
||||||
|
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||||
|
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||||
|
self.patches_per_frame = patches_per_frame
|
||||||
|
self.num_frames = num_tokens // patches_per_frame
|
||||||
|
|
||||||
|
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||||
|
# All patches in a frame are identical, so we only keep the first one
|
||||||
|
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||||
|
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||||
|
else:
|
||||||
|
# Not divisible or too small - store directly without compression
|
||||||
|
self.patches_per_frame = 1
|
||||||
|
self.num_frames = num_tokens
|
||||||
|
self.data = tensor
|
||||||
|
|
||||||
|
def expand(self):
|
||||||
|
"""Expand back to original tensor."""
|
||||||
|
if self.patches_per_frame == 1:
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
|
||||||
|
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
|
||||||
|
return expanded.reshape(self.batch_size, -1, self.feature_dim)
|
||||||
|
|
||||||
|
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
|
||||||
|
"""Compute ada values on compressed per-frame data, then expand spatially."""
|
||||||
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
|
# No compression - compute directly
|
||||||
|
if self.patches_per_frame == 1:
|
||||||
|
num_tokens = self.data.shape[1]
|
||||||
|
dim_per_param = self.feature_dim // num_ada_params
|
||||||
|
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
|
||||||
|
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
|
||||||
|
ada_values = (table_values + reshaped).unbind(dim=2)
|
||||||
|
return ada_values
|
||||||
|
|
||||||
|
# Compressed: compute on per-frame data then expand spatially
|
||||||
|
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
|
||||||
|
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
|
||||||
|
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
|
||||||
|
device=self.data.device, dtype=self.data.dtype
|
||||||
|
)
|
||||||
|
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
|
||||||
|
|
||||||
|
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
|
||||||
|
return tuple(
|
||||||
|
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
|
||||||
|
.reshape(batch_size, -1, frame_val.shape[-1])
|
||||||
|
for frame_val in frame_ada
|
||||||
|
)
|
||||||
|
|
||||||
class BasicAVTransformerBlock(nn.Module):
|
class BasicAVTransformerBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
def get_ada_values(
|
def get_ada_values(
|
||||||
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
|
||||||
):
|
):
|
||||||
|
if isinstance(timestep, CompressedTimestep):
|
||||||
|
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
|
||||||
|
|
||||||
num_ada_params = scale_shift_table.shape[0]
|
num_ada_params = scale_shift_table.shape[0]
|
||||||
|
|
||||||
ada_values = (
|
ada_values = (
|
||||||
@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
gate_timestep,
|
gate_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values]
|
return (*scale_shift_ada_values, *gate_ada_values)
|
||||||
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
|
|
||||||
|
|
||||||
return (*scale_shift_chunks, *gate_ada_values)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel):
|
|||||||
if grid_mask is not None:
|
if grid_mask is not None:
|
||||||
timestep = timestep[:, grid_mask]
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||||
|
|
||||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||||
timestep.flatten(),
|
timestep_scaled.flatten(),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||||
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1])
|
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||||
v_embedded_timestep = v_embedded_timestep.view(
|
orig_shape = kwargs.get("orig_shape")
|
||||||
batch_size, -1, v_embedded_timestep.shape[-1]
|
v_patches_per_frame = None
|
||||||
)
|
if orig_shape is not None and len(orig_shape) == 5:
|
||||||
|
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||||
|
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||||
|
|
||||||
|
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||||
|
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||||
|
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
a_timestep = a_timestep * self.timestep_scale_multiplier
|
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||||
|
a_timestep_flat = a_timestep_scaled.flatten()
|
||||||
|
timestep_flat = timestep_scaled.flatten()
|
||||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
|
||||||
|
|
||||||
|
# Cross-attention timesteps - compress these too
|
||||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
a_timestep.flatten(),
|
a_timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
timestep.flatten(),
|
timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
timestep.flatten() * av_ca_factor,
|
timestep_flat * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
a_timestep.flatten() * av_ca_factor,
|
a_timestep_flat * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
|
||||||
|
cross_av_timestep_ss = [
|
||||||
|
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
|
||||||
|
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||||
|
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
|
||||||
|
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
|
||||||
|
]
|
||||||
|
|
||||||
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
a_timestep, a_embedded_timestep = self.audio_adaln_single(
|
||||||
a_timestep.flatten(),
|
a_timestep_flat,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
|
# Audio timesteps
|
||||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
a_embedded_timestep = a_embedded_timestep.view(
|
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||||
batch_size, -1, a_embedded_timestep.shape[-1]
|
|
||||||
)
|
|
||||||
cross_av_timestep_ss = [
|
|
||||||
av_ca_audio_scale_shift_timestep,
|
|
||||||
av_ca_video_scale_shift_timestep,
|
|
||||||
av_ca_a2v_gate_noise_timestep,
|
|
||||||
av_ca_v2a_gate_noise_timestep,
|
|
||||||
]
|
|
||||||
cross_av_timestep_ss = list(
|
|
||||||
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
a_timestep = timestep
|
a_timestep = timestep_scaled
|
||||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
cross_av_timestep_ss = []
|
cross_av_timestep_ss = []
|
||||||
|
|
||||||
@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
ax = x[1]
|
ax = x[1]
|
||||||
v_embedded_timestep = embedded_timestep[0]
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
a_embedded_timestep = embedded_timestep[1]
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
|
||||||
|
# Expand compressed video timestep if needed
|
||||||
|
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||||
|
v_embedded_timestep = v_embedded_timestep.expand()
|
||||||
|
|
||||||
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
|
||||||
|
|
||||||
# Process audio output
|
# Process audio output
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user