mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 08:12:34 +08:00
wan: vae: encoder: Add feature cache layer that corks singles
If a downsample only gives you a single frame, save it to the feature cache and return nothing to the top level. This increases the efficiency of cacheability, but also prepares support for going two by two rather than four by four on the frames.
This commit is contained in:
parent
7a16e8aa4e
commit
2fac7a8726
@ -99,7 +99,7 @@ class Resample(nn.Module):
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
@ -146,7 +146,6 @@ class Resample(nn.Module):
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
@ -157,7 +156,17 @@ class Resample(nn.Module):
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
deferred_x = feat_cache[idx + 1]
|
||||
if deferred_x is not None:
|
||||
x = torch.cat([deferred_x, x], 2)
|
||||
feat_cache[idx + 1] = None
|
||||
|
||||
if x.shape[2] == 1 and not final:
|
||||
feat_cache[idx + 1] = x
|
||||
x = None
|
||||
|
||||
feat_idx[0] += 2
|
||||
return x
|
||||
|
||||
|
||||
@ -177,7 +186,7 @@ class ResidualBlock(nn.Module):
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
old_x = x
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
@ -213,7 +222,7 @@ class AttentionBlock(nn.Module):
|
||||
self.proj = ops.Conv2d(dim, dim, 1)
|
||||
self.optimized_attention = vae_attention()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
@ -283,7 +292,7 @@ class Encoder3d(nn.Module):
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
@ -303,7 +312,9 @@ class Encoder3d(nn.Module):
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
x = layer(x, feat_cache, feat_idx, final=final)
|
||||
if x is None:
|
||||
return None
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
@ -441,10 +452,10 @@ class Decoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
def count_cache_layers(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@ -485,7 +496,7 @@ class WanVAE(nn.Module):
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.encoder)
|
||||
feat_map = [None] * count_cache_layers(self.encoder)
|
||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
@ -498,8 +509,12 @@ class WanVAE(nn.Module):
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=feat_map,
|
||||
feat_idx=conv_idx)
|
||||
feat_idx=conv_idx,
|
||||
final=(i == (iter_ - 1)))
|
||||
if out_ is None:
|
||||
continue
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
return mu
|
||||
|
||||
@ -509,7 +524,7 @@ class WanVAE(nn.Module):
|
||||
iter_ = z.shape[2]
|
||||
feat_map = None
|
||||
if iter_ > 1:
|
||||
feat_map = [None] * count_conv3d(self.decoder)
|
||||
feat_map = [None] * count_cache_layers(self.decoder)
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
conv_idx = [0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user