diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 71f73c64e..efbf7644a 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -98,8 +98,10 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] b, c, t, h, w = x.size() if self.mode == 'upsample3d': if feat_cache is not None: @@ -176,8 +178,10 @@ class ResidualBlock(nn.Module): CausalConv3d(out_dim, out_dim, 3, padding=1)) self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: @@ -282,8 +286,10 @@ class Encoder3d(nn.Module): self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) - - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone() @@ -388,8 +394,10 @@ class Decoder3d(nn.Module): self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, output_channels, 3, padding=1)) - - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] ## conv1 if feat_cache is not None: idx = feat_idx[0] diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py index 8e1593a54..fb27a68a2 100644 --- a/comfy/ldm/wan/vae2_2.py +++ b/comfy/ldm/wan/vae2_2.py @@ -54,7 +54,10 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] b, c, t, h, w = x.size() if self.mode == "upsample3d": if feat_cache is not None: @@ -135,7 +138,10 @@ class ResidualBlock(nn.Module): CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()) - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: @@ -326,7 +332,10 @@ class Down_ResidualBlock(nn.Module): self.downsamples = nn.Sequential(*downsamples) - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] x_copy = x for module in self.downsamples: x = module(x, feat_cache, feat_idx) @@ -367,8 +376,10 @@ class Up_ResidualBlock(nn.Module): upsamples.append(Resample(out_dim, mode=mode)) self.upsamples = nn.Sequential(*upsamples) - - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): + if feat_idx is None: + feat_idx = [0] x_main = x for module in self.upsamples: x_main = module(x_main, feat_cache, feat_idx) @@ -438,7 +449,10 @@ class Encoder3d(nn.Module): CausalConv3d(out_dim, z_dim, 3, padding=1), ) - def forward(self, x, feat_cache=None, feat_idx=[0]): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None): + if feat_idx is None: + feat_idx = [0] if feat_cache is not None: idx = feat_idx[0] @@ -550,7 +564,10 @@ class Decoder3d(nn.Module): CausalConv3d(out_dim, 12, 3, padding=1), ) - def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + # Fix: mutable default argument feat_idx=[0] would persist between calls + def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False): + if feat_idx is None: + feat_idx = [0] if feat_cache is not None: idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:, :, :].clone()