diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 71f73c64e..a96b83c6c 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -99,7 +99,7 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): b, c, t, h, w = x.size() if self.mode == 'upsample3d': if feat_cache is not None: @@ -109,22 +109,7 @@ class Resample(nn.Module): feat_idx[0] += 1 else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -145,19 +130,24 @@ class Resample(nn.Module): if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 + feat_cache[idx] = x else: - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - + cache_x = x[:, :, -1:, :, :] x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x - feat_idx[0] += 1 + + deferred_x = feat_cache[idx + 1] + if deferred_x is not None: + x = torch.cat([deferred_x, x], 2) + feat_cache[idx + 1] = None + + if x.shape[2] == 1 and not final: + feat_cache[idx + 1] = x + x = None + + feat_idx[0] += 2 return x @@ -177,19 +167,12 @@ class ResidualBlock(nn.Module): self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -213,7 +196,7 @@ class AttentionBlock(nn.Module): self.proj = ops.Conv2d(dim, dim, 1) self.optimized_attention = vae_attention() - def forward(self, x): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): identity = x b, c, t, h, w = x.size() x = rearrange(x, 'b c t h w -> (b t) c h w') @@ -283,17 +266,10 @@ class Encoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -303,14 +279,16 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache, feat_idx, final=final) + if x is None: + return None else: x = layer(x) ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, final=final) else: x = layer(x) @@ -318,14 +296,7 @@ class Encoder3d(nn.Module): for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -393,14 +364,7 @@ class Decoder3d(nn.Module): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -409,42 +373,56 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + out_chunks = [] + + def run_up(layer_idx, x_ref, feat_idx): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + run_up( + layer_idx, + [x[:, :, frame_idx:frame_idx + 1, :, :]], + feat_idx.copy(), + ) + del x + return + + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) else: x = layer(x) - return x + + next_x_ref = [x] + del x + run_up(layer_idx + 1, next_x_ref, feat_idx) + + run_up(0, [x], feat_idx) + return out_chunks -def count_conv3d(model): +def count_cache_layers(model): count = 0 for m in model.modules(): - if isinstance(m, CausalConv3d): + if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'): count += 1 return count @@ -482,11 +460,12 @@ class WanVAE(nn.Module): conv_idx = [0] ## cache t = x.shape[2] - iter_ = 1 + (t - 1) // 4 + t = 1 + ((t - 1) // 4) * 4 + iter_ = 1 + (t - 1) // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.encoder) - ## 对encode输入的x,按时间拆分为1、4、4、4.... + feat_map = [None] * count_cache_layers(self.encoder) + ## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整) for i in range(iter_): conv_idx = [0] if i == 0: @@ -496,20 +475,23 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, - feat_idx=conv_idx) + feat_idx=conv_idx, + final=(i == (iter_ - 1))) + if out_ is None: + continue out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) return mu def decode(self, z): - conv_idx = [0] # z: [b,c,t,h,w] - iter_ = z.shape[2] + iter_ = 1 + z.shape[2] // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.decoder) + feat_map = [None] * count_cache_layers(self.decoder) x = self.conv2(z) for i in range(iter_): conv_idx = [0] @@ -520,8 +502,8 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, feat_idx=conv_idx) - out = torch.cat([out, out_], 2) - return out + out += out_ + return torch.cat(out, 2)