mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-28 03:12:31 +08:00
ltx: vae: implement start and resume protocol
This commit is contained in:
parent
06381d5d18
commit
b23f1f456f
@ -17,11 +17,13 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
|||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class RunUpState:
|
class RunUpState:
|
||||||
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_frames=None):
|
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn, max_chunk_size, output_shape, output_dtype, output_frames=None):
|
||||||
self.timestep_shift_scale = timestep_shift_scale
|
self.timestep_shift_scale = timestep_shift_scale
|
||||||
self.scaled_timestep = scaled_timestep
|
self.scaled_timestep = scaled_timestep
|
||||||
self.checkpoint_fn = checkpoint_fn
|
self.checkpoint_fn = checkpoint_fn
|
||||||
self.max_chunk_size = max_chunk_size
|
self.max_chunk_size = max_chunk_size
|
||||||
|
self.output_shape = output_shape
|
||||||
|
self.output_dtype = output_dtype
|
||||||
self.output_frames = output_frames
|
self.output_frames = output_frames
|
||||||
self.pending_samples = []
|
self.pending_samples = []
|
||||||
|
|
||||||
@ -614,6 +616,7 @@ class Decoder(nn.Module):
|
|||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Decoder` class."""
|
r"""The forward method of the `Decoder` class."""
|
||||||
batch_size = sample.shape[0]
|
batch_size = sample.shape[0]
|
||||||
|
output_shape = self.decode_output_shape(sample.shape)
|
||||||
|
|
||||||
mark_conv3d_ended(self.conv_in)
|
mark_conv3d_ended(self.conv_in)
|
||||||
sample = self.conv_in(sample, causal=self.causal)
|
sample = self.conv_in(sample, causal=self.causal)
|
||||||
@ -660,6 +663,8 @@ class Decoder(nn.Module):
|
|||||||
scaled_timestep=scaled_timestep,
|
scaled_timestep=scaled_timestep,
|
||||||
checkpoint_fn=checkpoint_fn,
|
checkpoint_fn=checkpoint_fn,
|
||||||
max_chunk_size=get_max_chunk_size(sample.device),
|
max_chunk_size=get_max_chunk_size(sample.device),
|
||||||
|
output_shape=output_shape,
|
||||||
|
output_dtype=sample.dtype,
|
||||||
)
|
)
|
||||||
self.temporal_cache_state[threading.get_ident()] = run_up_state
|
self.temporal_cache_state[threading.get_ident()] = run_up_state
|
||||||
|
|
||||||
@ -667,23 +672,57 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return output_buffer
|
return output_buffer
|
||||||
|
|
||||||
def forward_start(self, *args, **kwargs):
|
def forward_start(
|
||||||
#output_buffer == None implies initial exhaustion, so this should park it in a
|
self,
|
||||||
#resumable state on the bottom of the run_up stack.
|
sample: torch.FloatTensor,
|
||||||
raise NotImplementedError("Decoder.forward_start is not implemented yet")
|
timestep: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
return self.forward_orig(sample, timestep=timestep, output_buffer=None)
|
||||||
|
except Exception:
|
||||||
|
clear_temporal_cache_state(self)
|
||||||
|
raise
|
||||||
|
|
||||||
#completely new function (maybe)
|
def forward_resume(self, output_t: int):
|
||||||
def forward_resume(self, *args, **kwargs):
|
tid = threading.get_ident()
|
||||||
#your code here
|
run_up_state = self.temporal_cache_state.get(tid, None)
|
||||||
#inspect the new state. Pop any complete frames first.
|
if run_up_state is None:
|
||||||
#Then execute run_up on the highest index frame. You will need to be the list iterator
|
return None
|
||||||
#for chunked run_up stashes and potentially restash incompleted lists.
|
|
||||||
#come down the indicies as stash highest to lowest and just like run_up you need to
|
output_shape = list(run_up_state.output_shape)
|
||||||
#return if output_buffer is exhausted with the extra run_up calls potentially leaving
|
output_shape[2] = output_t
|
||||||
#behind new state for the next time this is called.
|
output_buffer = torch.empty(
|
||||||
#if we have truly finished, do the same logic as the finally below to clean up.
|
output_shape,
|
||||||
#we will design the protocol for signalling end to the caller later.
|
dtype=run_up_state.output_dtype, device=comfy.model_management.intermediate_device(),
|
||||||
raise NotImplementedError("Decoder.forward_resume is not implemented yet")
|
)
|
||||||
|
output_offset = [0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if run_up_state.output_frames is not None:
|
||||||
|
output_slice = output_buffer[:, :, :run_up_state.output_frames.shape[2]]
|
||||||
|
t = output_slice.shape[2]
|
||||||
|
output_slice.copy_(run_up_state.output_frames[:, :, :t])
|
||||||
|
output_offset[0] += t
|
||||||
|
run_up_state.output_frames = None if t == run_up_state.output_frames.shape[2] else run_up_state.output_frames[:, :, t:]
|
||||||
|
|
||||||
|
pending_samples = run_up_state.pending_samples
|
||||||
|
run_up_state.pending_samples = []
|
||||||
|
while len(pending_samples):
|
||||||
|
idx, samples, ended = pending_samples.pop(0)
|
||||||
|
while len(samples):
|
||||||
|
if output_offset[0] == output_buffer.shape[2]:
|
||||||
|
pending_samples = [(idx, samples, ended)] + pending_samples
|
||||||
|
run_up_state.pending_samples.extend(pending_samples)
|
||||||
|
return output_buffer
|
||||||
|
sample1 = samples.pop(0)
|
||||||
|
self.run_up(idx, [sample1], ended and len(samples) == 0, run_up_state, output_buffer, output_offset)
|
||||||
|
|
||||||
|
if run_up_state.output_frames is None and not run_up_state.pending_samples:
|
||||||
|
clear_temporal_cache_state(self)
|
||||||
|
return output_buffer[:, :, :output_offset[0]]
|
||||||
|
except Exception:
|
||||||
|
clear_temporal_cache_state(self)
|
||||||
|
raise
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user