From b23f1f456fbaa2ccddb3812b3d7a59f36ae54d7a Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 16:14:52 +1000 Subject: [PATCH] ltx: vae: implement start and resume protocol --- .../vae/causal_video_autoencoder.py | 73 ++++++++++++++----- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 13c1c6a1b..c0211addd 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -17,11 +17,13 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init class RunUpState: - def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None): + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None): self.timestep_shift_scale = timestep_shift_scale self.scaled_timestep = scaled_timestep self.checkpoint_fn = checkpoint_fn self.max_chunk_size = max_chunk_size + self.output_shape = output_shape + self.output_dtype = output_dtype self.output_frames = output_frames self.pending_samples = [] @@ -614,6 +616,7 @@ class Decoder(nn.Module): ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] + output_shape = self.decode_output_shape(sample.shape) mark_conv3d_ended(self.conv_in) sample = self.conv_in(sample, causal=self.causal) @@ -660,6 +663,8 @@ class Decoder(nn.Module): scaled_timestep=scaled_timestep, checkpoint_fn=checkpoint_fn, max_chunk_size=get_max_chunk_size(sample.device), + output_shape=output_shape, + output_dtype=sample.dtype, ) self.temporal_cache_state[threading.get_ident()] = run_up_state @@ -667,23 +672,57 @@ class Decoder(nn.Module): return output_buffer - def forward_start(self, *args, **kwargs): - #output_buffer == None implies initial exhaustion, so this should park it in a - #resumable state on the bottom of the run_up stack. - raise NotImplementedError("Decoder.forward_start is not implemented yet") + def forward_start( + self, + sample: torch.FloatTensor, + timestep: Optional[torch.Tensor] = None, + ): + try: + return self.forward_orig(sample, timestep=timestep, output_buffer=None) + except Exception: + clear_temporal_cache_state(self) + raise - #completely new function (maybe) - def forward_resume(self, *args, **kwargs): - #your code here - #inspect the new state. Pop any complete frames first. - #Then execute run_up on the highest index frame. You will need to be the list iterator - #for chunked run_up stashes and potentially restash incompleted lists. - #come down the indicies as stash highest to lowest and just like run_up you need to - #return if output_buffer is exhausted with the extra run_up calls potentially leaving - #behind new state for the next time this is called. - #if we have truly finished, do the same logic as the finally below to clean up. - #we will design the protocol for signalling end to the caller later. - raise NotImplementedError("Decoder.forward_resume is not implemented yet") + def forward_resume(self, output_t: int): + tid = threading.get_ident() + run_up_state = self.temporal_cache_state.get(tid, None) + if run_up_state is None: + return None + + output_shape = list(run_up_state.output_shape) + output_shape[2] = output_t + output_buffer = torch.empty( + output_shape, + dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(), + ) + output_offset = [0] + + try: + if run_up_state.output_frames is not None: + output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]] + t = output_slice.shape[2] + output_slice.copy_(run_up_state.output_frames[:, :, :t]) + output_offset[0] += t + run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:] + + pending_samples = run_up_state.pending_samples + run_up_state.pending_samples = [] + while len(pending_samples): + idx, samples, ended = pending_samples.pop(0) + while len(samples): + if output_offset[0] == output_buffer.shape[2]: + pending_samples = [(idx, samples, ended)] + pending_samples + run_up_state.pending_samples.extend(pending_samples) + return output_buffer + sample1 = samples.pop(0) + self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset) + + if run_up_state.output_frames is None and not run_up_state.pending_samples: + clear_temporal_cache_state(self) + return output_buffer[:, :, :output_offset[0]] + except Exception: + clear_temporal_cache_state(self) + raise def forward( self,