diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 9a3db212f..65296afa8 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -378,24 +378,43 @@ class Decoder3d(nn.Module): else: x = layer(x) - ## upsamples - for layer in self.upsamples: + out_chunks = [] + + def run_head(x, feat_idx): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + + def run_upsamples(layer_idx, x, feat_idx): + if layer_idx >= len(self.upsamples): + run_head(x, feat_idx) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + run_upsamples( + layer_idx, + x[:, :, frame_idx:frame_idx + 1, :, :], + feat_idx.copy(), + ) + return + if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + run_upsamples(layer_idx + 1, x, feat_idx) + + run_upsamples(0, x, feat_idx) + return out_chunks def count_cache_layers(model): @@ -483,5 +502,5 @@ class WanVAE(nn.Module): x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, feat_idx=conv_idx) - out = torch.cat([out, out_], 2) - return out + out += out_ + return torch.cat(out, 2)