ltx: vae: implement chunked encoder + CPU IO chunking (Big VRAM reductions) (#13062)

* ltx: vae: add cache state to downsample block

* ltx: vae: Add time stride awareness to causal_conv_3d

* ltx: vae: Automate truncation for encoder

Other VAEs just truncate without error. Do the same.

* sd/ltx: Make chunked_io a flag in its own right

Taking this bi-direcitonal, so make it a for-purpose named flag.

* ltx: vae: implement chunked encoder + CPU IO chunking

People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.

* ltx: vae-encode: round chunk sizes more strictly

Only powers of 2 and multiple of 8 are valid due to cache slicing.
This commit is contained in:
rattus 2026-03-19 09:58:47 -07:00 committed by GitHub
parent f6b869d7d3
commit fabed694a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 26 deletions

View File

@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels
if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
@ -58,18 +63,23 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ]
if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)

View File

@ -233,10 +233,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
sample = self.conv_in(sample)
checkpoint_fn = (
@ -247,10 +244,14 @@ class Encoder(nn.Module):
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
if sample is None or sample.shape[2] == 0:
return None
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if sample is None or sample.shape[2] == 0:
return None
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
@ -282,9 +283,35 @@ class Encoder(nn.Module):
return sample
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
outputs = []
samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1:
chunk_t = max(2, max_chunk_size // frame_size)
if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self)
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
output = self._forward_chunk(chunk)
if output is not None:
outputs.append(output)
return torch_cat_if_needed(outputs, dim=2)
def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try:
return self.forward_orig(*args, **kwargs)
finally:
@ -737,12 +764,25 @@ class SpaceToDepthDownsample(nn.Module):
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None
if self.stride[0] == 2 and pad_first:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
pad_first = False
if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None
# skip connection
x_in = rearrange(
@ -757,15 +797,26 @@ class SpaceToDepthDownsample(nn.Module):
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2 and x.shape[2] == 1:
if cached_x is not None:
x = torch_cat_if_needed([cached_x, x], dim=2)
cached_x = None
else:
cached_x = x
x = None
x = x + x_in
if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
cached = add_exchange_cache(x, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return x
@ -1098,6 +1149,8 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
comfy_has_chunked_io = True
def __init__(self, version=0, config=None):
super().__init__()
@ -1240,11 +1293,9 @@ class VideoVAE(nn.Module):
}
return config
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
def encode(self, x, device=None):
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode_output_shape(self, input_shape):

View File

@ -953,7 +953,7 @@ class VAE:
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if hasattr(self.first_stage_model, 'decode_output_shape'):
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
@ -1038,8 +1038,13 @@ class VAE:
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out