From 4ba0b9125d8da56b006475d9c39d296acbd733ec Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 10:09:43 +1000 Subject: [PATCH] ltx: vae: save un-actionable chunks to RunUpState. --- .../ldm/lightricks/vae/causal_video_autoencoder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 0b2d0edcf..55be72e6f 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -23,6 +23,7 @@ class RunUpState: self.checkpoint_fn = checkpoint_fn self.max_chunk_size = max_chunk_size self.output_frames = output_frames + self.pending_samples = [] def in_meta_context(): return torch.device("meta") == torch.empty(0).device @@ -595,14 +596,13 @@ class Decoder(nn.Module): self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset) return else: - samples = torch.chunk(sample, chunks=num_chunks, dim=2) + samples = list(torch.chunk(sample, chunks=num_chunks, dim=2)) - #This loop might need to become a while. - #If the output buffer is exhausted (or none), it need to stash whatever is left of the samples - #list to new state. - #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. - for chunk_idx, sample1 in enumerate(samples): - self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset) + while len(samples): + if output_buffer is None or output_offset[0] == output_buffer.shape[2]: + run_up_state.pending_samples.append((idx + 1, samples, ended)) + return + self.run_up(idx + 1, [samples.pop(0)], ended and len(samples) == 1, run_up_state, output_buffer, output_offset) def forward_orig( self,