From 930df2d70bcbf554e4e04f0fe652731d11d0e1eb Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 09:36:57 +1000 Subject: [PATCH] ltx: vae: save left over tail frames to state If it doesnt fit stash it. --- .../lightricks/vae/causal_video_autoencoder.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index ab1990898..209f35758 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -17,10 +17,11 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init class RunUpState: - def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn): + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_frames=None): self.timestep_shift_scale = timestep_shift_scale self.scaled_timestep = scaled_timestep self.checkpoint_fn = checkpoint_fn + self.output_frames = output_frames def in_meta_context(): return torch.device("meta") == torch.empty(0).device @@ -557,13 +558,15 @@ class Decoder(nn.Module): sample = self.conv_out(sample, causal=self.causal) if sample is not None and sample.shape[2] > 0: sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - t = sample.shape[2] - output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + if output_buffer is None: + run_up_state.output_frames = sample + return + output_slice = output_buffer[:, :, output_offset[0]:output_offset[0] + sample.shape[2]] + t = output_slice.shape[2] + output_slice.copy_(sample[:, :, :t]) output_offset[0] += t - #if there isnt space in the output buffer, you need to stash unconsumed - #frames in the new state as its own thing. Forward resume then just immedatiately - #copies those into the next slice of output. Do not clone. Just slice. - #The VRAM is not a big deal. + if t < sample.shape[2]: + run_up_state.output_frames = sample[:, :, t:] return up_block = self.up_blocks[idx]