diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f975b5e11..894540879 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat): latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] +class HunyuanImage21Refiner(LatentFormat): + latent_channels = 64 + latent_dimensions = 3 + scale_factor = 1.03682 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 7732182a4..ca86b8bb1 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -278,6 +278,7 @@ class HunyuanVideo(nn.Module): guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, + disable_time_r=False, control=None, transformer_options={}, ) -> Tensor: @@ -288,7 +289,7 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - if self.time_r_in is not None: + if (self.time_r_in is not None) and (not disable_time_r): w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved if len(w) > 0: timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] @@ -428,14 +429,14 @@ class HunyuanVideo(nn.Module): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): bs = x.shape[0] if len(self.patch_size) == 3: img_ids = self.img_ids(x) @@ -443,5 +444,5 @@ class HunyuanVideo(nn.Module): else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py new file mode 100644 index 000000000..e3fff9bbe --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -0,0 +1,268 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d +import comfy.ops +import comfy.ldm.models.autoencoder +ops = comfy.ops.disable_weight_init + +class RMS_norm(nn.Module): + def __init__(self, dim): + super().__init__() + shape = (dim, 1, 1, 1) + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.empty(shape)) + + def forward(self, x): + return F.normalize(x, dim=1) * self.scale * self.gamma + +class DnSmpl(nn.Module): + def __init__(self, ic, oc, tds=True): + super().__init__() + fct = 2 * 2 * 2 if tds else 1 * 2 * 2 + assert oc % fct == 0 + self.conv = VideoConv3d(ic, oc // fct, kernel_size=3) + + self.tds = tds + self.gs = fct * ic // oc + + def forward(self, x): + r1 = 2 if self.tds else 1 + h = self.conv(x) + + if self.tds: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) + hf = hf.permute(0, 4, 6, 1, 2, 3, 5) + hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) + hf = torch.cat([hf, hf], dim=1) + + hn = h[:, :, 1:, :, :] + b, c, frms, ht, wd = hn.shape + nf = frms // r1 + hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6) + hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + + h = torch.cat([hf, hn], dim=2) + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2) + xf = xf.permute(0, 4, 6, 1, 2, 3, 5) + xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) + B, C, T, H, W = xf.shape + xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) + + xn = x[:, :, 1:, :, :] + b, ci, frms, ht, wd = xn.shape + nf = frms // r1 + xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6) + xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = xn.shape + xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + sc = torch.cat([xf, xn], dim=2) + else: + b, c, frms, ht, wd = h.shape + nf = frms // r1 + h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) + h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + + b, ci, frms, ht, wd = x.shape + nf = frms // r1 + sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) + sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = sc.shape + sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + + return h + sc + + +class UpSmpl(nn.Module): + def __init__(self, ic, oc, tus=True): + super().__init__() + fct = 2 * 2 * 2 if tus else 1 * 2 * 2 + self.conv = VideoConv3d(ic, oc * fct, kernel_size=3) + + self.tus = tus + self.rp = fct * oc // ic + + def forward(self, x): + r1 = 2 if self.tus else 1 + h = self.conv(x) + + if self.tus: + hf = h[:, :, :1, :, :] + b, c, f, ht, wd = hf.shape + nc = c // (2 * 2) + hf = hf.reshape(b, 2, 2, nc, f, ht, wd) + hf = hf.permute(0, 3, 4, 5, 1, 6, 2) + hf = hf.reshape(b, nc, f, ht * 2, wd * 2) + hf = hf[:, : hf.shape[1] // 2] + + hn = h[:, :, 1:, :, :] + b, c, frms, ht, wd = hn.shape + nc = c // (r1 * 2 * 2) + hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) + hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) + hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + h = torch.cat([hf, hn], dim=2) + + xf = x[:, :, :1, :, :] + b, ci, f, ht, wd = xf.shape + xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1) + b, c, f, ht, wd = xf.shape + nc = c // (2 * 2) + xf = xf.reshape(b, 2, 2, nc, f, ht, wd) + xf = xf.permute(0, 3, 4, 5, 1, 6, 2) + xf = xf.reshape(b, nc, f, ht * 2, wd * 2) + + xn = x[:, :, 1:, :, :] + xn = xn.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = xn.shape + nc = c // (r1 * 2 * 2) + xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) + xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) + xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) + sc = torch.cat([xf, xn], dim=2) + else: + b, c, frms, ht, wd = h.shape + nc = c // (r1 * 2 * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + sc = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = sc.shape + nc = c // (r1 * 2 * 2) + sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd) + sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3) + sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + return h + sc + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1) + + self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer() + + def forward(self, x): + x = x.unsqueeze(2) + x = self.conv_in(x) + + for stage in self.down: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'downsample'): + x = stage.downsample(x) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + b, c, t, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, t, h, w).mean(2) + + out = self.conv_out(F.silu(self.norm_out(x))) + skip + out = self.regul(out)[0] + + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + ch = block_out_channels[0] + self.conv_in = VideoConv3d(z_channels, ch, 3) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + depth_temporal = (ffactor_temporal >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal) + ch = nxt + self.up.append(stage) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, out_channels, 3) + + def forward(self, z): + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + + x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + for stage in self.up: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'upsample'): + x = stage.upsample(x) + + return self.conv_out(F.silu(self.norm_out(x))) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 13bd6e16b..611d36a1b 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module): z = posterior.mode() return z, None +class EmptyRegularizer(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, None class AbstractAutoencoder(torch.nn.Module): """ diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8f598a848..4245eedca 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -145,7 +145,7 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout=0.0, temb_channels=512, conv_op=ops.Conv2d): + dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -153,7 +153,7 @@ class ResnetBlock(nn.Module): self.use_conv_shortcut = conv_shortcut self.swish = torch.nn.SiLU(inplace=True) - self.norm1 = Normalize(in_channels) + self.norm1 = norm_op(in_channels) self.conv1 = conv_op(in_channels, out_channels, kernel_size=3, @@ -162,7 +162,7 @@ class ResnetBlock(nn.Module): if temb_channels > 0: self.temb_proj = ops.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) + self.norm2 = norm_op(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) self.conv2 = conv_op(out_channels, out_channels, @@ -305,11 +305,11 @@ def vae_attention(): return normal_attention class AttnBlock(nn.Module): - def __init__(self, in_channels, conv_op=ops.Conv2d): + def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize): super().__init__() self.in_channels = in_channels - self.norm = Normalize(in_channels) + self.norm = norm_op(in_channels) self.q = conv_op(in_channels, in_channels, kernel_size=1, diff --git a/comfy/model_base.py b/comfy/model_base.py index 993ff65e6..c69a9d1ad 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1432,3 +1432,23 @@ class HunyuanImage21(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HunyuanImage21Refiner(HunyuanImage21): + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = self.process_latent_in(image) + image = utils.resize_to_batch_size(image, noise.shape[0]) + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = comfy.conds.CONDConstant(True) + return out diff --git a/comfy/sd.py b/comfy/sd.py index 9dd9a74d4..02ddc7239 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -285,6 +285,7 @@ class VAE: self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False + self.not_video = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -409,6 +410,20 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: + ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + self.downscale_ratio = 16 + self.upscale_ratio = 16 + self.latent_dim = 3 + self.not_video = True + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -669,7 +684,7 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) - if self.latent_dim == 3 and pixel_samples.ndim < 5: + if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index aa953b462..ba1b8c313 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1321,6 +1321,23 @@ class HunyuanImage21(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +class HunyuanImage21Refiner(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "patch_size": [1, 1, 1], + "vec_in_dim": None, + } + + sampling_settings = { + "shift": 1.0, + } + + latent_format = latent_formats.HunyuanImage21Refiner + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21Refiner(self, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index ce031ceb2..351a7e2cb 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -128,6 +128,28 @@ class EmptyHunyuanImageLatent: latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) return ({"samples":latent}, ) +class HunyuanRefinerLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent": ("LATENT", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "execute" + + def execute(self, positive, negative, latent): + latent = latent["samples"] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent}) + out_latent = {} + out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, @@ -135,4 +157,5 @@ NODE_CLASS_MAPPINGS = { "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, "EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, + "HunyuanRefinerLatent": HunyuanRefinerLatent, }