ltx: vae: free non-chunking recursive intermediates

This commit is contained in:
Rattus 2026-03-17 13:39:13 +10:00
parent d4a3cfc649
commit 5dc1bbb9df

View File

@ -543,7 +543,9 @@ class Decoder(nn.Module):
output = [] output = []
max_chunk_size = get_max_chunk_size(sample.device) max_chunk_size = get_max_chunk_size(sample.device)
def run_up(idx, sample, ended): def run_up(idx, sample_ref, ended):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks): if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample) sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None: if timestep_shift_scale is not None:
@ -572,12 +574,20 @@ class Decoder(nn.Module):
total_bytes = sample.numel() * sample.element_size() total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples): if num_chunks == 1:
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1) # when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
run_up(idx + 1, next_sample_ref, ended)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
run_up(0, sample, True) for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
run_up(0, [sample], True)
sample = torch.cat(output, dim=2) sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)