From ab9e006873adea59131dac58ee0aef3ba53a5a03 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 8 Apr 2026 09:08:28 +1000 Subject: [PATCH] ltx: vae: Move constants to a named tuple. Consolidate these into a named tuple. This will expand with more content. Save it to the Decoder module itself for reusability. --- .../vae/causal_video_autoencoder.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index e137ae28d..ab1990898 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -16,6 +16,12 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +class RunUpState: + def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn): + self.timestep_shift_scale = timestep_shift_scale + self.scaled_timestep = scaled_timestep + self.checkpoint_fn = checkpoint_fn + def in_meta_context(): return torch.device("meta") == torch.empty(0).device @@ -530,19 +536,20 @@ class Decoder(nn.Module): ).unsqueeze(1).expand(2, output_channel), persistent=False, ) + self.temporal_cache_state = {} def decode_output_shape(self, input_shape): c, (ts, hs, ws), to = self._output_scale return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) - def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size): + def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size): sample = sample_ref[0] sample_ref[0] = None if idx >= len(self.up_blocks): sample = self.conv_norm_out(sample) - if timestep_shift_scale is not None: - shift, scale = timestep_shift_scale + if run_up_state.timestep_shift_scale is not None: + shift, scale = run_up_state.timestep_shift_scale sample = sample * (1 + scale) + shift sample = self.conv_act(sample) if ended: @@ -563,11 +570,11 @@ class Decoder(nn.Module): if ended: mark_conv3d_ended(up_block) if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep + sample = run_up_state.checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=run_up_state.scaled_timestep ) else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) + sample = run_up_state.checkpoint_fn(up_block)(sample, causal=self.causal) if sample is None or sample.shape[2] == 0: return @@ -581,7 +588,7 @@ class Decoder(nn.Module): del sample #Just let this run_up unconditionally regardless of, its ok because either a lower layer #chunker or output frame stash will do the work anyway. so unchanged. - self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size) return else: samples = torch.chunk(sample, chunks=num_chunks, dim=2) @@ -591,7 +598,7 @@ class Decoder(nn.Module): #list to new state. #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. for chunk_idx, sample1 in enumerate(samples): - self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset, max_chunk_size) def forward_orig( self, @@ -643,8 +650,14 @@ class Decoder(nn.Module): output_offset = [0] max_chunk_size = get_max_chunk_size(sample.device) + run_up_state = RunUpState( + timestep_shift_scale=timestep_shift_scale, + scaled_timestep=scaled_timestep, + checkpoint_fn=checkpoint_fn, + ) + self.temporal_cache_state[threading.get_ident()] = run_up_state - self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset, max_chunk_size) return output_buffer