ltx: vae: move max_chunk_size to the RunUpState

This commit is contained in:
Rattus 2026-04-08 09:48:51 +10:00
parent 930df2d70b
commit ce054bbf2d

View File

@ -17,10 +17,11 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class RunUpState: class RunUpState:
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_frames=None): def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None):
self.timestep_shift_scale = timestep_shift_scale self.timestep_shift_scale = timestep_shift_scale
self.scaled_timestep = scaled_timestep self.scaled_timestep = scaled_timestep
self.checkpoint_fn = checkpoint_fn self.checkpoint_fn = checkpoint_fn
self.max_chunk_size = max_chunk_size
self.output_frames = output_frames self.output_frames = output_frames
def in_meta_context(): def in_meta_context():
@ -544,7 +545,7 @@ class Decoder(nn.Module):
c, (ts, hs, ws), to = self._output_scale c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size): def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset):
sample = sample_ref[0] sample = sample_ref[0]
sample_ref[0] = None sample_ref[0] = None
if idx >= len(self.up_blocks): if idx >= len(self.up_blocks):
@ -583,7 +584,7 @@ class Decoder(nn.Module):
return return
total_bytes = sample.numel() * sample.element_size() total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size num_chunks = (total_bytes + run_up_state.max_chunk_size - 1) // run_up_state.max_chunk_size
if num_chunks == 1: if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done # when we are not chunking, detach our x so the callee can free it as soon as they are done
@ -591,7 +592,7 @@ class Decoder(nn.Module):
del sample del sample
#Just let this run_up unconditionally regardless of, its ok because either a lower layer #Just let this run_up unconditionally regardless of, its ok because either a lower layer
#chunker or output frame stash will do the work anyway. so unchanged. #chunker or output frame stash will do the work anyway. so unchanged.
self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size) self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset)
return return
else: else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2) samples = torch.chunk(sample, chunks=num_chunks, dim=2)
@ -601,7 +602,7 @@ class Decoder(nn.Module):
#list to new state. #list to new state.
#exhaustion is detectable here with output_offset[0] vs output_buffer shape in T. #exhaustion is detectable here with output_offset[0] vs output_buffer shape in T.
for chunk_idx, sample1 in enumerate(samples): for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset, max_chunk_size) self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset)
def forward_orig( def forward_orig(
self, self,
@ -652,15 +653,15 @@ class Decoder(nn.Module):
output_offset = [0] output_offset = [0]
max_chunk_size = get_max_chunk_size(sample.device)
run_up_state = RunUpState( run_up_state = RunUpState(
timestep_shift_scale=timestep_shift_scale, timestep_shift_scale=timestep_shift_scale,
scaled_timestep=scaled_timestep, scaled_timestep=scaled_timestep,
checkpoint_fn=checkpoint_fn, checkpoint_fn=checkpoint_fn,
max_chunk_size=get_max_chunk_size(sample.device),
) )
self.temporal_cache_state[threading.get_ident()] = run_up_state self.temporal_cache_state[threading.get_ident()] = run_up_state
self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset, max_chunk_size) self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset)
return output_buffer return output_buffer