mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 16:22:38 +08:00
wan: vae: remove cloning
The loopers now control the chunking such there is noever more than 2 frames, so just cache these slices directly and avoid the clone allocations completely.
This commit is contained in:
parent
6c4843567d
commit
45995ed76a
@ -109,7 +109,7 @@ class Resample(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if feat_cache[idx] == 'Rep':
|
if feat_cache[idx] == 'Rep':
|
||||||
x = self.time_conv(x)
|
x = self.time_conv(x)
|
||||||
else:
|
else:
|
||||||
@ -130,10 +130,10 @@ class Resample(nn.Module):
|
|||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
if feat_cache[idx] is None:
|
if feat_cache[idx] is None:
|
||||||
feat_cache[idx] = x.clone()
|
feat_cache[idx] = x
|
||||||
else:
|
else:
|
||||||
|
|
||||||
cache_x = x[:, :, -1:, :, :].clone()
|
cache_x = x[:, :, -1:, :, :]
|
||||||
x = self.time_conv(
|
x = self.time_conv(
|
||||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
@ -172,7 +172,7 @@ class ResidualBlock(nn.Module):
|
|||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -269,7 +269,7 @@ class Encoder3d(nn.Module):
|
|||||||
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -296,7 +296,7 @@ class Encoder3d(nn.Module):
|
|||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -364,7 +364,7 @@ class Decoder3d(nn.Module):
|
|||||||
## conv1
|
## conv1
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -383,7 +383,7 @@ class Decoder3d(nn.Module):
|
|||||||
def run_head(x, feat_idx):
|
def run_head(x, feat_idx):
|
||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
x = layer(x, feat_cache[feat_idx[0]])
|
x = layer(x, feat_cache[feat_idx[0]])
|
||||||
feat_cache[feat_idx[0]] = cache_x
|
feat_cache[feat_idx[0]] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user