mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
implementation plan
This commit is contained in:
parent
3d21d2afb6
commit
9dede56fac
@ -553,6 +553,10 @@ class Decoder(nn.Module):
|
|||||||
t = sample.shape[2]
|
t = sample.shape[2]
|
||||||
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||||
output_offset[0] += t
|
output_offset[0] += t
|
||||||
|
#if there isnt space in the output buffer, you need to stash unconsumed
|
||||||
|
#frames in the new state as its own thing. Forward resume then just immedatiately
|
||||||
|
#copies those into the next slice of output. Do not clone. Just slice.
|
||||||
|
#The VRAM is not a big deal.
|
||||||
return
|
return
|
||||||
|
|
||||||
up_block = self.up_blocks[idx]
|
up_block = self.up_blocks[idx]
|
||||||
@ -575,11 +579,17 @@ class Decoder(nn.Module):
|
|||||||
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||||
next_sample_ref = [sample]
|
next_sample_ref = [sample]
|
||||||
del sample
|
del sample
|
||||||
|
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
|
||||||
|
#chunker or output frame stash will do the work anyway. so unchanged.
|
||||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
|
|
||||||
|
#This loop might need to become a while.
|
||||||
|
#If the output buffer is exhausted (or none), it need to stash whatever is left of the samples
|
||||||
|
#list to new state.
|
||||||
|
#exhaustion is detectable here with output_offset[0] vs output_buffer shape in T.
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
|
|
||||||
@ -630,6 +640,8 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||||
|
|
||||||
|
#The meaning of output_buffer == None changes.
|
||||||
|
#forward now needs to do this for the non start-resume case.
|
||||||
if output_buffer is None:
|
if output_buffer is None:
|
||||||
output_buffer = torch.empty(
|
output_buffer = torch.empty(
|
||||||
self.decode_output_shape(sample.shape),
|
self.decode_output_shape(sample.shape),
|
||||||
@ -643,6 +655,24 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return output_buffer
|
return output_buffer
|
||||||
|
|
||||||
|
def forward_start(self, *args, **kwargs):
|
||||||
|
#output_buffer == None implies initial exhaustion, so this should park it in a
|
||||||
|
#resumable state on the bottom of the run_up stack.
|
||||||
|
raise NotImplementedError("Decoder.forward_start is not implemented yet")
|
||||||
|
|
||||||
|
#completely new function (maybe)
|
||||||
|
def forward_resume(self, *args, **kwargs):
|
||||||
|
#your code here
|
||||||
|
#inspect the new state. Pop any complete frames first.
|
||||||
|
#Then execute run_up on the highest index frame. You will need to be the list iterator
|
||||||
|
#for chunked run_up stashes and potentially restash incompleted lists.
|
||||||
|
#come down the indicies as stash highest to lowest and just like run_up you need to
|
||||||
|
#return if output_buffer is exhausted with the extra run_up calls potentially leaving
|
||||||
|
#behind new state for the next time this is called.
|
||||||
|
#if we have truly finished, do the same logic as the finally below to clean up.
|
||||||
|
#we will design the protocol for signalling end to the caller later.
|
||||||
|
raise NotImplementedError("Decoder.forward_resume is not implemented yet")
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return self.forward_orig(*args, **kwargs)
|
return self.forward_orig(*args, **kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user