vae_refiner: roll the convolution through temporal

Work in progress.

Roll the convolution through time using 2-latent-frame chunks and a
FIFO queue for the convolution seams.
This commit is contained in:
Rattus 2025-11-19 21:15:37 +10:00 committed by comfyanonymous
parent 0aa6eb2edc
commit dc2e308422

View File

@ -6,6 +6,16 @@ import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class SpatialPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, padding=(0, padding, padding), dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@ -83,6 +93,25 @@ class DnSmpl(nn.Module):
return h + sc return h + sc
def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, SpatialPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (0, 0, 0, 0, 2, 0), mode = 'replicate')
else:
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
out = op(x)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
return out
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
@ -94,46 +123,53 @@ class UpSmpl(nn.Module):
self.tus = tus self.tus = tus
self.rp = fct * oc // ic self.rp = fct * oc // ic
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tus else 1 r1 = 2 if self.tus else 1
h = self.conv(x) h = conv_carry([x], self.conv, conv_carry_in, conv_carry_out)
if self.tus and self.refiner_vae: if self.tus and self.refiner_vae:
hf = h[:, :, :1, :, :] if conv_carry_in is None:
b, c, f, ht, wd = hf.shape hf = h[:, :, :1, :, :]
nc = c // (2 * 2) b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, 2, 2, nc, f, ht, wd) nc = c // (2 * 2)
hf = hf.permute(0, 3, 4, 5, 1, 6, 2) hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
hf = hf.reshape(b, nc, f, ht * 2, wd * 2) hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
hf = hf[:, : hf.shape[1] // 2] hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2]
hn = h[:, :, 1:, :, :] h = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
b, c, f, ht, wd = xf.shape
nc = c // (2 * 2)
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
x = x[:, :, 1:, :, :]
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2) nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
h = torch.cat([hf, hn], dim=2) x = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = x.shape
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
b, c, f, ht, wd = xf.shape
nc = c // (2 * 2)
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
xn = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2) nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
if conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
sc = torch.cat([xf, x], dim=2)
else:
sc = x
else: else:
#FIXME: make this work
b, c, frms, ht, wd = h.shape b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2) nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
@ -229,6 +265,23 @@ class Encoder(nn.Module):
return out return out
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=SpatialPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=SpatialPadConv3d, norm_op=RMS_norm)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
@ -240,7 +293,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = VideoConv3d conv_op = SpatialPadConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -250,9 +303,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -260,9 +313,8 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
@ -286,16 +338,29 @@ class Decoder(nn.Module):
# z = z.permute(0, 2, 1, 3, 4) # z = z.permute(0, 2, 1, 3, 4)
# z = z[:, :, 1:] # z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = conv_carry([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up: conv_carry_in = None
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
out = self.conv_out(F.silu(self.norm_out(x))) x = torch.split(x, 2, dim=2)
out = []
for i, x1 in enumerate(x):
conv_carry_out = []
for stage in self.up:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
x1 = [ F.silu(self.norm_out(x1)) ]
x1 = conv_carry(x1, self.conv_out, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
del x
out = torch.cat(out, dim=2)
if not self.refiner_vae: if not self.refiner_vae:
if z.shape[-3] == 1: if z.shape[-3] == 1: