From 7a88c578e0f46fdb72c974de64faeedcf84d03e6 Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 20 Mar 2026 11:38:46 +1000 Subject: [PATCH] wan: vae: Don't recursion in local fns (move run_up) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moved Decoder3d’s recursive run_up out of forward into a class method to avoid nested closure self-reference cycles. This avoids cyclic garbage that delays garbage of tensors which in turn delays VRAM release before tiled fallback. --- comfy/ldm/wan/vae.py | 74 +++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index a96b83c6c..deeb8695b 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -360,6 +360,43 @@ class Decoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, output_channels, 3, padding=1)) + def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + self.run_up( + layer_idx, + [x[:, :, frame_idx:frame_idx + 1, :, :]], + feat_cache, + feat_idx.copy(), + out_chunks, + ) + del x + return + + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + next_x_ref = [x] + del x + self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks) + def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: @@ -380,42 +417,7 @@ class Decoder3d(nn.Module): out_chunks = [] - def run_up(layer_idx, x_ref, feat_idx): - x = x_ref[0] - x_ref[0] = None - if layer_idx >= len(self.upsamples): - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - cache_x = x[:, :, -CACHE_T:, :, :] - x = layer(x, feat_cache[feat_idx[0]]) - feat_cache[feat_idx[0]] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - out_chunks.append(x) - return - - layer = self.upsamples[layer_idx] - if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: - for frame_idx in range(x.shape[2]): - run_up( - layer_idx, - [x[:, :, frame_idx:frame_idx + 1, :, :]], - feat_idx.copy(), - ) - del x - return - - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - next_x_ref = [x] - del x - run_up(layer_idx + 1, next_x_ref, feat_idx) - - run_up(0, [x], feat_idx) + self.run_up(0, [x], feat_cache, feat_idx, out_chunks) return out_chunks