From d8fa68084f5d0f97ae85aa98cfe885ddebcf95c1 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 16 Mar 2026 21:08:41 +1000 Subject: [PATCH] wan: remove all concatentation with the feature cache The loopers are now responsible for ensuring that non-final frames are processes at least two-by-two, elimiating the need for this cat case. --- comfy/ldm/wan/vae.py | 65 ++++---------------------------------------- 1 file changed, 5 insertions(+), 60 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 5b62cf123..9a3db212f 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -110,21 +110,6 @@ class Resample(nn.Module): else: cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -149,10 +134,6 @@ class Resample(nn.Module): else: cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x @@ -192,13 +173,6 @@ class ResidualBlock(nn.Module): if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -296,13 +270,6 @@ class Encoder3d(nn.Module): if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -320,8 +287,8 @@ class Encoder3d(nn.Module): ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, final=final) else: x = layer(x) @@ -330,13 +297,6 @@ class Encoder3d(nn.Module): if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -405,13 +365,6 @@ class Decoder3d(nn.Module): if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -420,7 +373,7 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: + if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) @@ -437,13 +390,6 @@ class Decoder3d(nn.Module): if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -519,9 +465,8 @@ class WanVAE(nn.Module): return mu def decode(self, z): - conv_idx = [0] # z: [b,c,t,h,w] - iter_ = z.shape[2] + iter_ = 1 + z.shape[2] // 2 feat_map = None if iter_ > 1: feat_map = [None] * count_cache_layers(self.decoder) @@ -535,7 +480,7 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, feat_idx=conv_idx) out = torch.cat([out, out_], 2)