mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 22:58:19 +08:00
In Python, mutable default arguments are evaluated **once at function definition time** and shared across all subsequent calls. This is a well-known Python pitfall:
```python
# BAD: this list is shared across ALL calls to forward()
def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1 # modifies the shared default list!
```
In `comfy/ldm/wan/vae.py` and `comfy/ldm/wan/vae2_2.py`, the `forward` methods of `Resample`, `ResidualBlock`, `Down_ResidualBlock`, `Up_ResidualBlock`, `Encoder3d` and `Decoder3d` all use `feat_idx=[0]` as a default argument. Since `feat_idx[0]` is incremented inside these methods, the default value accumulates between inference runs. On the second run, `feat_idx[0]` no longer starts at `0` but at whatever value it reached at the end of the first run, causing incorrect cache indexing throughout the entire encoder and decoder.
**Fix:**
```python
# GOOD: a new list is created for every call that doesn't pass feat_idx
def forward(self, x, feat_cache=None, feat_idx=None):
# Fix: mutable default argument feat_idx=[0] would persist between calls
if feat_idx is None:
feat_idx = [0]
```
**Observed impact:** On AMD/ROCm hardware this bug caused 4-5x slower inference on all runs after the first with WAN VAE. After this fix, only Run 2 remains slightly slower (due to a separate MIOpen kernel cache issue), while Run 3 and beyond are now as fast as Run 1. The bug likely affects all hardware to some degree as incorrect cache indexing causes unnecessary recomputation.
Related issues exists in the ROCm tracker and in the ComfyUI tracker.
https://github.com/ROCm/ROCm/issues/6008
https://github.com/Comfy-Org/ComfyUI/issues/12672#issuecomment-4059981039
536 lines
19 KiB
Python
536 lines
19 KiB
Python
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange
|
||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||
|
||
import comfy.ops
|
||
ops = comfy.ops.disable_weight_init
|
||
|
||
CACHE_T = 2
|
||
|
||
|
||
class CausalConv3d(ops.Conv3d):
|
||
"""
|
||
Causal 3d convolusion.
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self._padding = 2 * self.padding[0]
|
||
self.padding = (0, self.padding[1], self.padding[2])
|
||
|
||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||
if cache_list is not None:
|
||
cache_x = cache_list[cache_idx]
|
||
cache_list[cache_idx] = None
|
||
|
||
if cache_x is None and x.shape[2] == 1:
|
||
#Fast path - the op will pad for use by truncating the weight
|
||
#and save math on a pile of zeros.
|
||
return super().forward(x, autopad="causal_zero")
|
||
|
||
if self._padding > 0:
|
||
padding_needed = self._padding
|
||
if cache_x is not None:
|
||
cache_x = cache_x.to(x.device)
|
||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||
padding_shape = list(x.shape)
|
||
padding_shape[2] = padding_needed
|
||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||
del cache_x
|
||
|
||
return super().forward(x)
|
||
|
||
|
||
class RMS_norm(nn.Module):
|
||
|
||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||
super().__init__()
|
||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||
|
||
self.channel_first = channel_first
|
||
self.scale = dim**0.5
|
||
self.gamma = nn.Parameter(torch.ones(shape))
|
||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
|
||
|
||
def forward(self, x):
|
||
return F.normalize(
|
||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||
|
||
|
||
class Resample(nn.Module):
|
||
|
||
def __init__(self, dim, mode):
|
||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||
'downsample3d')
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.mode = mode
|
||
|
||
# layers
|
||
if mode == 'upsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
elif mode == 'upsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||
|
||
elif mode == 'downsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
elif mode == 'downsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||
|
||
else:
|
||
self.resample = nn.Identity()
|
||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||
if feat_idx is None:
|
||
feat_idx = [0]
|
||
b, c, t, h, w = x.size()
|
||
if self.mode == 'upsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = 'Rep'
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] != 'Rep':
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] == 'Rep':
|
||
cache_x = torch.cat([
|
||
torch.zeros_like(cache_x).to(cache_x.device),
|
||
cache_x
|
||
],
|
||
dim=2)
|
||
if feat_cache[idx] == 'Rep':
|
||
x = self.time_conv(x)
|
||
else:
|
||
x = self.time_conv(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
|
||
x = x.reshape(b, 2, c, t, h, w)
|
||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||
3)
|
||
x = x.reshape(b, c, t * 2, h, w)
|
||
t = x.shape[2]
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.resample(x)
|
||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||
|
||
if self.mode == 'downsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = x.clone()
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -1:, :, :].clone()
|
||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||
# # cache last frame of last two chunk
|
||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||
|
||
x = self.time_conv(
|
||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
return x
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
|
||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||
super().__init__()
|
||
self.in_dim = in_dim
|
||
self.out_dim = out_dim
|
||
|
||
# layers
|
||
self.residual = nn.Sequential(
|
||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||
if in_dim != out_dim else nn.Identity()
|
||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||
if feat_idx is None:
|
||
feat_idx = [0]
|
||
old_x = x
|
||
for layer in self.residual:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x + self.shortcut(old_x)
|
||
|
||
|
||
class AttentionBlock(nn.Module):
|
||
"""
|
||
Causal self-attention with a single head.
|
||
"""
|
||
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.dim = dim
|
||
|
||
# layers
|
||
self.norm = RMS_norm(dim)
|
||
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
|
||
self.proj = ops.Conv2d(dim, dim, 1)
|
||
self.optimized_attention = vae_attention()
|
||
|
||
def forward(self, x):
|
||
identity = x
|
||
b, c, t, h, w = x.size()
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.norm(x)
|
||
# compute query, key, value
|
||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim=1)
|
||
x = self.optimized_attention(q, k, v)
|
||
|
||
# output
|
||
x = self.proj(x)
|
||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||
return x + identity
|
||
|
||
|
||
class Encoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
input_channels=3,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [1] + dim_mult]
|
||
scale = 1.0
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||
|
||
# downsample blocks
|
||
downsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
for _ in range(num_res_blocks):
|
||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
downsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# downsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'downsample3d' if temperal_downsample[
|
||
i] else 'downsample2d'
|
||
downsamples.append(Resample(out_dim, mode=mode))
|
||
scale /= 2.0
|
||
self.downsamples = nn.Sequential(*downsamples)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||
ResidualBlock(out_dim, out_dim, dropout))
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||
if feat_idx is None:
|
||
feat_idx = [0]
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## downsamples
|
||
for layer in self.downsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
class Decoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
output_channels=3,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_upsample=[False, True, True],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_upsample = temperal_upsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||
ResidualBlock(dims[0], dims[0], dropout))
|
||
|
||
# upsample blocks
|
||
upsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
if i == 1 or i == 2 or i == 3:
|
||
in_dim = in_dim // 2
|
||
for _ in range(num_res_blocks + 1):
|
||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
upsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# upsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||
upsamples.append(Resample(out_dim, mode=mode))
|
||
scale *= 2.0
|
||
self.upsamples = nn.Sequential(*upsamples)
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||
# Fix: mutable default argument feat_idx=[0] would persist between calls
|
||
def forward(self, x, feat_cache=None, feat_idx=None):
|
||
if feat_idx is None:
|
||
feat_idx = [0]
|
||
## conv1
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## upsamples
|
||
for layer in self.upsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
def count_conv3d(model):
|
||
count = 0
|
||
for m in model.modules():
|
||
if isinstance(m, CausalConv3d):
|
||
count += 1
|
||
return count
|
||
|
||
|
||
class WanVAE(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
image_channels=3,
|
||
conv_out_channels=3,
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
self.temperal_upsample = temperal_downsample[::-1]
|
||
|
||
# modules
|
||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_downsample, dropout)
|
||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_upsample, dropout)
|
||
|
||
def encode(self, x):
|
||
conv_idx = [0]
|
||
## cache
|
||
t = x.shape[2]
|
||
iter_ = 1 + (t - 1) // 4
|
||
feat_map = None
|
||
if iter_ > 1:
|
||
feat_map = [None] * count_conv3d(self.encoder)
|
||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.encoder(
|
||
x[:, :, :1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.encoder(
|
||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||
return mu
|
||
|
||
def decode(self, z):
|
||
conv_idx = [0]
|
||
# z: [b,c,t,h,w]
|
||
iter_ = z.shape[2]
|
||
feat_map = None
|
||
if iter_ > 1:
|
||
feat_map = [None] * count_conv3d(self.decoder)
|
||
x = self.conv2(z)
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
return out
|