mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
vae_refiner: roll the convolution through temporal
Work in progress. Roll the convolution through time using 2-latent-frame chunks and a FIFO queue for the convolution seams.
This commit is contained in:
parent
0aa6eb2edc
commit
dc2e308422
@ -6,6 +6,16 @@ import comfy.ops
|
|||||||
import comfy.ldm.models.autoencoder
|
import comfy.ldm.models.autoencoder
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialPadConv3d(nn.Module):
|
||||||
|
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, padding=(0, padding, padding), dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
class RMS_norm(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -83,6 +93,25 @@ class DnSmpl(nn.Module):
|
|||||||
|
|
||||||
return h + sc
|
return h + sc
|
||||||
|
|
||||||
|
def conv_carry(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||||
|
|
||||||
|
x = xl[0]
|
||||||
|
xl.clear()
|
||||||
|
|
||||||
|
if isinstance(op, SpatialPadConv3d):
|
||||||
|
if conv_carry_in is None:
|
||||||
|
x = torch.nn.functional.pad(x, (0, 0, 0, 0, 2, 0), mode = 'replicate')
|
||||||
|
else:
|
||||||
|
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||||
|
|
||||||
|
out = op(x)
|
||||||
|
|
||||||
|
if conv_carry_out is not None:
|
||||||
|
to_push = x[:, :, -2:, :, :].clone()
|
||||||
|
conv_carry_out.append(to_push)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class UpSmpl(nn.Module):
|
class UpSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
||||||
@ -94,46 +123,53 @@ class UpSmpl(nn.Module):
|
|||||||
self.tus = tus
|
self.tus = tus
|
||||||
self.rp = fct * oc // ic
|
self.rp = fct * oc // ic
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
r1 = 2 if self.tus else 1
|
r1 = 2 if self.tus else 1
|
||||||
h = self.conv(x)
|
h = conv_carry([x], self.conv, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
if self.tus and self.refiner_vae:
|
if self.tus and self.refiner_vae:
|
||||||
hf = h[:, :, :1, :, :]
|
if conv_carry_in is None:
|
||||||
b, c, f, ht, wd = hf.shape
|
hf = h[:, :, :1, :, :]
|
||||||
nc = c // (2 * 2)
|
b, c, f, ht, wd = hf.shape
|
||||||
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
nc = c // (2 * 2)
|
||||||
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||||
hf = hf[:, : hf.shape[1] // 2]
|
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||||
|
hf = hf[:, : hf.shape[1] // 2]
|
||||||
|
|
||||||
hn = h[:, :, 1:, :, :]
|
h = h[:, :, 1:, :, :]
|
||||||
b, c, frms, ht, wd = hn.shape
|
|
||||||
|
xf = x[:, :, :1, :, :]
|
||||||
|
b, ci, f, ht, wd = xf.shape
|
||||||
|
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
||||||
|
b, c, f, ht, wd = xf.shape
|
||||||
|
nc = c // (2 * 2)
|
||||||
|
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||||
|
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||||
|
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
|
||||||
|
b, c, frms, ht, wd = h.shape
|
||||||
nc = c // (r1 * 2 * 2)
|
nc = c // (r1 * 2 * 2)
|
||||||
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
h = torch.cat([hf, hn], dim=2)
|
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||||
|
b, c, frms, ht, wd = x.shape
|
||||||
xf = x[:, :, :1, :, :]
|
|
||||||
b, ci, f, ht, wd = xf.shape
|
|
||||||
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
|
||||||
b, c, f, ht, wd = xf.shape
|
|
||||||
nc = c // (2 * 2)
|
|
||||||
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
|
||||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
|
||||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
|
||||||
|
|
||||||
xn = x[:, :, 1:, :, :]
|
|
||||||
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
|
||||||
b, c, frms, ht, wd = xn.shape
|
|
||||||
nc = c // (r1 * 2 * 2)
|
nc = c // (r1 * 2 * 2)
|
||||||
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
sc = torch.cat([xf, xn], dim=2)
|
|
||||||
|
if conv_carry_in is None:
|
||||||
|
h = torch.cat([hf, h], dim=2)
|
||||||
|
sc = torch.cat([xf, x], dim=2)
|
||||||
|
else:
|
||||||
|
sc = x
|
||||||
else:
|
else:
|
||||||
|
#FIXME: make this work
|
||||||
b, c, frms, ht, wd = h.shape
|
b, c, frms, ht, wd = h.shape
|
||||||
nc = c // (r1 * 2 * 2)
|
nc = c // (r1 * 2 * 2)
|
||||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
@ -229,6 +265,23 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HunyuanRefinerResnetBlock(ResnetBlock):
|
||||||
|
def __init__(self, in_channels, out_channels, conv_op=SpatialPadConv3d, norm_op=RMS_norm):
|
||||||
|
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=SpatialPadConv3d, norm_op=RMS_norm)
|
||||||
|
|
||||||
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||||
|
h = x
|
||||||
|
h = [ self.swish(self.norm1(x)) ]
|
||||||
|
h = conv_carry(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
|
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
||||||
|
h = conv_carry(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x+h
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
||||||
@ -240,7 +293,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.refiner_vae = refiner_vae
|
self.refiner_vae = refiner_vae
|
||||||
if self.refiner_vae:
|
if self.refiner_vae:
|
||||||
conv_op = VideoConv3d
|
conv_op = SpatialPadConv3d
|
||||||
norm_op = RMS_norm
|
norm_op = RMS_norm
|
||||||
else:
|
else:
|
||||||
conv_op = ops.Conv3d
|
conv_op = ops.Conv3d
|
||||||
@ -250,9 +303,9 @@ class Decoder(nn.Module):
|
|||||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.up = nn.ModuleList()
|
self.up = nn.ModuleList()
|
||||||
depth = (ffactor_spatial >> 1).bit_length()
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
@ -260,9 +313,8 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
for i, tgt in enumerate(block_out_channels):
|
for i, tgt in enumerate(block_out_channels):
|
||||||
stage = nn.Module()
|
stage = nn.Module()
|
||||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
|
||||||
conv_op=conv_op, norm_op=norm_op)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks + 1)])
|
for j in range(num_res_blocks + 1)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
@ -286,16 +338,29 @@ class Decoder(nn.Module):
|
|||||||
# z = z.permute(0, 2, 1, 3, 4)
|
# z = z.permute(0, 2, 1, 3, 4)
|
||||||
# z = z[:, :, 1:]
|
# z = z[:, :, 1:]
|
||||||
|
|
||||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
x = conv_carry([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
for stage in self.up:
|
conv_carry_in = None
|
||||||
for blk in stage.block:
|
|
||||||
x = blk(x)
|
|
||||||
if hasattr(stage, 'upsample'):
|
|
||||||
x = stage.upsample(x)
|
|
||||||
|
|
||||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
x = torch.split(x, 2, dim=2)
|
||||||
|
out = []
|
||||||
|
|
||||||
|
for i, x1 in enumerate(x):
|
||||||
|
conv_carry_out = []
|
||||||
|
for stage in self.up:
|
||||||
|
for blk in stage.block:
|
||||||
|
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||||
|
if hasattr(stage, 'upsample'):
|
||||||
|
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||||
|
|
||||||
|
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||||
|
x1 = conv_carry(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||||
|
out.append(x1)
|
||||||
|
conv_carry_in = conv_carry_out
|
||||||
|
del x
|
||||||
|
|
||||||
|
out = torch.cat(out, dim=2)
|
||||||
|
|
||||||
if not self.refiner_vae:
|
if not self.refiner_vae:
|
||||||
if z.shape[-3] == 1:
|
if z.shape[-3] == 1:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user