wan: remove all concatentation with the feature cache

The loopers are now responsible for ensuring that non-final frames are
processes at least two-by-two, elimiating the need for this cat case.
This commit is contained in:
Rattus 2026-03-16 21:08:41 +10:00
parent 2fac7a8726
commit d8fa68084f

View File

@ -110,21 +110,6 @@ class Resample(nn.Module):
else: else:
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep': if feat_cache[idx] == 'Rep':
x = self.time_conv(x) x = self.time_conv(x)
else: else:
@ -149,10 +134,6 @@ class Resample(nn.Module):
else: else:
cache_x = x[:, :, -1:, :, :].clone() cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv( x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
@ -192,13 +173,6 @@ class ResidualBlock(nn.Module):
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, cache_list=feat_cache, cache_idx=idx) x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -296,13 +270,6 @@ class Encoder3d(nn.Module):
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -320,8 +287,8 @@ class Encoder3d(nn.Module):
## middle ## middle
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx, final=final)
else: else:
x = layer(x) x = layer(x)
@ -330,13 +297,6 @@ class Encoder3d(nn.Module):
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -405,13 +365,6 @@ class Decoder3d(nn.Module):
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -420,7 +373,7 @@ class Decoder3d(nn.Module):
## middle ## middle
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
@ -437,13 +390,6 @@ class Decoder3d(nn.Module):
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -519,9 +465,8 @@ class WanVAE(nn.Module):
return mu return mu
def decode(self, z): def decode(self, z):
conv_idx = [0]
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
iter_ = z.shape[2] iter_ = 1 + z.shape[2] // 2
feat_map = None feat_map = None
if iter_ > 1: if iter_ > 1:
feat_map = [None] * count_cache_layers(self.decoder) feat_map = [None] * count_cache_layers(self.decoder)
@ -535,7 +480,7 @@ class WanVAE(nn.Module):
feat_idx=conv_idx) feat_idx=conv_idx)
else: else:
out_ = self.decoder( out_ = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map, feat_cache=feat_map,
feat_idx=conv_idx) feat_idx=conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)