Reduce LTX2 VAE VRAM consumption (#12028)

* causal_video_ae: Remove attention ResNet

This attention_head_dim argument does not exist on this constructor so
this is dead code. Remove as generic attention mid VAE conflicts with
temporal roll.

* ltx-vae: consoldate causal/non-causal code paths

* ltx-vae: add cache rolling adder

* ltx-vae: use cached adder for resnet

* ltx-vae: Implement rolling VAE

Implement a temporal rolling VAE for the LTX2 VAE.

Usually when doing temporal rolling VAEs you can just chunk on time relying
on causality and cache behind you as you go. The LTX VAE is however
non-causal.

So go whole hog and implement per layer run ahead and backpressure between
the decoder layers using recursive state beween the layers.

Operations are ammended with temporal_cache_state{} which they can use to
hold any state then need for partial execution. Convolutions cache their
inputs behind the up to N-1 frames, and skip connections need to cache the
mismatch between convolution input and output that happens due to missing
future (non-causal) input.

Each call to run_up() processes a layer accross a range on input that
may or may not be complete. It goes depth first to process as much as
possible to try and digest frames to the final output ASAP. If layers run
out of input due to convolution losses, they simply return without action
effectively applying back-pressure to the earlier layers. As the earlier
layers do more work and caller deeper, the partial states are reconciled
and output continues to digest depth first as much as possible.

Chunking is done using a size quota rather than a fixed frame length and
any layer can initiate chunking, and multiple layers can chunk at different
granulatiries. This remove the old limitation of always having to process
1 latent frame to entirety and having to hold 8 full decoded frames as
the VRAM peak.
This commit is contained in:
rattus 2026-01-22 13:54:18 -08:00 committed by GitHub
parent 8490eedadf
commit 0fd1b78736
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 64 deletions

View File

@ -1,11 +1,11 @@
from typing import Tuple, Union from typing import Tuple, Union
import threading
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
def __init__( def __init__(
self, self,
@ -42,23 +42,34 @@ class CausalConv3d(nn.Module):
padding_mode=spatial_padding_mode, padding_mode=spatial_padding_mode,
groups=groups, groups=groups,
) )
self.temporal_cache_state={}
def forward(self, x, causal: bool = True): def forward(self, x, causal: bool = True):
if causal: tid = threading.get_ident()
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1) cached, is_end = self.temporal_cache_state.get(tid, (None, False))
) if cached is None:
x = torch.concatenate((first_frame_pad, x), dim=2) padding_length = self.time_kernel_size - 1
else: if not causal:
first_frame_pad = x[:, :, :1, :, :].repeat( padding_length = padding_length // 2
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1) if x.shape[2] == 0:
) return x
last_frame_pad = x[:, :, -1:, :, :].repeat( cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1))
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1) pieces = [ cached, x ]
) if is_end and not causal:
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
x = self.conv(x)
return x needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
x = torch.cat(pieces, dim=2)
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
@property @property
def weight(self): def weight(self):

View File

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import threading
import torch import torch
from torch import nn from torch import nn
from functools import partial from functools import partial
@ -6,12 +7,35 @@ import math
from einops import rearrange from einops import rearrange
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd from .conv_nd_factory import make_conv_nd, make_linear_nd
from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops import comfy.ops
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
if isinstance(m, CausalConv3d):
current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True)
def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
def add_exchange_cache(dest, cache_in, new_input, dim=2):
if dest is not None:
if cache_in is not None:
cache_to_dest = min(dest.shape[dim], cache_in.shape[dim])
lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim)
lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim)
lead_in_dest.add_(lead_in_source)
body, new_input = split2(new_input, dest.shape[dim], dim)
dest.add_(body)
return torch_cat_if_needed([cache_in, new_input], dim=dim)
class Encoder(nn.Module): class Encoder(nn.Module):
r""" r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
@ -205,7 +229,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class.""" r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
@ -254,6 +278,22 @@ class Encoder(nn.Module):
return sample return sample
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
tid = threading.get_ident()
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
class Decoder(nn.Module): class Decoder(nn.Module):
r""" r"""
@ -341,18 +381,6 @@ class Decoder(nn.Module):
timestep_conditioning=timestep_conditioning, timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
) )
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y": elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2) output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D( block = ResnetBlock3D(
@ -428,8 +456,9 @@ class Decoder(nn.Module):
) )
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward( def forward_orig(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None,
@ -437,6 +466,7 @@ class Decoder(nn.Module):
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0] batch_size = sample.shape[0]
mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal) sample = self.conv_in(sample, causal=self.causal)
checkpoint_fn = ( checkpoint_fn = (
@ -445,24 +475,12 @@ class Decoder(nn.Module):
else lambda x: x else lambda x: x
) )
scaled_timestep = None timestep_shift_scale = None
if self.timestep_conditioning: if self.timestep_conditioning:
assert ( assert (
timestep is not None timestep is not None
), "should pass timestep with timestep_conditioning=True" ), "should pass timestep with timestep_conditioning=True"
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
for up_block in self.up_blocks:
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
if self.timestep_conditioning:
embedded_timestep = self.last_time_embedder( embedded_timestep = self.last_time_embedder(
timestep=scaled_timestep.flatten(), timestep=scaled_timestep.flatten(),
resolution=None, resolution=None,
@ -483,16 +501,62 @@ class Decoder(nn.Module):
embedded_timestep.shape[-2], embedded_timestep.shape[-2],
embedded_timestep.shape[-1], embedded_timestep.shape[-1],
) )
shift, scale = ada_values.unbind(dim=1) timestep_shift_scale = ada_values.unbind(dim=1)
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample) output = []
sample = self.conv_out(sample, causal=self.causal)
def run_up(idx, sample, ended):
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
up_block = self.up_blocks[idx]
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample return sample
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)
finally:
for _, module in self.named_modules():
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
class UNetMidBlock3D(nn.Module): class UNetMidBlock3D(nn.Module):
""" """
@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module):
) )
self.residual = residual self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor self.out_channels_reduction_factor = out_channels_reduction_factor
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None): def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
tid = threading.get_ident()
cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True))
y = self.conv(x, causal=causal)
y = rearrange(
y,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv:
y = y[:, :, 1:, :, :]
drop_first_conv = False
if self.residual: if self.residual:
# Reshape and duplicate the input to match the output shape # Reshape and duplicate the input to match the output shape
x_in = rearrange( x_in = rearrange(
@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module):
) )
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
x_in = x_in.repeat(1, num_repeat, 1, 1, 1) x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
if self.stride[0] == 2: if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res:
x_in = x_in[:, :, 1:, :, :] x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal) drop_first_res = False
x = rearrange(
x, if y.shape[2] == 0:
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", y = None
p1=self.stride[0],
p2=self.stride[1], cached = add_exchange_cache(y, cached, x_in, dim=2)
p3=self.stride[2], self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res)
)
if self.stride[0] == 2: else:
x = x[:, :, 1:, :, :] self.temporal_cache_state[tid] = (None, drop_first_conv, False)
if self.residual:
x = x + x_in return y
return x
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None: def __init__(self, dim, eps, elementwise_affine=True) -> None:
@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module):
torch.randn(4, in_channels) / in_channels**0.5 torch.randn(4, in_channels) / in_channels**0.5
) )
self.temporal_cache_state={}
def _feed_spatial_noise( def _feed_spatial_noise(
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module):
input_tensor = self.conv_shortcut(input_tensor) input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states tid = threading.get_ident()
cached = self.temporal_cache_state.get(tid, None)
cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2)
self.temporal_cache_state[tid] = cached
return output_tensor return hidden_states
def patchify(x, patch_size_hw, patch_size_t=1): def patchify(x, patch_size_hw, patch_size_t=1):

View File

@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae():
import xformers.ops import xformers.ops
def torch_cat_if_needed(xl, dim): def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1: if len(xl) > 1:
return torch.cat(xl, dim) return torch.cat(xl, dim)
else: elif len(xl) == 1:
return xl[0] return xl[0]
else:
return None
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """