mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
implementation plan
This commit is contained in:
parent
3d21d2afb6
commit
9dede56fac
@ -553,6 +553,10 @@ class Decoder(nn.Module):
|
||||
t = sample.shape[2]
|
||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||
output_offset[0] += t
|
||||
#if there isnt space in the output buffer, you need to stash unconsumed
|
||||
#frames in the new state as its own thing. Forward resume then just immedatiately
|
||||
#copies those into the next slice of output. Do not clone. Just slice.
|
||||
#The VRAM is not a big deal.
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
@ -575,11 +579,17 @@ class Decoder(nn.Module):
|
||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||
next_sample_ref = [sample]
|
||||
del sample
|
||||
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
|
||||
#chunker or output frame stash will do the work anyway. so unchanged.
|
||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
return
|
||||
else:
|
||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||
|
||||
#This loop might need to become a while.
|
||||
#If the output buffer is exhausted (or none), it need to stash whatever is left of the samples
|
||||
#list to new state.
|
||||
#exhaustion is detectable here with output_offset[0] vs output_buffer shape in T.
|
||||
for chunk_idx, sample1 in enumerate(samples):
|
||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||
|
||||
@ -630,6 +640,8 @@ class Decoder(nn.Module):
|
||||
)
|
||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||
|
||||
#The meaning of output_buffer == None changes.
|
||||
#forward now needs to do this for the non start-resume case.
|
||||
if output_buffer is None:
|
||||
output_buffer = torch.empty(
|
||||
self.decode_output_shape(sample.shape),
|
||||
@ -643,6 +655,24 @@ class Decoder(nn.Module):
|
||||
|
||||
return output_buffer
|
||||
|
||||
def forward_start(self, *args, **kwargs):
|
||||
#output_buffer == None implies initial exhaustion, so this should park it in a
|
||||
#resumable state on the bottom of the run_up stack.
|
||||
raise NotImplementedError("Decoder.forward_start is not implemented yet")
|
||||
|
||||
#completely new function (maybe)
|
||||
def forward_resume(self, *args, **kwargs):
|
||||
#your code here
|
||||
#inspect the new state. Pop any complete frames first.
|
||||
#Then execute run_up on the highest index frame. You will need to be the list iterator
|
||||
#for chunked run_up stashes and potentially restash incompleted lists.
|
||||
#come down the indicies as stash highest to lowest and just like run_up you need to
|
||||
#return if output_buffer is exhausted with the extra run_up calls potentially leaving
|
||||
#behind new state for the next time this is called.
|
||||
#if we have truly finished, do the same logic as the finally below to clean up.
|
||||
#we will design the protocol for signalling end to the caller later.
|
||||
raise NotImplementedError("Decoder.forward_resume is not implemented yet")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
try:
|
||||
return self.forward_orig(*args, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user