This commit is contained in:
kijai 2025-11-16 00:44:51 +02:00 committed by comfyanonymous
parent cadd00226b
commit 4f242de56f
3 changed files with 60 additions and 46 deletions

View File

@ -612,6 +612,43 @@ class HunyuanImage21Refiner(LatentFormat):
scale_factor = 1.03682 scale_factor = 1.03682
class HunyuanVideo15(LatentFormat): class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [ #placeholder values todo: replace with proper ones
[-0.0154, -0.0397, -0.0521],
[ 0.0005, 0.0093, 0.0006],
[-0.0805, -0.0773, -0.0586],
[-0.0494, -0.0487, -0.0498],
[-0.0212, -0.0076, -0.0261],
[-0.0179, -0.0417, -0.0505],
[ 0.0158, 0.0310, 0.0239],
[ 0.0409, 0.0516, 0.0201],
[ 0.0350, 0.0553, 0.0036],
[-0.0447, -0.0327, -0.0479],
[-0.0038, -0.0221, -0.0365],
[-0.0423, -0.0718, -0.0654],
[ 0.0039, 0.0368, 0.0104],
[ 0.0655, 0.0217, 0.0122],
[ 0.0490, 0.1638, 0.2053],
[ 0.0932, 0.0829, 0.0650],
[-0.0186, -0.0209, -0.0135],
[-0.0080, -0.0076, -0.0148],
[-0.0284, -0.0201, 0.0011],
[-0.0642, -0.0294, -0.0777],
[-0.0035, 0.0076, -0.0140],
[ 0.0519, 0.0731, 0.0887],
[-0.0102, 0.0095, 0.0704],
[ 0.0068, 0.0218, -0.0023],
[-0.0726, -0.0486, -0.0519],
[ 0.0260, 0.0295, 0.0263],
[ 0.0250, 0.0333, 0.0341],
[ 0.0168, -0.0120, -0.0174],
[ 0.0226, 0.1037, 0.0114],
[ 0.2577, 0.1906, 0.1604],
[-0.0646, -0.0137, -0.0018],
[-0.0646, -0.0137, -0.0018]
]
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
latent_channels = 32 latent_channels = 32
latent_dimensions = 3 latent_dimensions = 3
scale_factor = 1.03682 scale_factor = 1.03682

View File

@ -6,7 +6,6 @@ import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass from dataclasses import dataclass
from einops import repeat from einops import repeat
@ -295,6 +294,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor, timesteps: Tensor,
y: Tensor = None, y: Tensor = None,
txt_byt5=None, txt_byt5=None,
vision_states=None, #todo hunyuan video 1.5 vision encoder states input
guidance: Tensor = None, guidance: Tensor = None,
guiding_frame_index=None, guiding_frame_index=None,
ref_latent=None, ref_latent=None,
@ -352,47 +352,24 @@ class HunyuanVideo(nn.Module):
if self.byt5_in is not None and txt_byt5 is not None: if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5) txt_byt5 = self.byt5_in(txt_byt5)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
else:
txt = torch.cat((txt, txt_byt5), dim=1)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
# if self.cond_type_embedding is not None: #todo vision_in
# self.cond_type_embedding.to(txt.device) if self.cond_type_embedding is not None and vision_states is not None:
# cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) txt_vision_states = self.vision_in(vision_states)
# txt = txt + cond_emb.to(txt.dtype) cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
# if txt_byt5 is None: #print("txt_vision_states shape:", txt_vision_states.shape)
# txt_byt5 = torch.zeros((1, 1000, 1472), device=txt.device, dtype=txt.dtype) txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
# if self.byt5_in is not None and txt_byt5 is not None: extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
# txt_byt5 = self.byt5_in(txt_byt5) txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
# if self.cond_type_embedding is not None:
# cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
# txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
# txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
# #txt = torch.cat((txt, txt_byt5), dim=1)
# #txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
# print("txt_byt5 shape:", txt_byt5.shape)
# print("txt shape:", txt.shape)
# txt = torch.cat((txt_byt5, txt), dim=1)
# txt_ids = torch.cat((txt_byt5_ids, txt_ids), dim=1)
# vision_states = torch.zeros(img.shape[0], 729, self.vision_in_dim, device=img.device, dtype=img.dtype)
# if self.cond_type_embedding is not None:
# extra_encoder_hidden_states = self.vision_in(vision_states)
# extra_encoder_hidden_states = extra_encoder_hidden_states * 0.0 #t2v
# cond_emb = self.cond_type_embedding(
# 2 * torch.ones_like(
# extra_encoder_hidden_states[:, :, 0],
# dtype=torch.long,
# device=extra_encoder_hidden_states.device,
# )
# )
# extra_encoder_hidden_states = extra_encoder_hidden_states + cond_emb
# print("extra_encoder_hidden_states shape:", extra_encoder_hidden_states.shape)
# txt = torch.cat((extra_encoder_hidden_states.to(txt.dtype), txt), dim=1)
# extra_txt_ids = torch.zeros((txt_ids.shape[0], extra_encoder_hidden_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
# txt_ids = torch.cat((extra_txt_ids, txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1) ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
@ -487,14 +464,14 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs) return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y=None, txt_byt5=None, vision_states=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
self, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) ).execute(x, timestep, context, y, txt_byt5, vision_states, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): def _forward(self, x, timestep, context, y=None, txt_byt5=None, vision_states=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0] bs = x.shape[0]
if len(self.patch_size) == 3: if len(self.patch_size) == 3:
img_ids = self.img_ids(x) img_ids = self.img_ids(x)
@ -502,5 +479,5 @@ class HunyuanVideo(nn.Module):
else: else:
img_ids = self.img_ids_2d(x) img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, vision_states, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out return out

View File

@ -1373,7 +1373,7 @@ class HunyuanImage21Refiner(HunyuanVideo):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanImage21Refiner(self, device=device) out = model_base.HunyuanImage21Refiner(self, device=device)
return out return out
class HunyuanVideo15(HunyuanVideo): class HunyuanVideo15(HunyuanVideo):
unet_config = { unet_config = {
"image_model": "hunyuan_video", "image_model": "hunyuan_video",
@ -1387,7 +1387,7 @@ class HunyuanVideo15(HunyuanVideo):
sampling_settings = { sampling_settings = {
"shift": 7.0, "shift": 7.0,
} }
memory_usage_factor = 7.7 memory_usage_factor = 4.0 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15 latent_format = latent_formats.HunyuanVideo15
@ -1396,7 +1396,7 @@ class HunyuanVideo15(HunyuanVideo):
print("HunyuanVideo15") print("HunyuanVideo15")
out = model_base.HunyuanVideo15(self, device=device) out = model_base.HunyuanVideo15(self, device=device)
return out return out
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))