diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 77e642a94..dfcf9ef27 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -611,6 +611,11 @@ class HunyuanImage21Refiner(LatentFormat): latent_dimensions = 3 scale_factor = 1.03682 +class HunyuanVideo15(LatentFormat): + latent_channels = 32 + latent_dimensions = 3 + scale_factor = 1.03682 + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 5132e6c07..b6ec421fe 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -42,6 +42,8 @@ class HunyuanVideoParams: guidance_embed: bool byt5: bool meanflow: bool + use_cond_type_embedding: bool + vision_in_dim: int class SelfAttentionRef(nn.Module): @@ -196,11 +198,16 @@ class HunyuanVideo(nn.Module): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + params = HunyuanVideoParams(**kwargs) + print("HunyuanVideo params:", params) self.params = params self.patch_size = params.patch_size self.in_channels = params.in_channels self.out_channels = params.out_channels + self.use_cond_type_embedding = params.use_cond_type_embedding + self.vision_in_dim = params.vision_in_dim if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -266,6 +273,18 @@ class HunyuanVideo(nn.Module): if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) + # HunyuanVideo 1.5 specific modules + if self.vision_in_dim is not None: + from comfy.ldm.wan.model import MLPProj # todo move + self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings) + else: + self.vision_in = None + if self.use_cond_type_embedding: + # 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature + self.cond_type_embedding = nn.Embedding(3, self.hidden_size) + else: + self.cond_type_embedding = None + def forward_orig( self, img: Tensor, @@ -337,6 +356,44 @@ class HunyuanVideo(nn.Module): txt = torch.cat((txt, txt_byt5), dim=1) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + # if self.cond_type_embedding is not None: + # self.cond_type_embedding.to(txt.device) + # cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) + # txt = txt + cond_emb.to(txt.dtype) + + # if txt_byt5 is None: + # txt_byt5 = torch.zeros((1, 1000, 1472), device=txt.device, dtype=txt.dtype) + # if self.byt5_in is not None and txt_byt5 is not None: + # txt_byt5 = self.byt5_in(txt_byt5) + # if self.cond_type_embedding is not None: + # cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long)) + # txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype) + # txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + # #txt = torch.cat((txt, txt_byt5), dim=1) + # #txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + # print("txt_byt5 shape:", txt_byt5.shape) + # print("txt shape:", txt.shape) + # txt = torch.cat((txt_byt5, txt), dim=1) + # txt_ids = torch.cat((txt_byt5_ids, txt_ids), dim=1) + + # vision_states = torch.zeros(img.shape[0], 729, self.vision_in_dim, device=img.device, dtype=img.dtype) + # if self.cond_type_embedding is not None: + # extra_encoder_hidden_states = self.vision_in(vision_states) + # extra_encoder_hidden_states = extra_encoder_hidden_states * 0.0 #t2v + # cond_emb = self.cond_type_embedding( + # 2 * torch.ones_like( + # extra_encoder_hidden_states[:, :, 0], + # dtype=torch.long, + # device=extra_encoder_hidden_states.device, + # ) + # ) + # extra_encoder_hidden_states = extra_encoder_hidden_states + cond_emb + # print("extra_encoder_hidden_states shape:", extra_encoder_hidden_states.shape) + # txt = torch.cat((extra_encoder_hidden_states.to(txt.dtype), txt), dim=1) + + # extra_txt_ids = torch.zeros((txt_ids.shape[0], extra_encoder_hidden_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + # txt_ids = torch.cat((extra_txt_ids, txt_ids), dim=1) + ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index c2a0b507d..aab56ca6c 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -220,11 +220,12 @@ class Encoder(nn.Module): if self.refiner_vae: out = self.regul(out)[0] - out = torch.cat((out[:, :, :1], out), dim=2) - out = out.permute(0, 2, 1, 3, 4) - b, f_times_2, c, h, w = out.shape - out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - out = out.permute(0, 2, 1, 3, 4).contiguous() + # todo don't break this + # out = torch.cat((out[:, :, :1], out), dim=2) + # out = out.permute(0, 2, 1, 3, 4) + # b, f_times_2, c, h, w = out.shape + # out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + # out = out.permute(0, 2, 1, 3, 4).contiguous() return out @@ -275,13 +276,15 @@ class Decoder(nn.Module): self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - if self.refiner_vae: - z = z.permute(0, 2, 1, 3, 4) - b, f, c, h, w = z.shape - z = z.reshape(b, f, 2, c // 2, h, w) - z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - z = z.permute(0, 2, 1, 3, 4) - z = z[:, :, 1:] + + # todo don't break this + # if self.refiner_vae: + # z = z.permute(0, 2, 1, 3, 4) + # b, f, c, h, w = z.shape + # z = z.reshape(b, f, 2, c // 2, h, w) + # z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + # z = z.permute(0, 2, 1, 3, 4) + # z = z[:, :, 1:] x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7c788d085..9549693b3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1536,3 +1536,36 @@ class HunyuanImage21Refiner(HunyuanImage21): out = super().extra_conds(**kwargs) out['disable_time_r'] = comfy.conds.CONDConstant(True) return out + +class HunyuanVideo15(HunyuanImage21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1 + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = torch.zeros_like(noise)[:, :1] + mask[:, :, 1:] = 1.0 + + return torch.cat((image, mask), dim=1) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3142a7fc3..f3355da5a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -186,6 +186,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 + + # HunyuanVideo 1.5 + if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["use_cond_type_embedding"] = True + if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: + dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + else: + dit_config["vision_in_dim"] = None return dit_config if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4064bdae1..d9807251b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1373,7 +1373,35 @@ class HunyuanImage21Refiner(HunyuanVideo): def get_model(self, state_dict, prefix="", device=None): out = model_base.HunyuanImage21Refiner(self, device=device) return out + +class HunyuanVideo15(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "patch_size": [1, 1, 1], + "in_channels": 65, + "out_channels": 32, + "depth": 54, + "vision_in_dim": 1152, + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] + sampling_settings = { + "shift": 7.0, + } + memory_usage_factor = 7.7 + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + print("HunyuanVideo15") + out = model_base.HunyuanVideo15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index f7c34d059..6a8dfdb7c 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -57,6 +57,22 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): generate = execute # TODO: remove +class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): + @classmethod + def define_schema(cls): + schema = super().define_schema() + schema.node_id = "EmptyHunyuanVideo15Latent" + return schema + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + # Using scale factor of 16 instead of 8 + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + generate = execute # TODO: remove + + PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " "1. The main content and theme of the video." @@ -210,6 +226,7 @@ class HunyuanExtension(ComfyExtension): CLIPTextEncodeHunyuanDiT, TextEncodeHunyuanVideo_ImageToVideo, EmptyHunyuanLatentVideo, + EmptyHunyuanVideo15Latent, HunyuanImageToVideo, EmptyHunyuanImageLatent, HunyuanRefinerLatent,