From 7b88e672efedf7b542bfe3f4c1d1cffdff77edb2 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 17 Mar 2026 14:25:15 +1000 Subject: [PATCH] ltx: vae: cleanup some intermediates The conv layer can be the VRAM peak and it does a torch.cat. So cleanup the pieces of the cat. Also clear our the cache ASAP as each layer detect its end as this VAE surges in VRAM at the end due to the ended padding increasing the size of the final frame convolutions off-the-books to the chunker. So if all the earlier layers free up their cache it can offset that surge. Its a fragmentation nightmare, and the chance of it having to recache the pyt allocator is very high, but you wont OOM. --- comfy/ldm/lightricks/vae/causal_conv3d.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index b8341edbc..356394239 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -65,9 +65,13 @@ class CausalConv3d(nn.Module): self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) x = torch.cat(pieces, dim=2) + del pieces + del cached if needs_caching: self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + elif is_end: + self.temporal_cache_state[tid] = (None, True) return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]