wan: vae: remove cloning

The loopers now control the chunking such there is noever more than 2
frames, so just cache these slices directly and avoid the clone
allocations completely.
This commit is contained in:
Rattus 2026-03-16 22:11:07 +10:00
parent 6c4843567d
commit 45995ed76a

View File

@ -109,7 +109,7 @@ class Resample(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
if feat_cache[idx] == 'Rep': if feat_cache[idx] == 'Rep':
x = self.time_conv(x) x = self.time_conv(x)
else: else:
@ -130,10 +130,10 @@ class Resample(nn.Module):
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
if feat_cache[idx] is None: if feat_cache[idx] is None:
feat_cache[idx] = x.clone() feat_cache[idx] = x
else: else:
cache_x = x[:, :, -1:, :, :].clone() cache_x = x[:, :, -1:, :, :]
x = self.time_conv( x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
@ -172,7 +172,7 @@ class ResidualBlock(nn.Module):
for layer in self.residual: for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, cache_list=feat_cache, cache_idx=idx) x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -269,7 +269,7 @@ class Encoder3d(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0], final=False): def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -296,7 +296,7 @@ class Encoder3d(nn.Module):
for layer in self.head: for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -364,7 +364,7 @@ class Decoder3d(nn.Module):
## conv1 ## conv1
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -383,7 +383,7 @@ class Decoder3d(nn.Module):
def run_head(x, feat_idx): def run_head(x, feat_idx):
for layer in self.head: for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]]) x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1