diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index c6f742710..c2a0b507d 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize import comfy.ops import comfy.ldm.models.autoencoder ops = comfy.ops.disable_weight_init @@ -17,11 +17,12 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * self.gamma class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True): + def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 - self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) + self.conv = op(ic, oc // fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae self.tds = tds self.gs = fct * ic // oc @@ -30,7 +31,7 @@ class DnSmpl(nn.Module): r1 = 2 if self.tds else 1 h = self.conv(x) - if self.tds: + if self.tds and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) @@ -66,6 +67,7 @@ class DnSmpl(nn.Module): sc = torch.cat([xf, xn], dim=2) else: b, c, frms, ht, wd = h.shape + nf = frms // r1 h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) @@ -83,10 +85,11 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True): + def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 - self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) + self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) + self.refiner_vae = refiner_vae self.tus = tus self.rp = fct * oc // ic @@ -95,7 +98,7 @@ class UpSmpl(nn.Module): r1 = 2 if self.tus else 1 h = self.conv(x) - if self.tus: + if self.tus and self.refiner_vae: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape nc = c // (2 * 2) @@ -148,43 +151,56 @@ class UpSmpl(nn.Module): class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, - ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): + ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): super().__init__() self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks - self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) + self.ffactor_temporal = ffactor_temporal + + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = VideoConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + + self.conv_in = conv_op(in_channels, block_out_channels[0], 3, 1, 1) self.down = nn.ModuleList() ch = block_out_channels[0] depth = (ffactor_spatial >> 1).bit_length() - depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() + depth_temporal = ((ffactor_spatial // self.ffactor_temporal) >> 1).bit_length() for i, tgt in enumerate(block_out_channels): stage = nn.Module() stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, - conv_op=VideoConv3d, norm_op=RMS_norm) + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch - stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) + stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) - self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - self.norm_out = RMS_norm(ch) - self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() def forward(self, x): + if not self.refiner_vae and x.shape[2] == 1: + x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) + x = self.conv_in(x) for stage in self.down: @@ -200,31 +216,42 @@ class Encoder(nn.Module): skip = x.view(b, c // grp, grp, t, h, w).mean(2) out = self.conv_out(F.silu(self.norm_out(x))) + skip - out = self.regul(out)[0] - out = torch.cat((out[:, :, :1], out), dim=2) - out = out.permute(0, 2, 1, 3, 4) - b, f_times_2, c, h, w = out.shape - out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - out = out.permute(0, 2, 1, 3, 4).contiguous() + if self.refiner_vae: + out = self.regul(out)[0] + + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out class Decoder(nn.Module): def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, - ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): + ffactor_spatial, ffactor_temporal, upsample_match_channel=True, refiner_vae=True, **_): super().__init__() block_out_channels = block_out_channels[::-1] self.z_channels = z_channels self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks + self.refiner_vae = refiner_vae + if self.refiner_vae: + conv_op = VideoConv3d + norm_op = RMS_norm + else: + conv_op = ops.Conv3d + norm_op = Normalize + ch = block_out_channels[0] - self.conv_in = VideoConv3d(z_channels, ch, 3) + self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) - self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -235,25 +262,26 @@ class Decoder(nn.Module): stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, out_channels=tgt, temb_channels=0, - conv_op=VideoConv3d, norm_op=RMS_norm) + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch - stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) + stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal, refiner_vae=self.refiner_vae, op=conv_op) ch = nxt self.up.append(stage) - self.norm_out = RMS_norm(ch) - self.conv_out = VideoConv3d(ch, out_channels, 3) + self.norm_out = norm_op(ch) + self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - z = z.permute(0, 2, 1, 3, 4) - b, f, c, h, w = z.shape - z = z.reshape(b, f, 2, c // 2, h, w) - z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - z = z.permute(0, 2, 1, 3, 4) - z = z[:, :, 1:] + if self.refiner_vae: + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) @@ -264,4 +292,10 @@ class Decoder(nn.Module): if hasattr(stage, 'upsample'): x = stage.upsample(x) - return self.conv_out(F.silu(self.norm_out(x))) + out = self.conv_out(F.silu(self.norm_out(x))) + + if not self.refiner_vae: + if z.shape[-3] == 1: + out = out[:, :, -1:] + + return out diff --git a/comfy/sd.py b/comfy/sd.py index 2df340739..873ad20f2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -332,35 +332,51 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64: - ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] - self.downscale_ratio = 32 - self.upscale_ratio = 32 - self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) - - self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif "decoder.conv_in.weight" in sd: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - - if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE - ddconfig['ch_mult'] = [1, 2, 4] - self.downscale_ratio = 4 - self.upscale_ratio = 4 - - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] - if 'post_quant_conv.weight' in sd: - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - else: + if sd['decoder.conv_in.weight'].shape[1] == 64: + ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.downscale_ratio = 32 + self.upscale_ratio = 32 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, - encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) + encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) + elif sd['decoder.conv_in.weight'].shape[1] == 32: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + self.latent_dim = 3 + self.not_video = True + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (2800 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2800 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + else: + #default SD1.x/SD2.x VAE parameters + ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.upscale_ratio = 4 + + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'post_quant_conv.weight' in sd: + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + else: + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: self.first_stage_model = AudioOobleckVAE() self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)