From 119fc04459049452ba5c485e2b8b634852c03867 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sat, 29 Nov 2025 09:21:54 +1000 Subject: [PATCH] move refiner VAE temporal roller to core Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass. --- comfy/ldm/hunyuan_video/vae_refiner.py | 94 +++++---------------- comfy/ldm/modules/diffusionmodules/model.py | 49 +++++++++-- 2 files changed, 65 insertions(+), 78 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index 9f750dcc4..ddf77cd0e 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,42 +1,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed import comfy.ops import comfy.ldm.models.autoencoder import comfy.model_management ops = comfy.ops.disable_weight_init -class NoPadConv3d(nn.Module): - def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): - super().__init__() - self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) - - def forward(self, x): - return self.conv(x) - - -def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): - - x = xl[0] - xl.clear() - - if conv_carry_out is not None: - to_push = x[:, :, -2:, :, :].clone() - conv_carry_out.append(to_push) - - if isinstance(op, NoPadConv3d): - if conv_carry_in is None: - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') - else: - carry_len = conv_carry_in[0].shape[2] - x = torch.cat([conv_carry_in.pop(0), x], dim=2) - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') - - out = op(x) - - return out - class RMS_norm(nn.Module): def __init__(self, dim): @@ -49,7 +19,7 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tds, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 @@ -109,7 +79,7 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tus, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) @@ -163,23 +133,6 @@ class UpSmpl(nn.Module): return h + x -class HunyuanRefinerResnetBlock(ResnetBlock): - def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm): - super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - - def forward(self, x, conv_carry_in=None, conv_carry_out=None): - h = x - h = [ self.swish(self.norm1(x)) ] - h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - h = [ self.dropout(self.swish(self.norm2(h))) ] - h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - - return x+h - class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): @@ -191,7 +144,7 @@ class Encoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -206,9 +159,10 @@ class Encoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: @@ -218,9 +172,9 @@ class Encoder(nn.Module): self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -246,22 +200,20 @@ class Encoder(nn.Module): conv_carry_out = [] if i == len(x) - 1: conv_carry_out = None + x1 = [ x1 ] x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) for stage in self.down: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'downsample'): x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) out.append(x1) conv_carry_in = conv_carry_out - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) del out @@ -288,7 +240,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -298,9 +250,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -308,9 +260,10 @@ class Decoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: @@ -340,7 +293,7 @@ class Decoder(nn.Module): conv_carry_out = None for stage in self.up: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'upsample'): x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) @@ -350,10 +303,7 @@ class Decoder(nn.Module): conv_carry_in = conv_carry_out del x - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 47bddf9b2..47713d0d8 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops +def torch_cat_if_needed(xl, dim): + if len(xl) > 1: + return torch.cat(xl, dim) + else: + return xl[0] + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32): return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) +class CarriedConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if isinstance(op, CarriedConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + out = op(x) + + return out + + class VideoConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() @@ -183,23 +220,23 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb=None): + def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None): h = x h = self.norm1(h) - h = self.swish(h) - h = self.conv1(h) + h = [ self.swish(h) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if temb is not None: h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) h = self.swish(h) - h = self.dropout(h) - h = self.conv2(h) + h = [ self.dropout(h) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) else: x = self.nin_shortcut(x)