implementation plan

This commit is contained in:
Rattus 2026-04-08 08:45:18 +10:00
parent 3d21d2afb6
commit 9dede56fac

View File

@ -553,6 +553,10 @@ class Decoder(nn.Module):
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
#if there isnt space in the output buffer, you need to stash unconsumed
#frames in the new state as its own thing. Forward resume then just immedatiately
#copies those into the next slice of output. Do not clone. Just slice.
#The VRAM is not a big deal.
return
up_block = self.up_blocks[idx]
@ -575,11 +579,17 @@ class Decoder(nn.Module):
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
#chunker or output frame stash will do the work anyway. so unchanged.
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
#This loop might need to become a while.
#If the output buffer is exhausted (or none), it need to stash whatever is left of the samples
#list to new state.
#exhaustion is detectable here with output_offset[0] vs output_buffer shape in T.
for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
@ -630,6 +640,8 @@ class Decoder(nn.Module):
)
timestep_shift_scale = ada_values.unbind(dim=1)
#The meaning of output_buffer == None changes.
#forward now needs to do this for the non start-resume case.
if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
@ -643,6 +655,24 @@ class Decoder(nn.Module):
return output_buffer
def forward_start(self, *args, **kwargs):
#output_buffer == None implies initial exhaustion, so this should park it in a
#resumable state on the bottom of the run_up stack.
raise NotImplementedError("Decoder.forward_start is not implemented yet")
#completely new function (maybe)
def forward_resume(self, *args, **kwargs):
#your code here
#inspect the new state. Pop any complete frames first.
#Then execute run_up on the highest index frame. You will need to be the list iterator
#for chunked run_up stashes and potentially restash incompleted lists.
#come down the indicies as stash highest to lowest and just like run_up you need to
#return if output_buffer is exhausted with the extra run_up calls potentially leaving
#behind new state for the next time this is called.
#if we have truly finished, do the same logic as the finally below to clean up.
#we will design the protocol for signalling end to the caller later.
raise NotImplementedError("Decoder.forward_resume is not implemented yet")
def forward(self, *args, **kwargs):
try:
return self.forward_orig(*args, **kwargs)