mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Support the new hunyuan vae. (#10150)
This commit is contained in:
parent
e4f99b479a
commit
a6f83a4a1a
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.ldm.models.autoencoder
|
import comfy.ldm.models.autoencoder
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@ -17,11 +17,12 @@ class RMS_norm(nn.Module):
|
|||||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||||
|
|
||||||
class DnSmpl(nn.Module):
|
class DnSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tds=True):
|
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||||
assert oc % fct == 0
|
assert oc % fct == 0
|
||||||
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
|
self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.refiner_vae = refiner_vae
|
||||||
|
|
||||||
self.tds = tds
|
self.tds = tds
|
||||||
self.gs = fct * ic // oc
|
self.gs = fct * ic // oc
|
||||||
@ -30,7 +31,7 @@ class DnSmpl(nn.Module):
|
|||||||
r1 = 2 if self.tds else 1
|
r1 = 2 if self.tds else 1
|
||||||
h = self.conv(x)
|
h = self.conv(x)
|
||||||
|
|
||||||
if self.tds:
|
if self.tds and self.refiner_vae:
|
||||||
hf = h[:, :, :1, :, :]
|
hf = h[:, :, :1, :, :]
|
||||||
b, c, f, ht, wd = hf.shape
|
b, c, f, ht, wd = hf.shape
|
||||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||||
@ -66,6 +67,7 @@ class DnSmpl(nn.Module):
|
|||||||
sc = torch.cat([xf, xn], dim=2)
|
sc = torch.cat([xf, xn], dim=2)
|
||||||
else:
|
else:
|
||||||
b, c, frms, ht, wd = h.shape
|
b, c, frms, ht, wd = h.shape
|
||||||
|
|
||||||
nf = frms // r1
|
nf = frms // r1
|
||||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
@ -83,10 +85,11 @@ class DnSmpl(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class UpSmpl(nn.Module):
|
class UpSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tus=True):
|
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||||
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
|
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.refiner_vae = refiner_vae
|
||||||
|
|
||||||
self.tus = tus
|
self.tus = tus
|
||||||
self.rp = fct * oc // ic
|
self.rp = fct * oc // ic
|
||||||
@ -95,7 +98,7 @@ class UpSmpl(nn.Module):
|
|||||||
r1 = 2 if self.tus else 1
|
r1 = 2 if self.tus else 1
|
||||||
h = self.conv(x)
|
h = self.conv(x)
|
||||||
|
|
||||||
if self.tus:
|
if self.tus and self.refiner_vae:
|
||||||
hf = h[:, :, :1, :, :]
|
hf = h[:, :, :1, :, :]
|
||||||
b, c, f, ht, wd = hf.shape
|
b, c, f, ht, wd = hf.shape
|
||||||
nc = c // (2 * 2)
|
nc = c // (2 * 2)
|
||||||
@ -148,43 +151,56 @@ class UpSmpl(nn.Module):
|
|||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.z_channels = z_channels
|
self.z_channels = z_channels
|
||||||
self.block_out_channels = block_out_channels
|
self.block_out_channels = block_out_channels
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
|
self.ffactor_temporal = ffactor_temporal
|
||||||
|
|
||||||
|
self.refiner_vae = refiner_vae
|
||||||
|
if self.refiner_vae:
|
||||||
|
conv_op = VideoConv3d
|
||||||
|
norm_op = RMS_norm
|
||||||
|
else:
|
||||||
|
conv_op = ops.Conv3d
|
||||||
|
norm_op = Normalize
|
||||||
|
|
||||||
|
self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1)
|
||||||
|
|
||||||
self.down = nn.ModuleList()
|
self.down = nn.ModuleList()
|
||||||
ch = block_out_channels[0]
|
ch = block_out_channels[0]
|
||||||
depth = (ffactor_spatial >> 1).bit_length()
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
|
depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length()
|
||||||
|
|
||||||
for i, tgt in enumerate(block_out_channels):
|
for i, tgt in enumerate(block_out_channels):
|
||||||
stage = nn.Module()
|
stage = nn.Module()
|
||||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
temb_channels=0,
|
||||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks)])
|
for j in range(num_res_blocks)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
if i < depth:
|
if i < depth:
|
||||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||||
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
|
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||||
ch = nxt
|
ch = nxt
|
||||||
self.down.append(stage)
|
self.down.append(stage)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.norm_out = RMS_norm(ch)
|
self.norm_out = norm_op(ch)
|
||||||
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||||
|
|
||||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if not self.refiner_vae and x.shape[2] == 1:
|
||||||
|
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||||
|
|
||||||
x = self.conv_in(x)
|
x = self.conv_in(x)
|
||||||
|
|
||||||
for stage in self.down:
|
for stage in self.down:
|
||||||
@ -200,31 +216,42 @@ class Encoder(nn.Module):
|
|||||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||||
|
|
||||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||||
out = self.regul(out)[0]
|
|
||||||
|
|
||||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
if self.refiner_vae:
|
||||||
out = out.permute(0, 2, 1, 3, 4)
|
out = self.regul(out)[0]
|
||||||
b, f_times_2, c, h, w = out.shape
|
|
||||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
out = out.permute(0, 2, 1, 3, 4)
|
||||||
|
b, f_times_2, c, h, w = out.shape
|
||||||
|
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||||
|
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
block_out_channels = block_out_channels[::-1]
|
block_out_channels = block_out_channels[::-1]
|
||||||
self.z_channels = z_channels
|
self.z_channels = z_channels
|
||||||
self.block_out_channels = block_out_channels
|
self.block_out_channels = block_out_channels
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = num_res_blocks
|
||||||
|
|
||||||
|
self.refiner_vae = refiner_vae
|
||||||
|
if self.refiner_vae:
|
||||||
|
conv_op = VideoConv3d
|
||||||
|
norm_op = RMS_norm
|
||||||
|
else:
|
||||||
|
conv_op = ops.Conv3d
|
||||||
|
norm_op = Normalize
|
||||||
|
|
||||||
ch = block_out_channels[0]
|
ch = block_out_channels[0]
|
||||||
self.conv_in = VideoConv3d(z_channels, ch, 3)
|
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||||
|
|
||||||
self.up = nn.ModuleList()
|
self.up = nn.ModuleList()
|
||||||
depth = (ffactor_spatial >> 1).bit_length()
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
@ -235,25 +262,26 @@ class Decoder(nn.Module):
|
|||||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
temb_channels=0,
|
||||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
conv_op=conv_op, norm_op=norm_op)
|
||||||
for j in range(num_res_blocks + 1)])
|
for j in range(num_res_blocks + 1)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
if i < depth:
|
if i < depth:
|
||||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||||
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
|
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op)
|
||||||
ch = nxt
|
ch = nxt
|
||||||
self.up.append(stage)
|
self.up.append(stage)
|
||||||
|
|
||||||
self.norm_out = RMS_norm(ch)
|
self.norm_out = norm_op(ch)
|
||||||
self.conv_out = VideoConv3d(ch, out_channels, 3)
|
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
z = z.permute(0, 2, 1, 3, 4)
|
if self.refiner_vae:
|
||||||
b, f, c, h, w = z.shape
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
b, f, c, h, w = z.shape
|
||||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||||
z = z.permute(0, 2, 1, 3, 4)
|
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||||
z = z[:, :, 1:]
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
|
z = z[:, :, 1:]
|
||||||
|
|
||||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
@ -264,4 +292,10 @@ class Decoder(nn.Module):
|
|||||||
if hasattr(stage, 'upsample'):
|
if hasattr(stage, 'upsample'):
|
||||||
x = stage.upsample(x)
|
x = stage.upsample(x)
|
||||||
|
|
||||||
return self.conv_out(F.silu(self.norm_out(x)))
|
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||||
|
|
||||||
|
if not self.refiner_vae:
|
||||||
|
if z.shape[-3] == 1:
|
||||||
|
out = out[:, :, -1:]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
70
comfy/sd.py
70
comfy/sd.py
@ -332,35 +332,51 @@ class VAE:
|
|||||||
self.first_stage_model = StageC_coder()
|
self.first_stage_model = StageC_coder()
|
||||||
self.downscale_ratio = 32
|
self.downscale_ratio = 32
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
|
|
||||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
|
||||||
self.downscale_ratio = 32
|
|
||||||
self.upscale_ratio = 32
|
|
||||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
|
||||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
|
||||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
|
|
||||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
|
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
|
||||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
|
||||||
|
|
||||||
elif "decoder.conv_in.weight" in sd:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||||
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
self.downscale_ratio = 32
|
||||||
ddconfig['ch_mult'] = [1, 2, 4]
|
self.upscale_ratio = 32
|
||||||
self.downscale_ratio = 4
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
self.upscale_ratio = 4
|
|
||||||
|
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
|
||||||
if 'post_quant_conv.weight' in sd:
|
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
|
||||||
else:
|
|
||||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
|
||||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
|
||||||
|
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||||
|
elif sd['decoder.conv_in.weight'].shape[1] == 32:
|
||||||
|
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
||||||
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||||
|
self.upscale_index_formula = (4, 16, 16)
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||||
|
self.downscale_index_formula = (4, 16, 16)
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.not_video = True
|
||||||
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
|
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||||
|
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||||
|
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||||
|
else:
|
||||||
|
#default SD1.x/SD2.x VAE parameters
|
||||||
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
|
|
||||||
|
if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||||
|
ddconfig['ch_mult'] = [1, 2, 4]
|
||||||
|
self.downscale_ratio = 4
|
||||||
|
self.upscale_ratio = 4
|
||||||
|
|
||||||
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||||
|
if 'post_quant_conv.weight' in sd:
|
||||||
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
|
else:
|
||||||
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||||
|
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||||
self.first_stage_model = AudioOobleckVAE()
|
self.first_stage_model = AudioOobleckVAE()
|
||||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user