diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index dfcf9ef27..2d0884898 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -612,6 +612,43 @@ class HunyuanImage21Refiner(LatentFormat): scale_factor = 1.03682 class HunyuanVideo15(LatentFormat): + latent_rgb_factors = [ #placeholder values todo: replace with proper ones + [-0.0154, -0.0397, -0.0521], + [ 0.0005, 0.0093, 0.0006], + [-0.0805, -0.0773, -0.0586], + [-0.0494, -0.0487, -0.0498], + [-0.0212, -0.0076, -0.0261], + [-0.0179, -0.0417, -0.0505], + [ 0.0158, 0.0310, 0.0239], + [ 0.0409, 0.0516, 0.0201], + [ 0.0350, 0.0553, 0.0036], + [-0.0447, -0.0327, -0.0479], + [-0.0038, -0.0221, -0.0365], + [-0.0423, -0.0718, -0.0654], + [ 0.0039, 0.0368, 0.0104], + [ 0.0655, 0.0217, 0.0122], + [ 0.0490, 0.1638, 0.2053], + [ 0.0932, 0.0829, 0.0650], + [-0.0186, -0.0209, -0.0135], + [-0.0080, -0.0076, -0.0148], + [-0.0284, -0.0201, 0.0011], + [-0.0642, -0.0294, -0.0777], + [-0.0035, 0.0076, -0.0140], + [ 0.0519, 0.0731, 0.0887], + [-0.0102, 0.0095, 0.0704], + [ 0.0068, 0.0218, -0.0023], + [-0.0726, -0.0486, -0.0519], + [ 0.0260, 0.0295, 0.0263], + [ 0.0250, 0.0333, 0.0341], + [ 0.0168, -0.0120, -0.0174], + [ 0.0226, 0.1037, 0.0114], + [ 0.2577, 0.1906, 0.1604], + [-0.0646, -0.0137, -0.0018], + [-0.0646, -0.0137, -0.0018] + ] + + latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] + latent_channels = 32 latent_dimensions = 3 scale_factor = 1.03682 diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index b6ec421fe..0fdec8659 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -6,7 +6,6 @@ import comfy.ldm.flux.layers import comfy.ldm.modules.diffusionmodules.mmdit from comfy.ldm.modules.attention import optimized_attention - from dataclasses import dataclass from einops import repeat @@ -295,6 +294,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor = None, txt_byt5=None, + vision_states=None, #todo hunyuan video 1.5 vision encoder states input guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, @@ -352,47 +352,24 @@ class HunyuanVideo(nn.Module): if self.byt5_in is not None and txt_byt5 is not None: txt_byt5 = self.byt5_in(txt_byt5) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long)) + txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype) + txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5 + else: + txt = torch.cat((txt, txt_byt5), dim=1) txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) - txt = torch.cat((txt, txt_byt5), dim=1) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) - # if self.cond_type_embedding is not None: - # self.cond_type_embedding.to(txt.device) - # cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) - # txt = txt + cond_emb.to(txt.dtype) - - # if txt_byt5 is None: - # txt_byt5 = torch.zeros((1, 1000, 1472), device=txt.device, dtype=txt.dtype) - # if self.byt5_in is not None and txt_byt5 is not None: - # txt_byt5 = self.byt5_in(txt_byt5) - # if self.cond_type_embedding is not None: - # cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long)) - # txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype) - # txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) - # #txt = torch.cat((txt, txt_byt5), dim=1) - # #txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) - # print("txt_byt5 shape:", txt_byt5.shape) - # print("txt shape:", txt.shape) - # txt = torch.cat((txt_byt5, txt), dim=1) - # txt_ids = torch.cat((txt_byt5_ids, txt_ids), dim=1) - - # vision_states = torch.zeros(img.shape[0], 729, self.vision_in_dim, device=img.device, dtype=img.dtype) - # if self.cond_type_embedding is not None: - # extra_encoder_hidden_states = self.vision_in(vision_states) - # extra_encoder_hidden_states = extra_encoder_hidden_states * 0.0 #t2v - # cond_emb = self.cond_type_embedding( - # 2 * torch.ones_like( - # extra_encoder_hidden_states[:, :, 0], - # dtype=torch.long, - # device=extra_encoder_hidden_states.device, - # ) - # ) - # extra_encoder_hidden_states = extra_encoder_hidden_states + cond_emb - # print("extra_encoder_hidden_states shape:", extra_encoder_hidden_states.shape) - # txt = torch.cat((extra_encoder_hidden_states.to(txt.dtype), txt), dim=1) - - # extra_txt_ids = torch.zeros((txt_ids.shape[0], extra_encoder_hidden_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) - # txt_ids = torch.cat((extra_txt_ids, txt_ids), dim=1) + #todo vision_in + if self.cond_type_embedding is not None and vision_states is not None: + txt_vision_states = self.vision_in(vision_states) + cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device)) + txt_vision_states = txt_vision_states + cond_emb + #print("txt_vision_states shape:", txt_vision_states.shape) + txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1) + extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -487,14 +464,14 @@ class HunyuanVideo(nn.Module): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, vision_states=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, vision_states, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, vision_states=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): bs = x.shape[0] if len(self.patch_size) == 3: img_ids = self.img_ids(x) @@ -502,5 +479,5 @@ class HunyuanVideo(nn.Module): else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, vision_states, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d9807251b..215ebc047 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1373,7 +1373,7 @@ class HunyuanImage21Refiner(HunyuanVideo): def get_model(self, state_dict, prefix="", device=None): out = model_base.HunyuanImage21Refiner(self, device=device) return out - + class HunyuanVideo15(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -1387,7 +1387,7 @@ class HunyuanVideo15(HunyuanVideo): sampling_settings = { "shift": 7.0, } - memory_usage_factor = 7.7 + memory_usage_factor = 4.0 #TODO supported_inference_dtypes = [torch.bfloat16, torch.float32] latent_format = latent_formats.HunyuanVideo15 @@ -1396,7 +1396,7 @@ class HunyuanVideo15(HunyuanVideo): print("HunyuanVideo15") out = model_base.HunyuanVideo15(self, device=device) return out - + def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))