ltx: vae: implement start and resume protocol

This commit is contained in:
Rattus 2026-04-08 16:14:52 +10:00
parent 06381d5d18
commit b23f1f456f

View File

@ -17,11 +17,13 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class RunUpState: class RunUpState:
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None): def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None):
self.timestep_shift_scale = timestep_shift_scale self.timestep_shift_scale = timestep_shift_scale
self.scaled_timestep = scaled_timestep self.scaled_timestep = scaled_timestep
self.checkpoint_fn = checkpoint_fn self.checkpoint_fn = checkpoint_fn
self.max_chunk_size = max_chunk_size self.max_chunk_size = max_chunk_size
self.output_shape = output_shape
self.output_dtype = output_dtype
self.output_frames = output_frames self.output_frames = output_frames
self.pending_samples = [] self.pending_samples = []
@ -614,6 +616,7 @@ class Decoder(nn.Module):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0] batch_size = sample.shape[0]
output_shape = self.decode_output_shape(sample.shape)
mark_conv3d_ended(self.conv_in) mark_conv3d_ended(self.conv_in)
sample = self.conv_in(sample, causal=self.causal) sample = self.conv_in(sample, causal=self.causal)
@ -660,6 +663,8 @@ class Decoder(nn.Module):
scaled_timestep=scaled_timestep, scaled_timestep=scaled_timestep,
checkpoint_fn=checkpoint_fn, checkpoint_fn=checkpoint_fn,
max_chunk_size=get_max_chunk_size(sample.device), max_chunk_size=get_max_chunk_size(sample.device),
output_shape=output_shape,
output_dtype=sample.dtype,
) )
self.temporal_cache_state[threading.get_ident()] = run_up_state self.temporal_cache_state[threading.get_ident()] = run_up_state
@ -667,23 +672,57 @@ class Decoder(nn.Module):
return output_buffer return output_buffer
def forward_start(self, *args, **kwargs): def forward_start(
#output_buffer == None implies initial exhaustion, so this should park it in a self,
#resumable state on the bottom of the run_up stack. sample: torch.FloatTensor,
raise NotImplementedError("Decoder.forward_start is not implemented yet") timestep: Optional[torch.Tensor] = None,
):
try:
return self.forward_orig(sample, timestep=timestep, output_buffer=None)
except Exception:
clear_temporal_cache_state(self)
raise
#completely new function (maybe) def forward_resume(self, output_t: int):
def forward_resume(self, *args, **kwargs): tid = threading.get_ident()
#your code here run_up_state = self.temporal_cache_state.get(tid, None)
#inspect the new state. Pop any complete frames first. if run_up_state is None:
#Then execute run_up on the highest index frame. You will need to be the list iterator return None
#for chunked run_up stashes and potentially restash incompleted lists.
#come down the indicies as stash highest to lowest and just like run_up you need to output_shape = list(run_up_state.output_shape)
#return if output_buffer is exhausted with the extra run_up calls potentially leaving output_shape[2] = output_t
#behind new state for the next time this is called. output_buffer = torch.empty(
#if we have truly finished, do the same logic as the finally below to clean up. output_shape,
#we will design the protocol for signalling end to the caller later. dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(),
raise NotImplementedError("Decoder.forward_resume is not implemented yet") )
output_offset = [0]
try:
if run_up_state.output_frames is not None:
output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]]
t = output_slice.shape[2]
output_slice.copy_(run_up_state.output_frames[:, :, :t])
output_offset[0] += t
run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:]
pending_samples = run_up_state.pending_samples
run_up_state.pending_samples = []
while len(pending_samples):
idx, samples, ended = pending_samples.pop(0)
while len(samples):
if output_offset[0] == output_buffer.shape[2]:
pending_samples = [(idx, samples, ended)] + pending_samples
run_up_state.pending_samples.extend(pending_samples)
return output_buffer
sample1 = samples.pop(0)
self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset)
if run_up_state.output_frames is None and not run_up_state.pending_samples:
clear_temporal_cache_state(self)
return output_buffer[:, :, :output_offset[0]]
except Exception:
clear_temporal_cache_state(self)
raise
def forward( def forward(
self, self,