From 8658c570f972a37f98ec9bd9c29668f5adc6294c Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 17 Mar 2026 13:26:34 +1000 Subject: [PATCH] wan: vae: free consumer caller tensors on recursion --- comfy/ldm/wan/vae.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index d224336a3..c8346992f 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -391,7 +391,9 @@ class Decoder3d(nn.Module): x = layer(x) out_chunks.append(x) - def run_upsamples(layer_idx, x, feat_idx): + def run_upsamples(layer_idx, x_ref, feat_idx): + x = x_ref[0] + x_ref[0] = None if layer_idx >= len(self.upsamples): run_head(x, feat_idx) return @@ -401,9 +403,10 @@ class Decoder3d(nn.Module): for frame_idx in range(x.shape[2]): run_upsamples( layer_idx, - x[:, :, frame_idx:frame_idx + 1, :, :], + [x[:, :, frame_idx:frame_idx + 1, :, :]], feat_idx.copy(), ) + del x return if feat_cache is not None: @@ -411,9 +414,11 @@ class Decoder3d(nn.Module): else: x = layer(x) - run_upsamples(layer_idx + 1, x, feat_idx) + next_x_ref = [x] + del x + run_upsamples(layer_idx + 1, next_x_ref, feat_idx) - run_upsamples(0, x, feat_idx) + run_upsamples(0, [x], feat_idx) return out_chunks