From ce054bbf2d30fa7f462b8904c39e3389fbaeee32 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 09:48:51 +1000 Subject: [PATCH] ltx: vae: move max_chunk_size to the RunUpState --- .../lightricks/vae/causal_video_autoencoder.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 209f35758..0b2d0edcf 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -17,10 +17,11 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init class RunUpState: - def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_frames=None): + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None): self.timestep_shift_scale = timestep_shift_scale self.scaled_timestep = scaled_timestep self.checkpoint_fn = checkpoint_fn + self.max_chunk_size = max_chunk_size self.output_frames = output_frames def in_meta_context(): @@ -544,7 +545,7 @@ class Decoder(nn.Module): c, (ts, hs, ws), to = self._output_scale return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) - def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size): + def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset): sample = sample_ref[0] sample_ref[0] = None if idx >= len(self.up_blocks): @@ -583,7 +584,7 @@ class Decoder(nn.Module): return total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size + num_chunks = (total_bytes + run_up_state.max_chunk_size - 1) // run_up_state.max_chunk_size if num_chunks == 1: # when we are not chunking, detach our x so the callee can free it as soon as they are done @@ -591,7 +592,7 @@ class Decoder(nn.Module): del sample #Just let this run_up unconditionally regardless of, its ok because either a lower layer #chunker or output frame stash will do the work anyway. so unchanged. - self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset) return else: samples = torch.chunk(sample, chunks=num_chunks, dim=2) @@ -601,7 +602,7 @@ class Decoder(nn.Module): #list to new state. #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. for chunk_idx, sample1 in enumerate(samples): - self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset) def forward_orig( self, @@ -652,15 +653,15 @@ class Decoder(nn.Module): output_offset = [0] - max_chunk_size = get_max_chunk_size(sample.device) run_up_state = RunUpState( timestep_shift_scale=timestep_shift_scale, scaled_timestep=scaled_timestep, checkpoint_fn=checkpoint_fn, + max_chunk_size=get_max_chunk_size(sample.device), ) self.temporal_cache_state[threading.get_ident()] = run_up_state - self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset, max_chunk_size) + self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset) return output_buffer