From 035414ede49c1b043ea6de054ca512bcbf0f6b35 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 17 Mar 2026 14:34:39 -0700 Subject: [PATCH] Reduce WAN VAE VRAM, Save use cases for OOM/Tiler (#13014) * wan: vae: encoder: Add feature cache layer that corks singles If a downsample only gives you a single frame, save it to the feature cache and return nothing to the top level. This increases the efficiency of cacheability, but also prepares support for going two by two rather than four by four on the frames. * wan: remove all concatentation with the feature cache The loopers are now responsible for ensuring that non-final frames are processes at least two-by-two, elimiating the need for this cat case. * wan: vae: recurse and chunk for 2+2 frames on decode Avoid having to clone off slices of 4 frame chunks and reduce the size of the big 6 frame convolutions down to 4. Save the VRAMs. * wan: encode frames 2x2. Reduce VRAM usage greatly by encoding frames 2 at a time rather than 4. * wan: vae: remove cloning The loopers now control the chunking such there is noever more than 2 frames, so just cache these slices directly and avoid the clone allocations completely. * wan: vae: free consumer caller tensors on recursion * wan: vae: restyle a little to match LTX --- comfy/ldm/wan/vae.py | 180 +++++++++++++++++++------------------------ 1 file changed, 81 insertions(+), 99 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 71f73c64e..a96b83c6c 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -99,7 +99,7 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): b, c, t, h, w = x.size() if self.mode == 'upsample3d': if feat_cache is not None: @@ -109,22 +109,7 @@ class Resample(nn.Module): feat_idx[0] += 1 else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -145,19 +130,24 @@ class Resample(nn.Module): if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 + feat_cache[idx] = x else: - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - + cache_x = x[:, :, -1:, :, :] x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x - feat_idx[0] += 1 + + deferred_x = feat_cache[idx + 1] + if deferred_x is not None: + x = torch.cat([deferred_x, x], 2) + feat_cache[idx + 1] = None + + if x.shape[2] == 1 and not final: + feat_cache[idx + 1] = x + x = None + + feat_idx[0] += 2 return x @@ -177,19 +167,12 @@ class ResidualBlock(nn.Module): self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -213,7 +196,7 @@ class AttentionBlock(nn.Module): self.proj = ops.Conv2d(dim, dim, 1) self.optimized_attention = vae_attention() - def forward(self, x): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): identity = x b, c, t, h, w = x.size() x = rearrange(x, 'b c t h w -> (b t) c h w') @@ -283,17 +266,10 @@ class Encoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -303,14 +279,16 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache, feat_idx, final=final) + if x is None: + return None else: x = layer(x) ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, final=final) else: x = layer(x) @@ -318,14 +296,7 @@ class Encoder3d(nn.Module): for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -393,14 +364,7 @@ class Decoder3d(nn.Module): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -409,42 +373,56 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + out_chunks = [] + + def run_up(layer_idx, x_ref, feat_idx): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + run_up( + layer_idx, + [x[:, :, frame_idx:frame_idx + 1, :, :]], + feat_idx.copy(), + ) + del x + return + + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) else: x = layer(x) - return x + + next_x_ref = [x] + del x + run_up(layer_idx + 1, next_x_ref, feat_idx) + + run_up(0, [x], feat_idx) + return out_chunks -def count_conv3d(model): +def count_cache_layers(model): count = 0 for m in model.modules(): - if isinstance(m, CausalConv3d): + if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'): count += 1 return count @@ -482,11 +460,12 @@ class WanVAE(nn.Module): conv_idx = [0] ## cache t = x.shape[2] - iter_ = 1 + (t - 1) // 4 + t = 1 + ((t - 1) // 4) * 4 + iter_ = 1 + (t - 1) // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.encoder) - ## 对encode输入的x,按时间拆分为1、4、4、4.... + feat_map = [None] * count_cache_layers(self.encoder) + ## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整) for i in range(iter_): conv_idx = [0] if i == 0: @@ -496,20 +475,23 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, - feat_idx=conv_idx) + feat_idx=conv_idx, + final=(i == (iter_ - 1))) + if out_ is None: + continue out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) return mu def decode(self, z): - conv_idx = [0] # z: [b,c,t,h,w] - iter_ = z.shape[2] + iter_ = 1 + z.shape[2] // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.decoder) + feat_map = [None] * count_cache_layers(self.decoder) x = self.conv2(z) for i in range(iter_): conv_idx = [0] @@ -520,8 +502,8 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, feat_idx=conv_idx) - out = torch.cat([out, out_], 2) - return out + out += out_ + return torch.cat(out, 2)