vae_refiner: roll the convolution through temporal II

Roll the convolution through time using 2-latent-frame chunks and a
FIFO queue for the convolution seams.

Added support for encoder, lowered to 1 latent frame to save more
VRAM, made work for Hunyuan Image 3.0 (as code shared).

Fixed names, cleaned up code.
This commit is contained in:
Rattus 2025-11-19 21:15:37 +10:00 committed by comfyanonymous
parent d8858cb58b
commit 023036ef9d

View File

@ -6,9 +6,8 @@ import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
class SpatialPadConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__() super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
@ -16,6 +15,28 @@ class SpatialPadConv3d(nn.Module):
return self.conv(x) return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@ -37,11 +58,12 @@ class DnSmpl(nn.Module):
self.tds = tds self.tds = tds
self.gs = fct * ic // oc self.gs = fct * ic // oc
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tds else 1 r1 = 2 if self.tds else 1
h = self.conv(x) h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tds and self.refiner_vae and conv_carry_in is None:
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -49,14 +71,7 @@ class DnSmpl(nn.Module):
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1) hf = torch.cat([hf, hf], dim=1)
hn = h[:, :, 1:, :, :] h = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :] xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape b, ci, f, ht, wd = xf.shape
@ -64,20 +79,14 @@ class DnSmpl(nn.Module):
xf = xf.permute(0, 4, 6, 1, 2, 3, 5) xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape B, C, T, H, W = xf.shape
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
x = x[:, :, 1:, :, :]
if h.shape[2] == 0:
return hf + xf
xn = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape b, c, frms, ht, wd = h.shape
nf = frms // r1 nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
@ -85,33 +94,17 @@ class DnSmpl(nn.Module):
b, ci, frms, ht, wd = x.shape b, ci, frms, ht, wd = x.shape
nf = frms // r1 nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape B, C, T, H, W = x.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
return h + sc if self.tds and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None): return h + x
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, SpatialPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 0, 0), mode = 'replicate')
out = op(x)
return out
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
@ -126,10 +119,9 @@ class UpSmpl(nn.Module):
def forward(self, x, conv_carry_in=None, conv_carry_out=None): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tus else 1 r1 = 2 if self.tus else 1
h = conv_carry([x], self.conv, conv_carry_in, conv_carry_out) h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tus and self.refiner_vae: if self.tus and self.refiner_vae and conv_carry_in is None:
if conv_carry_in is None:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
nc = c // (2 * 2) nc = c // (2 * 2)
@ -164,27 +156,28 @@ class UpSmpl(nn.Module):
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3) x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2) x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
if conv_carry_in is None: if self.tus and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2) h = torch.cat([hf, h], dim=2)
sc = torch.cat([xf, x], dim=2) x = torch.cat([xf, x], dim=2)
else:
sc = x
else:
#FIXME: make this work
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = x.repeat_interleave(repeats=self.rp, dim=1) return h + x
b, c, frms, ht, wd = sc.shape
nc = c // (r1 * 2 * 2)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
return h + sc class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
@ -197,7 +190,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = VideoConv3d conv_op = NoPadConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -212,9 +205,8 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
@ -225,9 +217,9 @@ class Encoder(nn.Module):
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch) self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -238,21 +230,46 @@ class Encoder(nn.Module):
if not self.refiner_vae and x.shape[2] == 1: if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x) if self.refiner_vae:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > 1:
xl += torch.split(x[:, :, 1:, :, :], self.ffactor_temporal, dim=2)
x = xl
else:
x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down: for stage in self.down:
for blk in stage.block: for blk in stage.block:
x = blk(x) x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'): if hasattr(stage, 'downsample'):
x = stage.downsample(x) x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) out.append(x1)
conv_carry_in = conv_carry_out
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
b, c, t, h, w = x.shape b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1) grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2) skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
if self.refiner_vae: if self.refiner_vae:
out = self.regul(out)[0] out = self.regul(out)[0]
@ -266,23 +283,6 @@ class Encoder(nn.Module):
return out return out
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=SpatialPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=SpatialPadConv3d, norm_op=RMS_norm)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
@ -294,7 +294,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = SpatialPadConv3d conv_op = NoPadConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -339,16 +339,21 @@ class Decoder(nn.Module):
# z = z.permute(0, 2, 1, 3, 4) # z = z.permute(0, 2, 1, 3, 4)
# z = z[:, :, 1:] # z = z[:, :, 1:]
x = conv_carry([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
if self.refiner_vae:
x = torch.split(x, 1, dim=2)
else:
x = [ x ]
out = []
conv_carry_in = None conv_carry_in = None
x = torch.split(x, 2, dim=2)
out = []
for i, x1 in enumerate(x): for i, x1 in enumerate(x):
conv_carry_out = [] conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
for stage in self.up: for stage in self.up:
for blk in stage.block: for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out) x1 = blk(x1, conv_carry_in, conv_carry_out)
@ -356,15 +361,19 @@ class Decoder(nn.Module):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
x1 = [ F.silu(self.norm_out(x1)) ] x1 = [ F.silu(self.norm_out(x1)) ]
x1 = conv_carry(x1, self.conv_out, conv_carry_in, conv_carry_out) x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
out.append(x1) out.append(x1)
conv_carry_in = conv_carry_out conv_carry_in = conv_carry_out
del x del x
if len(out) > 1:
out = torch.cat(out, dim=2) out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae: if not self.refiner_vae:
if z.shape[-3] == 1: if z.shape[-3] == 1:
out = out[:, :, -1:] out = out[:, :, -1:]
return out return out