This commit is contained in:
kijai 2025-11-13 23:09:05 +02:00 committed by comfyanonymous
parent 10e90a5757
commit cadd00226b
7 changed files with 164 additions and 13 deletions

View File

@ -611,6 +611,11 @@ class HunyuanImage21Refiner(LatentFormat):
latent_dimensions = 3
scale_factor = 1.03682
class HunyuanVideo15(LatentFormat):
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -42,6 +42,8 @@ class HunyuanVideoParams:
guidance_embed: bool
byt5: bool
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
class SelfAttentionRef(nn.Module):
@ -196,11 +198,16 @@ class HunyuanVideo(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
params = HunyuanVideoParams(**kwargs)
print("HunyuanVideo params:", params)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
self.use_cond_type_embedding = params.use_cond_type_embedding
self.vision_in_dim = params.vision_in_dim
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@ -266,6 +273,18 @@ class HunyuanVideo(nn.Module):
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
# HunyuanVideo 1.5 specific modules
if self.vision_in_dim is not None:
from comfy.ldm.wan.model import MLPProj # todo move
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
else:
self.vision_in = None
if self.use_cond_type_embedding:
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
else:
self.cond_type_embedding = None
def forward_orig(
self,
img: Tensor,
@ -337,6 +356,44 @@ class HunyuanVideo(nn.Module):
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
# if self.cond_type_embedding is not None:
# self.cond_type_embedding.to(txt.device)
# cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
# txt = txt + cond_emb.to(txt.dtype)
# if txt_byt5 is None:
# txt_byt5 = torch.zeros((1, 1000, 1472), device=txt.device, dtype=txt.dtype)
# if self.byt5_in is not None and txt_byt5 is not None:
# txt_byt5 = self.byt5_in(txt_byt5)
# if self.cond_type_embedding is not None:
# cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
# txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
# txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
# #txt = torch.cat((txt, txt_byt5), dim=1)
# #txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
# print("txt_byt5 shape:", txt_byt5.shape)
# print("txt shape:", txt.shape)
# txt = torch.cat((txt_byt5, txt), dim=1)
# txt_ids = torch.cat((txt_byt5_ids, txt_ids), dim=1)
# vision_states = torch.zeros(img.shape[0], 729, self.vision_in_dim, device=img.device, dtype=img.dtype)
# if self.cond_type_embedding is not None:
# extra_encoder_hidden_states = self.vision_in(vision_states)
# extra_encoder_hidden_states = extra_encoder_hidden_states * 0.0 #t2v
# cond_emb = self.cond_type_embedding(
# 2 * torch.ones_like(
# extra_encoder_hidden_states[:, :, 0],
# dtype=torch.long,
# device=extra_encoder_hidden_states.device,
# )
# )
# extra_encoder_hidden_states = extra_encoder_hidden_states + cond_emb
# print("extra_encoder_hidden_states shape:", extra_encoder_hidden_states.shape)
# txt = torch.cat((extra_encoder_hidden_states.to(txt.dtype), txt), dim=1)
# extra_txt_ids = torch.zeros((txt_ids.shape[0], extra_encoder_hidden_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
# txt_ids = torch.cat((extra_txt_ids, txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)

View File

@ -220,11 +220,12 @@ class Encoder(nn.Module):
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
# todo don't break this
# out = torch.cat((out[:, :, :1], out), dim=2)
# out = out.permute(0, 2, 1, 3, 4)
# b, f_times_2, c, h, w = out.shape
# out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
# out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
@ -275,13 +276,15 @@ class Decoder(nn.Module):
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
# todo don't break this
# if self.refiner_vae:
# z = z.permute(0, 2, 1, 3, 4)
# b, f, c, h, w = z.shape
# z = z.reshape(b, f, 2, c // 2, h, w)
# z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
# z = z.permute(0, 2, 1, 3, 4)
# z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))

View File

@ -1536,3 +1536,36 @@ class HunyuanImage21Refiner(HunyuanImage21):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(True)
return out
class HunyuanVideo15(HunyuanImage21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], latent_dim):
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = torch.zeros_like(noise)[:, :1]
mask[:, :, 1:] = 1.0
return torch.cat((image, mask), dim=1)

View File

@ -186,6 +186,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
# HunyuanVideo 1.5
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["use_cond_type_embedding"] = True
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
else:
dit_config["vision_in_dim"] = None
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)

View File

@ -1373,7 +1373,35 @@ class HunyuanImage21Refiner(HunyuanVideo):
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanImage21Refiner(self, device=device)
return out
class HunyuanVideo15(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"patch_size": [1, 1, 1],
"in_channels": 65,
"out_channels": 32,
"depth": 54,
"vision_in_dim": 1152,
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
sampling_settings = {
"shift": 7.0,
}
memory_usage_factor = 7.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
print("HunyuanVideo15")
out = model_base.HunyuanVideo15(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -57,6 +57,22 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
generate = execute # TODO: remove
class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
@classmethod
def define_schema(cls):
schema = super().define_schema()
schema.node_id = "EmptyHunyuanVideo15Latent"
return schema
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
# Using scale factor of 16 instead of 8
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
generate = execute # TODO: remove
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
@ -210,6 +226,7 @@ class HunyuanExtension(ComfyExtension):
CLIPTextEncodeHunyuanDiT,
TextEncodeHunyuanVideo_ImageToVideo,
EmptyHunyuanLatentVideo,
EmptyHunyuanVideo15Latent,
HunyuanImageToVideo,
EmptyHunyuanImageLatent,
HunyuanRefinerLatent,