mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
ltx: vae: implement start and resume protocol
This commit is contained in:
parent
06381d5d18
commit
b23f1f456f
@ -17,11 +17,13 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class RunUpState:
|
||||
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None):
|
||||
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None):
|
||||
self.timestep_shift_scale = timestep_shift_scale
|
||||
self.scaled_timestep = scaled_timestep
|
||||
self.checkpoint_fn = checkpoint_fn
|
||||
self.max_chunk_size = max_chunk_size
|
||||
self.output_shape = output_shape
|
||||
self.output_dtype = output_dtype
|
||||
self.output_frames = output_frames
|
||||
self.pending_samples = []
|
||||
|
||||
@ -614,6 +616,7 @@ class Decoder(nn.Module):
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
batch_size = sample.shape[0]
|
||||
output_shape = self.decode_output_shape(sample.shape)
|
||||
|
||||
mark_conv3d_ended(self.conv_in)
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
@ -660,6 +663,8 @@ class Decoder(nn.Module):
|
||||
scaled_timestep=scaled_timestep,
|
||||
checkpoint_fn=checkpoint_fn,
|
||||
max_chunk_size=get_max_chunk_size(sample.device),
|
||||
output_shape=output_shape,
|
||||
output_dtype=sample.dtype,
|
||||
)
|
||||
self.temporal_cache_state[threading.get_ident()] = run_up_state
|
||||
|
||||
@ -667,23 +672,57 @@ class Decoder(nn.Module):
|
||||
|
||||
return output_buffer
|
||||
|
||||
def forward_start(self, *args, **kwargs):
|
||||
#output_buffer == None implies initial exhaustion, so this should park it in a
|
||||
#resumable state on the bottom of the run_up stack.
|
||||
raise NotImplementedError("Decoder.forward_start is not implemented yet")
|
||||
def forward_start(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
):
|
||||
try:
|
||||
return self.forward_orig(sample, timestep=timestep, output_buffer=None)
|
||||
except Exception:
|
||||
clear_temporal_cache_state(self)
|
||||
raise
|
||||
|
||||
#completely new function (maybe)
|
||||
def forward_resume(self, *args, **kwargs):
|
||||
#your code here
|
||||
#inspect the new state. Pop any complete frames first.
|
||||
#Then execute run_up on the highest index frame. You will need to be the list iterator
|
||||
#for chunked run_up stashes and potentially restash incompleted lists.
|
||||
#come down the indicies as stash highest to lowest and just like run_up you need to
|
||||
#return if output_buffer is exhausted with the extra run_up calls potentially leaving
|
||||
#behind new state for the next time this is called.
|
||||
#if we have truly finished, do the same logic as the finally below to clean up.
|
||||
#we will design the protocol for signalling end to the caller later.
|
||||
raise NotImplementedError("Decoder.forward_resume is not implemented yet")
|
||||
def forward_resume(self, output_t: int):
|
||||
tid = threading.get_ident()
|
||||
run_up_state = self.temporal_cache_state.get(tid, None)
|
||||
if run_up_state is None:
|
||||
return None
|
||||
|
||||
output_shape = list(run_up_state.output_shape)
|
||||
output_shape[2] = output_t
|
||||
output_buffer = torch.empty(
|
||||
output_shape,
|
||||
dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
output_offset = [0]
|
||||
|
||||
try:
|
||||
if run_up_state.output_frames is not None:
|
||||
output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]]
|
||||
t = output_slice.shape[2]
|
||||
output_slice.copy_(run_up_state.output_frames[:, :, :t])
|
||||
output_offset[0] += t
|
||||
run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:]
|
||||
|
||||
pending_samples = run_up_state.pending_samples
|
||||
run_up_state.pending_samples = []
|
||||
while len(pending_samples):
|
||||
idx, samples, ended = pending_samples.pop(0)
|
||||
while len(samples):
|
||||
if output_offset[0] == output_buffer.shape[2]:
|
||||
pending_samples = [(idx, samples, ended)] + pending_samples
|
||||
run_up_state.pending_samples.extend(pending_samples)
|
||||
return output_buffer
|
||||
sample1 = samples.pop(0)
|
||||
self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset)
|
||||
|
||||
if run_up_state.output_frames is None and not run_up_state.pending_samples:
|
||||
clear_temporal_cache_state(self)
|
||||
return output_buffer[:, :, :output_offset[0]]
|
||||
except Exception:
|
||||
clear_temporal_cache_state(self)
|
||||
raise
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user