In Python, mutable default arguments are evaluated **once at function definition time** and shared across all subsequent calls. This is a well-known Python pitfall:
```python
# BAD: this list is shared across ALL calls to forward()
def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1 # modifies the shared default list!
```
In `comfy/ldm/wan/vae.py` and `comfy/ldm/wan/vae2_2.py`, the `forward` methods of `Resample`, `ResidualBlock`, `Down_ResidualBlock`, `Up_ResidualBlock`, `Encoder3d` and `Decoder3d` all use `feat_idx=[0]` as a default argument. Since `feat_idx[0]` is incremented inside these methods, the default value accumulates between inference runs. On the second run, `feat_idx[0]` no longer starts at `0` but at whatever value it reached at the end of the first run, causing incorrect cache indexing throughout the entire encoder and decoder.
**Fix:**
```python
# GOOD: a new list is created for every call that doesn't pass feat_idx
def forward(self, x, feat_cache=None, feat_idx=None):
# Fix: mutable default argument feat_idx=[0] would persist between calls
if feat_idx is None:
feat_idx = [0]
```
**Observed impact:** On AMD/ROCm hardware this bug caused 4-5x slower inference on all runs after the first with WAN VAE. After this fix, only Run 2 remains slightly slower (due to a separate MIOpen kernel cache issue), while Run 3 and beyond are now as fast as Run 1. The bug likely affects all hardware to some degree as incorrect cache indexing causes unnecessary recomputation.
Related issues exists in the ROCm tracker and in the ComfyUI tracker.
https://github.com/ROCm/ROCm/issues/6008https://github.com/Comfy-Org/ComfyUI/issues/12672#issuecomment-4059981039
The code throughout is None safe to just skip the feature cache saving
step if none. Set it none in single frame use so qwen doesn't burn VRAM
on the unused cache.
* ops: introduce autopad for conv3d
This works around pytorch missing ability to causal pad as part of the
kernel and avoids massive weight duplications for padding.
* wan-vae: rework causal padding
This currently uses F.pad which takes a full deep copy and is liable to
be the VRAM peak. Instead, kick spatial padding back to the op and
consolidate the temporal padding with the cat for the cache.
* wan-vae: implement zero pad fast path
The WAN VAE is also QWEN where it is used single-image. These
convolutions are however zero padded 3d convolutions, which means the
VAE is actually just 2D down the last element of the conv weight in
the temporal dimension. Fast path this, to avoid adding zeros that
then just evaporate in convoluton math but cost computation.
If this suffers an exception (such as a VRAM oom) it will leave the
encode() and decode() methods which skips the cleanup of the WAN
feature cache. The comfy node cache then ultimately keeps a reference
this object which is in turn reffing large tensors from the failed
execution.
The feature cache is currently setup at a class variable on the
encoder/decoder however, the encode and decode functions always clear
it on both entry and exit of normal execution.
Its likely the design intent is this is usable as a streaming encoder
where the input comes in batches, however the functions as they are
today don't support that.
So simplify by bringing the cache back to local variable, so that if
it does VRAM OOM the cache itself is properly garbage when the
encode()/decode() functions dissappear from the stack.