mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 14:47:33 +08:00
Fixes mutable default arguments in the wan vae.
In Python, mutable default arguments are evaluated **once at function definition time** and shared across all subsequent calls. This is a well-known Python pitfall:
```python
# BAD: this list is shared across ALL calls to forward()
def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1 # modifies the shared default list!
```
In `comfy/ldm/wan/vae.py` and `comfy/ldm/wan/vae2_2.py`, the `forward` methods of `Resample`, `ResidualBlock`, `Down_ResidualBlock`, `Up_ResidualBlock`, `Encoder3d` and `Decoder3d` all use `feat_idx=[0]` as a default argument. Since `feat_idx[0]` is incremented inside these methods, the default value accumulates between inference runs. On the second run, `feat_idx[0]` no longer starts at `0` but at whatever value it reached at the end of the first run, causing incorrect cache indexing throughout the entire encoder and decoder.
**Fix:**
```python
# GOOD: a new list is created for every call that doesn't pass feat_idx
def forward(self, x, feat_cache=None, feat_idx=None):
# Fix: mutable default argument feat_idx=[0] would persist between calls
if feat_idx is None:
feat_idx = [0]
```
**Observed impact:** On AMD/ROCm hardware this bug caused 4-5x slower inference on all runs after the first with WAN VAE. After this fix, only Run 2 remains slightly slower (due to a separate MIOpen kernel cache issue), while Run 3 and beyond are now as fast as Run 1. The bug likely affects all hardware to some degree as incorrect cache indexing causes unnecessary recomputation.
Related issues exists in the ROCm tracker and in the ComfyUI tracker.
https://github.com/ROCm/ROCm/issues/6008
https://github.com/Comfy-Org/ComfyUI/issues/12672#issuecomment-4059981039
This commit is contained in:
parent
16cd8d8a8f
commit
385221234c
@ -98,8 +98,10 @@ class Resample(nn.Module):
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
@ -176,8 +178,10 @@ class ResidualBlock(nn.Module):
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
@ -282,8 +286,10 @@ class Encoder3d(nn.Module):
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
@ -388,8 +394,10 @@ class Decoder3d(nn.Module):
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
|
||||
@ -54,7 +54,10 @@ class Resample(nn.Module):
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == "upsample3d":
|
||||
if feat_cache is not None:
|
||||
@ -135,7 +138,10 @@ class ResidualBlock(nn.Module):
|
||||
CausalConv3d(in_dim, out_dim, 1)
|
||||
if in_dim != out_dim else nn.Identity())
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
@ -326,7 +332,10 @@ class Down_ResidualBlock(nn.Module):
|
||||
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
x_copy = x
|
||||
for module in self.downsamples:
|
||||
x = module(x, feat_cache, feat_idx)
|
||||
@ -367,8 +376,10 @@ class Up_ResidualBlock(nn.Module):
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
x_main = x
|
||||
for module in self.upsamples:
|
||||
x_main = module(x_main, feat_cache, feat_idx)
|
||||
@ -438,7 +449,10 @@ class Encoder3d(nn.Module):
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@ -550,7 +564,10 @@ class Decoder3d(nn.Module):
|
||||
CausalConv3d(out_dim, 12, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||||
def forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False):
|
||||
if feat_idx is None:
|
||||
feat_idx = [0]
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user