ltx: vae: save left over tail frames to state

If it doesnt fit stash it.
This commit is contained in:
Rattus 2026-04-08 09:36:57 +10:00
parent ab9e006873
commit 930df2d70b

View File

@ -17,10 +17,11 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
class RunUpState:
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn):
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_frames=None):
self.timestep_shift_scale = timestep_shift_scale
self.scaled_timestep = scaled_timestep
self.checkpoint_fn = checkpoint_fn
self.output_frames = output_frames
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
@ -557,13 +558,15 @@ class Decoder(nn.Module):
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
if output_buffer is None:
run_up_state.output_frames = sample
return
output_slice = output_buffer[:, :, output_offset[0]:output_offset[0] + sample.shape[2]]
t = output_slice.shape[2]
output_slice.copy_(sample[:, :, :t])
output_offset[0] += t
#if there isnt space in the output buffer, you need to stash unconsumed
#frames in the new state as its own thing. Forward resume then just immedatiately
#copies those into the next slice of output. Do not clone. Just slice.
#The VRAM is not a big deal.
if t < sample.shape[2]:
run_up_state.output_frames = sample[:, :, t:]
return
up_block = self.up_blocks[idx]