mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-20 00:24:59 +08:00
* wan: vae: encoder: Add feature cache layer that corks singles If a downsample only gives you a single frame, save it to the feature cache and return nothing to the top level. This increases the efficiency of cacheability, but also prepares support for going two by two rather than four by four on the frames. * wan: remove all concatentation with the feature cache The loopers are now responsible for ensuring that non-final frames are processes at least two-by-two, elimiating the need for this cat case. * wan: vae: recurse and chunk for 2+2 frames on decode Avoid having to clone off slices of 4 frame chunks and reduce the size of the big 6 frame convolutions down to 4. Save the VRAMs. * wan: encode frames 2x2. Reduce VRAM usage greatly by encoding frames 2 at a time rather than 4. * wan: vae: remove cloning The loopers now control the chunking such there is noever more than 2 frames, so just cache these slices directly and avoid the clone allocations completely. * wan: vae: free consumer caller tensors on recursion * wan: vae: restyle a little to match LTX
510 lines
17 KiB
Python
510 lines
17 KiB
Python
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
|
||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from einops import rearrange
|
||
from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed
|
||
|
||
import comfy.ops
|
||
ops = comfy.ops.disable_weight_init
|
||
|
||
CACHE_T = 2
|
||
|
||
|
||
class CausalConv3d(ops.Conv3d):
|
||
"""
|
||
Causal 3d convolusion.
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self._padding = 2 * self.padding[0]
|
||
self.padding = (0, self.padding[1], self.padding[2])
|
||
|
||
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||
if cache_list is not None:
|
||
cache_x = cache_list[cache_idx]
|
||
cache_list[cache_idx] = None
|
||
|
||
if cache_x is None and x.shape[2] == 1:
|
||
#Fast path - the op will pad for use by truncating the weight
|
||
#and save math on a pile of zeros.
|
||
return super().forward(x, autopad="causal_zero")
|
||
|
||
if self._padding > 0:
|
||
padding_needed = self._padding
|
||
if cache_x is not None:
|
||
cache_x = cache_x.to(x.device)
|
||
padding_needed = max(0, padding_needed - cache_x.shape[2])
|
||
padding_shape = list(x.shape)
|
||
padding_shape[2] = padding_needed
|
||
padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype)
|
||
x = torch_cat_if_needed([padding, cache_x, x], dim=2)
|
||
del cache_x
|
||
|
||
return super().forward(x)
|
||
|
||
|
||
class RMS_norm(nn.Module):
|
||
|
||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||
super().__init__()
|
||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||
|
||
self.channel_first = channel_first
|
||
self.scale = dim**0.5
|
||
self.gamma = nn.Parameter(torch.ones(shape))
|
||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
|
||
|
||
def forward(self, x):
|
||
return F.normalize(
|
||
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
|
||
|
||
|
||
class Resample(nn.Module):
|
||
|
||
def __init__(self, dim, mode):
|
||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||
'downsample3d')
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.mode = mode
|
||
|
||
# layers
|
||
if mode == 'upsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
elif mode == 'upsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||
ops.Conv2d(dim, dim // 2, 3, padding=1))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||
|
||
elif mode == 'downsample2d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
elif mode == 'downsample3d':
|
||
self.resample = nn.Sequential(
|
||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||
self.time_conv = CausalConv3d(
|
||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||
|
||
else:
|
||
self.resample = nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||
b, c, t, h, w = x.size()
|
||
if self.mode == 'upsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = 'Rep'
|
||
feat_idx[0] += 1
|
||
else:
|
||
|
||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||
if feat_cache[idx] == 'Rep':
|
||
x = self.time_conv(x)
|
||
else:
|
||
x = self.time_conv(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
|
||
x = x.reshape(b, 2, c, t, h, w)
|
||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||
3)
|
||
x = x.reshape(b, c, t * 2, h, w)
|
||
t = x.shape[2]
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.resample(x)
|
||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||
|
||
if self.mode == 'downsample3d':
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
if feat_cache[idx] is None:
|
||
feat_cache[idx] = x
|
||
else:
|
||
|
||
cache_x = x[:, :, -1:, :, :]
|
||
x = self.time_conv(
|
||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||
feat_cache[idx] = cache_x
|
||
|
||
deferred_x = feat_cache[idx + 1]
|
||
if deferred_x is not None:
|
||
x = torch.cat([deferred_x, x], 2)
|
||
feat_cache[idx + 1] = None
|
||
|
||
if x.shape[2] == 1 and not final:
|
||
feat_cache[idx + 1] = x
|
||
x = None
|
||
|
||
feat_idx[0] += 2
|
||
return x
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
|
||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||
super().__init__()
|
||
self.in_dim = in_dim
|
||
self.out_dim = out_dim
|
||
|
||
# layers
|
||
self.residual = nn.Sequential(
|
||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||
if in_dim != out_dim else nn.Identity()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||
old_x = x
|
||
for layer in self.residual:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x + self.shortcut(old_x)
|
||
|
||
|
||
class AttentionBlock(nn.Module):
|
||
"""
|
||
Causal self-attention with a single head.
|
||
"""
|
||
|
||
def __init__(self, dim):
|
||
super().__init__()
|
||
self.dim = dim
|
||
|
||
# layers
|
||
self.norm = RMS_norm(dim)
|
||
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
|
||
self.proj = ops.Conv2d(dim, dim, 1)
|
||
self.optimized_attention = vae_attention()
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||
identity = x
|
||
b, c, t, h, w = x.size()
|
||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||
x = self.norm(x)
|
||
# compute query, key, value
|
||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim=1)
|
||
x = self.optimized_attention(q, k, v)
|
||
|
||
# output
|
||
x = self.proj(x)
|
||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||
return x + identity
|
||
|
||
|
||
class Encoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
input_channels=3,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [1] + dim_mult]
|
||
scale = 1.0
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||
|
||
# downsample blocks
|
||
downsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
for _ in range(num_res_blocks):
|
||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
downsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# downsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'downsample3d' if temperal_downsample[
|
||
i] else 'downsample2d'
|
||
downsamples.append(Resample(out_dim, mode=mode))
|
||
scale /= 2.0
|
||
self.downsamples = nn.Sequential(*downsamples)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||
ResidualBlock(out_dim, out_dim, dropout))
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## downsamples
|
||
for layer in self.downsamples:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx, final=final)
|
||
if x is None:
|
||
return None
|
||
else:
|
||
x = layer(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx, final=final)
|
||
else:
|
||
x = layer(x)
|
||
|
||
## head
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||
x = layer(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
return x
|
||
|
||
|
||
class Decoder3d(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
output_channels=3,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_upsample=[False, True, True],
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_upsample = temperal_upsample
|
||
|
||
# dimensions
|
||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||
|
||
# init block
|
||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||
|
||
# middle blocks
|
||
self.middle = nn.Sequential(
|
||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||
ResidualBlock(dims[0], dims[0], dropout))
|
||
|
||
# upsample blocks
|
||
upsamples = []
|
||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||
# residual (+attention) blocks
|
||
if i == 1 or i == 2 or i == 3:
|
||
in_dim = in_dim // 2
|
||
for _ in range(num_res_blocks + 1):
|
||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||
if scale in attn_scales:
|
||
upsamples.append(AttentionBlock(out_dim))
|
||
in_dim = out_dim
|
||
|
||
# upsample block
|
||
if i != len(dim_mult) - 1:
|
||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||
upsamples.append(Resample(out_dim, mode=mode))
|
||
scale *= 2.0
|
||
self.upsamples = nn.Sequential(*upsamples)
|
||
|
||
# output blocks
|
||
self.head = nn.Sequential(
|
||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||
|
||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||
## conv1
|
||
if feat_cache is not None:
|
||
idx = feat_idx[0]
|
||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||
x = self.conv1(x, feat_cache[idx])
|
||
feat_cache[idx] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = self.conv1(x)
|
||
|
||
## middle
|
||
for layer in self.middle:
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
out_chunks = []
|
||
|
||
def run_up(layer_idx, x_ref, feat_idx):
|
||
x = x_ref[0]
|
||
x_ref[0] = None
|
||
if layer_idx >= len(self.upsamples):
|
||
for layer in self.head:
|
||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||
cache_x = x[:, :, -CACHE_T:, :, :]
|
||
x = layer(x, feat_cache[feat_idx[0]])
|
||
feat_cache[feat_idx[0]] = cache_x
|
||
feat_idx[0] += 1
|
||
else:
|
||
x = layer(x)
|
||
out_chunks.append(x)
|
||
return
|
||
|
||
layer = self.upsamples[layer_idx]
|
||
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||
for frame_idx in range(x.shape[2]):
|
||
run_up(
|
||
layer_idx,
|
||
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
||
feat_idx.copy(),
|
||
)
|
||
del x
|
||
return
|
||
|
||
if feat_cache is not None:
|
||
x = layer(x, feat_cache, feat_idx)
|
||
else:
|
||
x = layer(x)
|
||
|
||
next_x_ref = [x]
|
||
del x
|
||
run_up(layer_idx + 1, next_x_ref, feat_idx)
|
||
|
||
run_up(0, [x], feat_idx)
|
||
return out_chunks
|
||
|
||
|
||
def count_cache_layers(model):
|
||
count = 0
|
||
for m in model.modules():
|
||
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||
count += 1
|
||
return count
|
||
|
||
|
||
class WanVAE(nn.Module):
|
||
|
||
def __init__(self,
|
||
dim=128,
|
||
z_dim=4,
|
||
dim_mult=[1, 2, 4, 4],
|
||
num_res_blocks=2,
|
||
attn_scales=[],
|
||
temperal_downsample=[True, True, False],
|
||
image_channels=3,
|
||
conv_out_channels=3,
|
||
dropout=0.0):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.z_dim = z_dim
|
||
self.dim_mult = dim_mult
|
||
self.num_res_blocks = num_res_blocks
|
||
self.attn_scales = attn_scales
|
||
self.temperal_downsample = temperal_downsample
|
||
self.temperal_upsample = temperal_downsample[::-1]
|
||
|
||
# modules
|
||
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_downsample, dropout)
|
||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||
attn_scales, self.temperal_upsample, dropout)
|
||
|
||
def encode(self, x):
|
||
conv_idx = [0]
|
||
## cache
|
||
t = x.shape[2]
|
||
t = 1 + ((t - 1) // 4) * 4
|
||
iter_ = 1 + (t - 1) // 2
|
||
feat_map = None
|
||
if iter_ > 1:
|
||
feat_map = [None] * count_cache_layers(self.encoder)
|
||
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.encoder(
|
||
x[:, :, :1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.encoder(
|
||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx,
|
||
final=(i == (iter_ - 1)))
|
||
if out_ is None:
|
||
continue
|
||
out = torch.cat([out, out_], 2)
|
||
|
||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||
return mu
|
||
|
||
def decode(self, z):
|
||
# z: [b,c,t,h,w]
|
||
iter_ = 1 + z.shape[2] // 2
|
||
feat_map = None
|
||
if iter_ > 1:
|
||
feat_map = [None] * count_cache_layers(self.decoder)
|
||
x = self.conv2(z)
|
||
for i in range(iter_):
|
||
conv_idx = [0]
|
||
if i == 0:
|
||
out = self.decoder(
|
||
x[:, :, i:i + 1, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
else:
|
||
out_ = self.decoder(
|
||
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||
feat_cache=feat_map,
|
||
feat_idx=conv_idx)
|
||
out += out_
|
||
return torch.cat(out, 2)
|