From dc2e308422ddbf2e6b23c1facae8e249d194bb0f Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 19 Nov 2025 21:15:37 +1000 Subject: [PATCH] vae_refiner: roll the convolution through temporal Work in progress. Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams. --- comfy/ldm/hunyuan_video/vae_refiner.py | 153 ++++++++++++++++++------- 1 file changed, 109 insertions(+), 44 deletions(-) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index aab56ca6c..e656e5996 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -6,6 +6,16 @@ import comfy.ops import comfy.ldm.models.autoencoder ops = comfy.ops.disable_weight_init + +class SpatialPadConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, padding=(0, padding, padding), dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + class RMS_norm(nn.Module): def __init__(self, dim): super().__init__() @@ -83,6 +93,25 @@ class DnSmpl(nn.Module): return h + sc +def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if isinstance(op, SpatialPadConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (0, 0, 0, 0, 2, 0), mode = 'replicate') + else: + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + out = op(x) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + return out + class UpSmpl(nn.Module): def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): @@ -94,46 +123,53 @@ class UpSmpl(nn.Module): self.tus = tus self.rp = fct * oc // ic - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tus else 1 - h = self.conv(x) + h = conv_carry([x], self.conv, conv_carry_in, conv_carry_out) if self.tus and self.refiner_vae: - hf = h[:, :, :1, :, :] - b, c, f, ht, wd = hf.shape - nc = c // (2 * 2) - hf = hf.reshape(b, 2, 2, nc, f, ht, wd) - hf = hf.permute(0, 3, 4, 5, 1, 6, 2) - hf = hf.reshape(b, nc, f, ht * 2, wd * 2) - hf = hf[:, : hf.shape[1] // 2] + if conv_carry_in is None: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + nc = c // (2 * 2) + hf = hf.reshape(b, 2, 2, nc, f, ht, wd) + hf = hf.permute(0, 3, 4, 5, 1, 6, 2) + hf = hf.reshape(b, nc, f, ht * 2, wd * 2) + hf = hf[:, : hf.shape[1] // 2] - hn = h[:, :, 1:, :, :] - b, c, frms, ht, wd = hn.shape + h = h[:, :, 1:, :, :] + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1) + b, c, f, ht, wd = xf.shape + nc = c // (2 * 2) + xf = xf.reshape(b, 2, 2, nc, f, ht, wd) + xf = xf.permute(0, 3, 4, 5, 1, 6, 2) + xf = xf.reshape(b, nc, f, ht * 2, wd * 2) + + x = x[:, :, 1:, :, :] + + b, c, frms, ht, wd = h.shape nc = c // (r1 * 2 * 2) - hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) - hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) - h = torch.cat([hf, hn], dim=2) - - xf = x[:, :, :1, :, :] - b, ci, f, ht, wd = xf.shape - xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1) - b, c, f, ht, wd = xf.shape - nc = c // (2 * 2) - xf = xf.reshape(b, 2, 2, nc, f, ht, wd) - xf = xf.permute(0, 3, 4, 5, 1, 6, 2) - xf = xf.reshape(b, nc, f, ht * 2, wd * 2) - - xn = x[:, :, 1:, :, :] - xn = xn.repeat_interleave(repeats=self.rp, dim=1) - b, c, frms, ht, wd = xn.shape + x = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = x.shape nc = c // (r1 * 2 * 2) - xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) - xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) - sc = torch.cat([xf, xn], dim=2) + x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd) + x = x.permute(0, 4, 5, 1, 6, 2, 7, 3) + x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + if conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + sc = torch.cat([xf, x], dim=2) + else: + sc = x else: + #FIXME: make this work b, c, frms, ht, wd = h.shape nc = c // (r1 * 2 * 2) h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) @@ -229,6 +265,23 @@ class Encoder(nn.Module): return out +class HunyuanRefinerResnetBlock(ResnetBlock): + def __init__(self, in_channels, out_channels, conv_op=SpatialPadConv3d, norm_op=RMS_norm): + super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=SpatialPadConv3d, norm_op=RMS_norm) + + def forward(self, x, conv_carry_in=None, conv_carry_out=None): + h = x + h = [ self.swish(self.norm1(x)) ] + h = conv_carry(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) + + h = [ self.dropout(self.swish(self.norm2(h))) ] + h = conv_carry(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x+h + class Decoder(nn.Module): def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): @@ -240,7 +293,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = VideoConv3d + conv_op = SpatialPadConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -250,9 +303,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -260,9 +313,8 @@ class Decoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, - temb_channels=0, conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt @@ -286,16 +338,29 @@ class Decoder(nn.Module): # z = z.permute(0, 2, 1, 3, 4) # z = z[:, :, 1:] - x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = conv_carry([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) - for stage in self.up: - for blk in stage.block: - x = blk(x) - if hasattr(stage, 'upsample'): - x = stage.upsample(x) + conv_carry_in = None - out = self.conv_out(F.silu(self.norm_out(x))) + x = torch.split(x, 2, dim=2) + out = [] + + for i, x1 in enumerate(x): + conv_carry_out = [] + for stage in self.up: + for blk in stage.block: + x1 = blk(x1, conv_carry_in, conv_carry_out) + if hasattr(stage, 'upsample'): + x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) + + x1 = [ F.silu(self.norm_out(x1)) ] + x1 = conv_carry(x1, self.conv_out, conv_carry_in, conv_carry_out) + out.append(x1) + conv_carry_in = conv_carry_out + del x + + out = torch.cat(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: