From 6571c912a70a6e9233283467f85b7ee10b3d7b59 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sun, 30 Nov 2025 08:56:07 +1000 Subject: [PATCH] model: Add temporal roll to main VAE decoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). --- comfy/ldm/modules/diffusionmodules/model.py | 84 +++++++++++++-------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 47713d0d8..098c7f247 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -126,29 +126,24 @@ class Upsample(nn.Module): stride=1, padding=1) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): scale_factor = self.scale_factor if isinstance(scale_factor, (int, float)): scale_factor = (scale_factor,) * (x.ndim - 2) if x.ndim == 5 and scale_factor[0] > 1.0: - t = x.shape[2] - if t > 1: - a, b = x.split((1, t - 1), dim=2) - del x - b = interpolate_up(b, scale_factor) - else: - a = x - - a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2) - if t > 1: - x = torch.cat((a, b), dim=2) - else: - x = a + results = [] + if conv_carry_in is None: + first = x[:, :, :1, :, :] + results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)) + x = x[:, :, 1:, :, :] + if x.shape[2] > 0: + results.append(interpolate_up(x, scale_factor)) + x = torch_cat_if_needed(results, dim=2) else: x = interpolate_up(x, scale_factor) if self.with_conv: - x = self.conv(x) + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) return x @@ -664,10 +659,17 @@ class Decoder(nn.Module): self.resolution = resolution self.in_channels = in_channels self.tanh_out = tanh_out + self.carried = False if conv3d: - conv_op = VideoConv3d - conv_out_op = VideoConv3d + if not attn_resolutions and resnet_op == ResnetBlock: + conv_op = CarriedConv3d + conv_out_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d + conv_out_op = VideoConv3d + mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -742,25 +744,43 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h = conv_carry_causal_3d([z], self.conv_in) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + if self.carried: + h = torch.split(h, 2, dim=2) + else: + h = [ h ] + out = [] - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h + conv_carry_in = None + + # upsampling + for i, h1 in enumerate(h): + conv_carry_out = [] + if i == len(h) - 1: + conv_carry_out = None + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs) + if len(self.up[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.up[i_level].attn[i_block](h1, **kwargs) + if i_level != 0: + h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out) + + h1 = self.norm_out(h1) + h1 = [ nonlinearity(h1) ] + h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out) + if self.tanh_out: + h1 = torch.tanh(h1) + out.append(h1) + conv_carry_in = conv_carry_out + + out = torch_cat_if_needed(out, dim=2) + + return out