mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
ltx: vae: free non-chunking recursive intermediates
This commit is contained in:
parent
d4a3cfc649
commit
5dc1bbb9df
@ -543,7 +543,9 @@ class Decoder(nn.Module):
|
|||||||
output = []
|
output = []
|
||||||
max_chunk_size = get_max_chunk_size(sample.device)
|
max_chunk_size = get_max_chunk_size(sample.device)
|
||||||
|
|
||||||
def run_up(idx, sample, ended):
|
def run_up(idx, sample_ref, ended):
|
||||||
|
sample = sample_ref[0]
|
||||||
|
sample_ref[0] = None
|
||||||
if idx >= len(self.up_blocks):
|
if idx >= len(self.up_blocks):
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
if timestep_shift_scale is not None:
|
if timestep_shift_scale is not None:
|
||||||
@ -572,12 +574,20 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
total_bytes = sample.numel() * sample.element_size()
|
total_bytes = sample.numel() * sample.element_size()
|
||||||
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
|
||||||
|
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
if num_chunks == 1:
|
||||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||||
|
next_sample_ref = [sample]
|
||||||
|
del sample
|
||||||
|
run_up(idx + 1, next_sample_ref, ended)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
|
|
||||||
run_up(0, sample, True)
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
|
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
||||||
|
|
||||||
|
run_up(0, [sample], True)
|
||||||
sample = torch.cat(output, dim=2)
|
sample = torch.cat(output, dim=2)
|
||||||
|
|
||||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user