mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Some checks are pending
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
* ops: introduce autopad for conv3d This works around pytorch missing ability to causal pad as part of the kernel and avoids massive weight duplications for padding. * wan-vae: rework causal padding This currently uses F.pad which takes a full deep copy and is liable to be the VRAM peak. Instead, kick spatial padding back to the op and consolidate the temporal padding with the cat for the cache. * wan-vae: implement zero pad fast path The WAN VAE is also QWEN where it is used single-image. These convolutions are however zero padded 3d convolutions, which means the VAE is actually just 2D down the last element of the conv weight in the temporal dimension. Fast path this, to avoid adding zeros that then just evaporate in convoluton math but cost computation.
524 lines
18 KiB
Python
524 lines
18 KiB
Python
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange
|
||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||
|
||
import comfy.ops
|
||
ops = comfy.ops.disable_weight_init
|
||
|
||
CACHE_T = 2
|
||
|
||
|
||
class CausalConv3d(ops.Conv3d):
|
||
"""
|
||
Causal 3d convolusion.
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self._padding = 2 * self.padding[0]
|
||
self.padding = (0, self.padding[1], self.padding[2])
|
||
|
||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||
if cache_list is not None:
|
||
cache_x = cache_list[cache_idx]
|
||
cache_list[cache_idx] = None
|
||
|
||
if cache_x is None and x.shape[2] == 1:
|
||
#Fast path - the op will pad for use by truncating the weight
|
||
#and save math on a pile of zeros.
|
||
return super().forward(x, autopad="causal_zero")
|
||
|
||
if self._padding > 0:
|
||
padding_needed = self._padding
|
||
if cache_x is not None:
|
||
cache_x = cache_x.to(x.device)
|
||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||
padding_shape = list(x.shape)
|
||
padding_shape[2] = padding_needed
|
||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||
del cache_x
|
||
|
||
return super().forward(x)
|
||
|
||
|
||
class RMS_norm(nn.Module):
|
||
|
||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||
super().__init__()
|
||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||
|
||
self.channel_first = channel_first
|
||
self.scale = dim**0.5
|
||
self.gamma = nn.Parameter(torch.ones(shape))
|
||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
|
||
|
||
def forward(self, x):
|
||
return F.normalize(
|
||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||
|
||
|
||
class Resample(nn.Module):
|
||
|
||
def __init__(self, dim, mode):
|
||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||
'downsample3d')
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.mode = mode
|
||
|
||
# layers
|
||
if mode == 'upsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
elif mode == 'upsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||
|
||
elif mode == 'downsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
elif mode == 'downsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||
|
||
else:
|
||
self.resample = nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
b, c, t, h, w = x.size()
|
||
if self.mode == 'upsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = 'Rep'
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] != 'Rep':
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
if cache_x.shape[2] < 2 and feat_cache[
|
||
idx] is not None and feat_cache[idx] == 'Rep':
|
||
cache_x = torch.cat([
|
||
torch.zeros_like(cache_x).to(cache_x.device),
|
||
cache_x
|
||
],
|
||
dim=2)
|
||
if feat_cache[idx] == 'Rep':
|
||
x = self.time_conv(x)
|
||
else:
|
||
x = self.time_conv(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
|
||
x = x.reshape(b, 2, c, t, h, w)
|
||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||
3)
|
||
x = x.reshape(b, c, t * 2, h, w)
|
||
t = x.shape[2]
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.resample(x)
|
||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||
|
||
if self.mode == 'downsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = x.clone()
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -1:, :, :].clone()
|
||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
||
# # cache last frame of last two chunk
|
||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||
|
||
x = self.time_conv(
|
||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
return x
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
|
||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||
super().__init__()
|
||
self.in_dim = in_dim
|
||
self.out_dim = out_dim
|
||
|
||
# layers
|
||
self.residual = nn.Sequential(
|
||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||
if in_dim != out_dim else nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
old_x = x
|
||
for layer in self.residual:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x + self.shortcut(old_x)
|
||
|
||
|
||
class AttentionBlock(nn.Module):
|
||
"""
|
||
Causal self-attention with a single head.
|
||
"""
|
||
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.dim = dim
|
||
|
||
# layers
|
||
self.norm = RMS_norm(dim)
|
||
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
|
||
self.proj = ops.Conv2d(dim, dim, 1)
|
||
self.optimized_attention = vae_attention()
|
||
|
||
def forward(self, x):
|
||
identity = x
|
||
b, c, t, h, w = x.size()
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.norm(x)
|
||
# compute query, key, value
|
||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim=1)
|
||
x = self.optimized_attention(q, k, v)
|
||
|
||
# output
|
||
x = self.proj(x)
|
||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||
return x + identity
|
||
|
||
|
||
class Encoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
input_channels=3,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [1] + dim_mult]
|
||
scale = 1.0
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||
|
||
# downsample blocks
|
||
downsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
for _ in range(num_res_blocks):
|
||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
downsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# downsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'downsample3d' if temperal_downsample[
|
||
i] else 'downsample2d'
|
||
downsamples.append(Resample(out_dim, mode=mode))
|
||
scale /= 2.0
|
||
self.downsamples = nn.Sequential(*downsamples)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||
ResidualBlock(out_dim, out_dim, dropout))
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## downsamples
|
||
for layer in self.downsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
class Decoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
output_channels=3,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_upsample=[False, True, True],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_upsample = temperal_upsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||
ResidualBlock(dims[0], dims[0], dropout))
|
||
|
||
# upsample blocks
|
||
upsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
if i == 1 or i == 2 or i == 3:
|
||
in_dim = in_dim // 2
|
||
for _ in range(num_res_blocks + 1):
|
||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
upsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# upsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||
upsamples.append(Resample(out_dim, mode=mode))
|
||
scale *= 2.0
|
||
self.upsamples = nn.Sequential(*upsamples)
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
## conv1
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## upsamples
|
||
for layer in self.upsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||
# cache last frame of last two chunk
|
||
cache_x = torch.cat([
|
||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||
cache_x.device), cache_x
|
||
],
|
||
dim=2)
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
def count_conv3d(model):
|
||
count = 0
|
||
for m in model.modules():
|
||
if isinstance(m, CausalConv3d):
|
||
count += 1
|
||
return count
|
||
|
||
|
||
class WanVAE(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
image_channels=3,
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
self.temperal_upsample = temperal_downsample[::-1]
|
||
|
||
# modules
|
||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_downsample, dropout)
|
||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_upsample, dropout)
|
||
|
||
def encode(self, x):
|
||
conv_idx = [0]
|
||
feat_map = [None] * count_conv3d(self.decoder)
|
||
## cache
|
||
t = x.shape[2]
|
||
iter_ = 1 + (t - 1) // 4
|
||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.encoder(
|
||
x[:, :, :1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.encoder(
|
||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||
return mu
|
||
|
||
def decode(self, z):
|
||
conv_idx = [0]
|
||
feat_map = [None] * count_conv3d(self.decoder)
|
||
# z: [b,c,t,h,w]
|
||
|
||
iter_ = z.shape[2]
|
||
x = self.conv2(z)
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
out = torch.cat([out, out_], 2)
|
||
return out
|