Commit Graph

5 Commits

Author SHA1 Message Date
rattus
1a157e1f97
Reduce LTX VAE VRAM usage and save use cases from OOMs/Tiler (#13013)
* ltx: vae: scale the chunk size with the users VRAM

Scale this linearly down for users with low VRAM.

* ltx: vae: free non-chunking recursive intermediates

* ltx: vae: cleanup some intermediates

The conv layer can be the VRAM peak and it does a torch.cat. So cleanup
the pieces of the cat. Also clear our the cache ASAP as each layer detect
its end as this VAE surges in VRAM at the end due to the ended padding
increasing the size of the final frame convolutions off-the-books to
the chunker. So if all the earlier layers free up their cache it can
offset that surge.

Its a fragmentation nightmare, and the chance of it having to recache the
pyt allocator is very high, but you wont OOM.
2026-03-17 17:32:43 -04:00
rattus
0fd1b78736
Reduce LTX2 VAE VRAM consumption (#12028)
* causal_video_ae: Remove attention ResNet

This attention_head_dim argument does not exist on this constructor so
this is dead code. Remove as generic attention mid VAE conflicts with
temporal roll.

* ltx-vae: consoldate causal/non-causal code paths

* ltx-vae: add cache rolling adder

* ltx-vae: use cached adder for resnet

* ltx-vae: Implement rolling VAE

Implement a temporal rolling VAE for the LTX2 VAE.

Usually when doing temporal rolling VAEs you can just chunk on time relying
on causality and cache behind you as you go. The LTX VAE is however
non-causal.

So go whole hog and implement per layer run ahead and backpressure between
the decoder layers using recursive state beween the layers.

Operations are ammended with temporal_cache_state{} which they can use to
hold any state then need for partial execution. Convolutions cache their
inputs behind the up to N-1 frames, and skip connections need to cache the
mismatch between convolution input and output that happens due to missing
future (non-causal) input.

Each call to run_up() processes a layer accross a range on input that
may or may not be complete. It goes depth first to process as much as
possible to try and digest frames to the final output ASAP. If layers run
out of input due to convolution losses, they simply return without action
effectively applying back-pressure to the earlier layers. As the earlier
layers do more work and caller deeper, the partial states are reconciled
and output continues to digest depth first as much as possible.

Chunking is done using a size quota rather than a fixed frame length and
any layer can initiate chunking, and multiple layers can chunk at different
granulatiries. This remove the old limitation of always having to process
1 latent frame to entirety and having to hold 8 full decoded frames as
the VRAM peak.
2026-01-22 16:54:18 -05:00
comfyanonymous
93fedd92fe Support LTXV 0.9.5.
Credits: Lightricks team.
2025-03-05 00:13:49 -05:00
comfyanonymous
e5c3f4b87f LTXV lowvram fixes. 2024-11-22 17:17:11 -05:00
comfyanonymous
5e16f1d24b Support Lightricks LTX-Video model. 2024-11-22 08:46:39 -05:00