mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 06:40:16 +08:00
Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass.
314 lines
12 KiB
Python
314 lines
12 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
|
import comfy.ops
|
|
import comfy.ldm.models.autoencoder
|
|
import comfy.model_management
|
|
ops = comfy.ops.disable_weight_init
|
|
|
|
|
|
class RMS_norm(nn.Module):
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
shape = (dim, 1, 1, 1)
|
|
self.scale = dim**0.5
|
|
self.gamma = nn.Parameter(torch.empty(shape))
|
|
|
|
def forward(self, x):
|
|
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
|
|
|
class DnSmpl(nn.Module):
|
|
def __init__(self, ic, oc, tds, refiner_vae, op):
|
|
super().__init__()
|
|
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
|
assert oc % fct == 0
|
|
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
|
|
self.refiner_vae = refiner_vae
|
|
|
|
self.tds = tds
|
|
self.gs = fct * ic // oc
|
|
|
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
|
r1 = 2 if self.tds else 1
|
|
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
|
|
|
if self.tds and self.refiner_vae and conv_carry_in is None:
|
|
|
|
hf = h[:, :, :1, :, :]
|
|
b, c, f, ht, wd = hf.shape
|
|
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
|
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
|
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
|
hf = torch.cat([hf, hf], dim=1)
|
|
|
|
h = h[:, :, 1:, :, :]
|
|
|
|
xf = x[:, :, :1, :, :]
|
|
b, ci, f, ht, wd = xf.shape
|
|
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
|
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
|
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
|
B, C, T, H, W = xf.shape
|
|
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
|
|
|
x = x[:, :, 1:, :, :]
|
|
|
|
if h.shape[2] == 0:
|
|
return hf + xf
|
|
|
|
b, c, frms, ht, wd = h.shape
|
|
nf = frms // r1
|
|
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
|
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
|
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
|
|
|
b, ci, frms, ht, wd = x.shape
|
|
nf = frms // r1
|
|
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
|
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
|
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
|
B, C, T, H, W = x.shape
|
|
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
|
|
|
if self.tds and self.refiner_vae and conv_carry_in is None:
|
|
h = torch.cat([hf, h], dim=2)
|
|
x = torch.cat([xf, x], dim=2)
|
|
|
|
return h + x
|
|
|
|
|
|
class UpSmpl(nn.Module):
|
|
def __init__(self, ic, oc, tus, refiner_vae, op):
|
|
super().__init__()
|
|
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
|
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
|
self.refiner_vae = refiner_vae
|
|
|
|
self.tus = tus
|
|
self.rp = fct * oc // ic
|
|
|
|
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
|
r1 = 2 if self.tus else 1
|
|
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
|
|
|
if self.tus and self.refiner_vae and conv_carry_in is None:
|
|
hf = h[:, :, :1, :, :]
|
|
b, c, f, ht, wd = hf.shape
|
|
nc = c // (2 * 2)
|
|
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
|
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
|
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
|
hf = hf[:, : hf.shape[1] // 2]
|
|
|
|
h = h[:, :, 1:, :, :]
|
|
|
|
xf = x[:, :, :1, :, :]
|
|
b, ci, f, ht, wd = xf.shape
|
|
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
|
b, c, f, ht, wd = xf.shape
|
|
nc = c // (2 * 2)
|
|
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
|
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
|
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
|
|
|
x = x[:, :, 1:, :, :]
|
|
|
|
b, c, frms, ht, wd = h.shape
|
|
nc = c // (r1 * 2 * 2)
|
|
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
|
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
|
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
|
|
|
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
|
b, c, frms, ht, wd = x.shape
|
|
nc = c // (r1 * 2 * 2)
|
|
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
|
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
|
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
|
|
|
if self.tus and self.refiner_vae and conv_carry_in is None:
|
|
h = torch.cat([hf, h], dim=2)
|
|
x = torch.cat([xf, x], dim=2)
|
|
|
|
return h + x
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
|
super().__init__()
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
self.ffactor_temporal = ffactor_temporal
|
|
|
|
self.refiner_vae = refiner_vae
|
|
if self.refiner_vae:
|
|
conv_op = CarriedConv3d
|
|
norm_op = RMS_norm
|
|
else:
|
|
conv_op = ops.Conv3d
|
|
norm_op = Normalize
|
|
|
|
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
|
|
|
|
self.down = nn.ModuleList()
|
|
ch = block_out_channels[0]
|
|
depth = (ffactor_spatial >> 1).bit_length()
|
|
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_op=conv_op, norm_op=norm_op)
|
|
for j in range(num_res_blocks)])
|
|
ch = tgt
|
|
if i < depth:
|
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
|
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
|
ch = nxt
|
|
self.down.append(stage)
|
|
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
|
|
self.norm_out = norm_op(ch)
|
|
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
|
|
|
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
|
|
|
def forward(self, x):
|
|
if not self.refiner_vae and x.shape[2] == 1:
|
|
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
|
|
|
if self.refiner_vae:
|
|
xl = [x[:, :, :1, :, :]]
|
|
if x.shape[2] > self.ffactor_temporal:
|
|
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
|
x = xl
|
|
else:
|
|
x = [x]
|
|
out = []
|
|
|
|
conv_carry_in = None
|
|
|
|
for i, x1 in enumerate(x):
|
|
conv_carry_out = []
|
|
if i == len(x) - 1:
|
|
conv_carry_out = None
|
|
|
|
x1 = [ x1 ]
|
|
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
|
|
|
for stage in self.down:
|
|
for blk in stage.block:
|
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
|
if hasattr(stage, 'downsample'):
|
|
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
|
|
|
out.append(x1)
|
|
conv_carry_in = conv_carry_out
|
|
|
|
out = torch_cat_if_needed(out, dim=2)
|
|
|
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
|
del out
|
|
|
|
b, c, t, h, w = x.shape
|
|
grp = c // (self.z_channels << 1)
|
|
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
|
|
|
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
|
|
|
if self.refiner_vae:
|
|
out = self.regul(out)[0]
|
|
|
|
return out
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
|
super().__init__()
|
|
block_out_channels = block_out_channels[::-1]
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
self.refiner_vae = refiner_vae
|
|
if self.refiner_vae:
|
|
conv_op = CarriedConv3d
|
|
norm_op = RMS_norm
|
|
else:
|
|
conv_op = ops.Conv3d
|
|
norm_op = Normalize
|
|
|
|
ch = block_out_channels[0]
|
|
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
|
|
|
self.up = nn.ModuleList()
|
|
depth = (ffactor_spatial >> 1).bit_length()
|
|
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_op=conv_op, norm_op=norm_op)
|
|
for j in range(num_res_blocks + 1)])
|
|
ch = tgt
|
|
if i < depth:
|
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
|
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
|
ch = nxt
|
|
self.up.append(stage)
|
|
|
|
self.norm_out = norm_op(ch)
|
|
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
|
|
|
def forward(self, z):
|
|
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
|
|
|
if self.refiner_vae:
|
|
x = torch.split(x, 2, dim=2)
|
|
else:
|
|
x = [ x ]
|
|
out = []
|
|
|
|
conv_carry_in = None
|
|
|
|
for i, x1 in enumerate(x):
|
|
conv_carry_out = []
|
|
if i == len(x) - 1:
|
|
conv_carry_out = None
|
|
for stage in self.up:
|
|
for blk in stage.block:
|
|
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
|
if hasattr(stage, 'upsample'):
|
|
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
|
|
|
x1 = [ F.silu(self.norm_out(x1)) ]
|
|
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
|
out.append(x1)
|
|
conv_carry_in = conv_carry_out
|
|
del x
|
|
|
|
out = torch_cat_if_needed(out, dim=2)
|
|
|
|
if not self.refiner_vae:
|
|
if z.shape[-3] == 1:
|
|
out = out[:, :, -1:]
|
|
|
|
return out
|
|
|