From 95ca2e56c82c1c714dba685bd81ebf3f7baf8efa Mon Sep 17 00:00:00 2001 From: rattus128 <46076784+rattus128@users.noreply.github.com> Date: Tue, 14 Oct 2025 05:23:11 +1000 Subject: [PATCH] WAN2.2: Fix cache VRAM leak on error (#10308) Same change pattern as 7e8dd275c243ad460ed5015d2e13611d81d2a569 applied to WAN2.2 If this suffers an exception (such as a VRAM oom) it will leave the encode() and decode() methods which skips the cleanup of the WAN feature cache. The comfy node cache then ultimately keeps a reference this object which is in turn reffing large tensors from the failed execution. The feature cache is currently setup at a class variable on the encoder/decoder however, the encode and decode functions always clear it on both entry and exit of normal execution. Its likely the design intent is this is usable as a streaming encoder where the input comes in batches, however the functions as they are today don't support that. So simplify by bringing the cache back to local variable, so that if it does VRAM OOM the cache itself is properly garbage when the encode()/decode() functions dissappear from the stack. --- comfy/ldm/wan/vae2_2.py | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py index 1f6d584a2..8e1593a54 100644 --- a/comfy/ldm/wan/vae2_2.py +++ b/comfy/ldm/wan/vae2_2.py @@ -657,51 +657,51 @@ class WanVAE(nn.Module): ) def encode(self, x): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.encoder) x = patchify(x, patch_size=2) t = x.shape[2] iter_ = 1 + (t - 1) // 4 for i in range(iter_): - self._enc_conv_idx = [0] + conv_idx = [0] if i == 0: out = self.encoder( x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, ) else: out_ = self.encoder( x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, ) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) - self.clear_cache() return mu def decode(self, z): - self.clear_cache() + conv_idx = [0] + feat_map = [None] * count_conv3d(self.decoder) iter_ = z.shape[2] x = self.conv2(z) for i in range(iter_): - self._conv_idx = [0] + conv_idx = [0] if i == 0: out = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, first_chunk=True, ) else: out_ = self.decoder( x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx, + feat_cache=feat_map, + feat_idx=conv_idx, ) out = torch.cat([out, out_], 2) out = unpatchify(out, patch_size=2) - self.clear_cache() return out def reparameterize(self, mu, log_var): @@ -715,12 +715,3 @@ class WanVAE(nn.Module): return mu std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) return mu + std * torch.randn_like(std) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num