diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index c8346992f..a96b83c6c 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -380,28 +380,25 @@ class Decoder3d(nn.Module): out_chunks = [] - def run_head(x, feat_idx): - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - cache_x = x[:, :, -CACHE_T:, :, :] - x = layer(x, feat_cache[feat_idx[0]]) - feat_cache[feat_idx[0]] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - out_chunks.append(x) - - def run_upsamples(layer_idx, x_ref, feat_idx): + def run_up(layer_idx, x_ref, feat_idx): x = x_ref[0] x_ref[0] = None if layer_idx >= len(self.upsamples): - run_head(x, feat_idx) + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) return layer = self.upsamples[layer_idx] if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: for frame_idx in range(x.shape[2]): - run_upsamples( + run_up( layer_idx, [x[:, :, frame_idx:frame_idx + 1, :, :]], feat_idx.copy(), @@ -416,9 +413,9 @@ class Decoder3d(nn.Module): next_x_ref = [x] del x - run_upsamples(layer_idx + 1, next_x_ref, feat_idx) + run_up(layer_idx + 1, next_x_ref, feat_idx) - run_upsamples(0, [x], feat_idx) + run_up(0, [x], feat_idx) return out_chunks