mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 19:02:31 +08:00
ltx: vae: save un-actionable chunks to RunUpState.
This commit is contained in:
parent
ce054bbf2d
commit
4ba0b9125d
@ -23,6 +23,7 @@ class RunUpState:
|
|||||||
self.checkpoint_fn = checkpoint_fn
|
self.checkpoint_fn = checkpoint_fn
|
||||||
self.max_chunk_size = max_chunk_size
|
self.max_chunk_size = max_chunk_size
|
||||||
self.output_frames = output_frames
|
self.output_frames = output_frames
|
||||||
|
self.pending_samples = []
|
||||||
|
|
||||||
def in_meta_context():
|
def in_meta_context():
|
||||||
return torch.device("meta") == torch.empty(0).device
|
return torch.device("meta") == torch.empty(0).device
|
||||||
@ -595,14 +596,13 @@ class Decoder(nn.Module):
|
|||||||
self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset)
|
self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
samples = list(torch.chunk(sample, chunks=num_chunks, dim=2))
|
||||||
|
|
||||||
#This loop might need to become a while.
|
while len(samples):
|
||||||
#If the output buffer is exhausted (or none), it need to stash whatever is left of the samples
|
if output_buffer is None or output_offset[0] == output_buffer.shape[2]:
|
||||||
#list to new state.
|
run_up_state.pending_samples.append((idx + 1, samples, ended))
|
||||||
#exhaustion is detectable here with output_offset[0] vs output_buffer shape in T.
|
return
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
self.run_up(idx + 1, [samples.pop(0)], ended and len(samples) == 1, run_up_state, output_buffer, output_offset)
|
||||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset)
|
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user