From 9dede56facdbdf5698d0e81c8024a00b01070103 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 08:45:18 +1000 Subject: [PATCH] implementation plan --- .../vae/causal_video_autoencoder.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 998122c85..539c0487c 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -553,6 +553,10 @@ class Decoder(nn.Module): t = sample.shape[2] output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) output_offset[0] += t + #if there isnt space in the output buffer, you need to stash unconsumed + #frames in the new state as its own thing. Forward resume then just immedatiately + #copies those into the next slice of output. Do not clone. Just slice. + #The VRAM is not a big deal. return up_block = self.up_blocks[idx] @@ -575,11 +579,17 @@ class Decoder(nn.Module): # when we are not chunking, detach our x so the callee can free it as soon as they are done next_sample_ref = [sample] del sample + #Just let this run_up unconditionally regardless of, its ok because either a lower layer + #chunker or output frame stash will do the work anyway. so unchanged. self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) return else: samples = torch.chunk(sample, chunks=num_chunks, dim=2) + #This loop might need to become a while. + #If the output buffer is exhausted (or none), it need to stash whatever is left of the samples + #list to new state. + #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. for chunk_idx, sample1 in enumerate(samples): self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) @@ -630,6 +640,8 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) + #The meaning of output_buffer == None changes. + #forward now needs to do this for the non start-resume case. if output_buffer is None: output_buffer = torch.empty( self.decode_output_shape(sample.shape), @@ -643,6 +655,24 @@ class Decoder(nn.Module): return output_buffer + def forward_start(self, *args, **kwargs): + #output_buffer == None implies initial exhaustion, so this should park it in a + #resumable state on the bottom of the run_up stack. + raise NotImplementedError("Decoder.forward_start is not implemented yet") + + #completely new function (maybe) + def forward_resume(self, *args, **kwargs): + #your code here + #inspect the new state. Pop any complete frames first. + #Then execute run_up on the highest index frame. You will need to be the list iterator + #for chunked run_up stashes and potentially restash incompleted lists. + #come down the indicies as stash highest to lowest and just like run_up you need to + #return if output_buffer is exhausted with the extra run_up calls potentially leaving + #behind new state for the next time this is called. + #if we have truly finished, do the same logic as the finally below to clean up. + #we will design the protocol for signalling end to the caller later. + raise NotImplementedError("Decoder.forward_resume is not implemented yet") + def forward(self, *args, **kwargs): try: return self.forward_orig(*args, **kwargs)