wan: vae: free consumer caller tensors on recursion

This commit is contained in:
Rattus 2026-03-17 13:26:34 +10:00
parent 45995ed76a
commit 8658c570f9

View File

@ -391,7 +391,9 @@ class Decoder3d(nn.Module):
x = layer(x)
out_chunks.append(x)
def run_upsamples(layer_idx, x, feat_idx):
def run_upsamples(layer_idx, x_ref, feat_idx):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
run_head(x, feat_idx)
return
@ -401,9 +403,10 @@ class Decoder3d(nn.Module):
for frame_idx in range(x.shape[2]):
run_upsamples(
layer_idx,
x[:, :, frame_idx:frame_idx + 1, :, :],
[x[:, :, frame_idx:frame_idx + 1, :, :]],
feat_idx.copy(),
)
del x
return
if feat_cache is not None:
@ -411,9 +414,11 @@ class Decoder3d(nn.Module):
else:
x = layer(x)
run_upsamples(layer_idx + 1, x, feat_idx)
next_x_ref = [x]
del x
run_upsamples(layer_idx + 1, next_x_ref, feat_idx)
run_upsamples(0, x, feat_idx)
run_upsamples(0, [x], feat_idx)
return out_chunks