diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index a4a567130..d224336a3 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -109,7 +109,7 @@ class Resample(nn.Module): feat_idx[0] += 1 else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -CACHE_T:, :, :] if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -130,10 +130,10 @@ class Resample(nn.Module): if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() + feat_cache[idx] = x else: - cache_x = x[:, :, -1:, :, :].clone() + cache_x = x[:, :, -1:, :, :] x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x @@ -172,7 +172,7 @@ class ResidualBlock(nn.Module): for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -269,7 +269,7 @@ class Encoder3d(nn.Module): def forward(self, x, feat_cache=None, feat_idx=[0], final=False): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -296,7 +296,7 @@ class Encoder3d(nn.Module): for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -364,7 +364,7 @@ class Decoder3d(nn.Module): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -383,7 +383,7 @@ class Decoder3d(nn.Module): def run_head(x, feat_idx): for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: - cache_x = x[:, :, -CACHE_T:, :, :].clone() + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, feat_cache[feat_idx[0]]) feat_cache[feat_idx[0]] = cache_x feat_idx[0] += 1