mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
init
This commit is contained in:
parent
10e90a5757
commit
cadd00226b
@ -611,6 +611,11 @@ class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -42,6 +42,8 @@ class HunyuanVideoParams:
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@ -196,11 +198,16 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
print("HunyuanVideo params:", params)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@ -266,6 +273,18 @@ class HunyuanVideo(nn.Module):
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj # todo move
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@ -337,6 +356,44 @@ class HunyuanVideo(nn.Module):
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
# if self.cond_type_embedding is not None:
|
||||
# self.cond_type_embedding.to(txt.device)
|
||||
# cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
# txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
# if txt_byt5 is None:
|
||||
# txt_byt5 = torch.zeros((1, 1000, 1472), device=txt.device, dtype=txt.dtype)
|
||||
# if self.byt5_in is not None and txt_byt5 is not None:
|
||||
# txt_byt5 = self.byt5_in(txt_byt5)
|
||||
# if self.cond_type_embedding is not None:
|
||||
# cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
# txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
# txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
# #txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
# #txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
# print("txt_byt5 shape:", txt_byt5.shape)
|
||||
# print("txt shape:", txt.shape)
|
||||
# txt = torch.cat((txt_byt5, txt), dim=1)
|
||||
# txt_ids = torch.cat((txt_byt5_ids, txt_ids), dim=1)
|
||||
|
||||
# vision_states = torch.zeros(img.shape[0], 729, self.vision_in_dim, device=img.device, dtype=img.dtype)
|
||||
# if self.cond_type_embedding is not None:
|
||||
# extra_encoder_hidden_states = self.vision_in(vision_states)
|
||||
# extra_encoder_hidden_states = extra_encoder_hidden_states * 0.0 #t2v
|
||||
# cond_emb = self.cond_type_embedding(
|
||||
# 2 * torch.ones_like(
|
||||
# extra_encoder_hidden_states[:, :, 0],
|
||||
# dtype=torch.long,
|
||||
# device=extra_encoder_hidden_states.device,
|
||||
# )
|
||||
# )
|
||||
# extra_encoder_hidden_states = extra_encoder_hidden_states + cond_emb
|
||||
# print("extra_encoder_hidden_states shape:", extra_encoder_hidden_states.shape)
|
||||
# txt = torch.cat((extra_encoder_hidden_states.to(txt.dtype), txt), dim=1)
|
||||
|
||||
# extra_txt_ids = torch.zeros((txt_ids.shape[0], extra_encoder_hidden_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
# txt_ids = torch.cat((extra_txt_ids, txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
|
||||
@ -220,11 +220,12 @@ class Encoder(nn.Module):
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
# todo don't break this
|
||||
# out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
# out = out.permute(0, 2, 1, 3, 4)
|
||||
# b, f_times_2, c, h, w = out.shape
|
||||
# out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
# out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
return out
|
||||
|
||||
@ -275,13 +276,15 @@ class Decoder(nn.Module):
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
if self.refiner_vae:
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
|
||||
# todo don't break this
|
||||
# if self.refiner_vae:
|
||||
# z = z.permute(0, 2, 1, 3, 4)
|
||||
# b, f, c, h, w = z.shape
|
||||
# z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
# z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
# z = z.permute(0, 2, 1, 3, 4)
|
||||
# z = z[:, :, 1:]
|
||||
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
@ -1536,3 +1536,36 @@ class HunyuanImage21Refiner(HunyuanImage21):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||
return out
|
||||
|
||||
class HunyuanVideo15(HunyuanImage21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
mask[:, :, 1:] = 1.0
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
@ -186,6 +186,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
|
||||
# HunyuanVideo 1.5
|
||||
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_cond_type_embedding"] = True
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
|
||||
@ -1373,7 +1373,35 @@ class HunyuanImage21Refiner(HunyuanVideo):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"patch_size": [1, 1, 1],
|
||||
"in_channels": 65,
|
||||
"out_channels": 32,
|
||||
"depth": 54,
|
||||
"vision_in_dim": 1152,
|
||||
}
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
sampling_settings = {
|
||||
"shift": 7.0,
|
||||
}
|
||||
memory_usage_factor = 7.7
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
print("HunyuanVideo15")
|
||||
out = model_base.HunyuanVideo15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -57,6 +57,22 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "EmptyHunyuanVideo15Latent"
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
# Using scale factor of 16 instead of 8
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
@ -210,6 +226,7 @@ class HunyuanExtension(ComfyExtension):
|
||||
CLIPTextEncodeHunyuanDiT,
|
||||
TextEncodeHunyuanVideo_ImageToVideo,
|
||||
EmptyHunyuanLatentVideo,
|
||||
EmptyHunyuanVideo15Latent,
|
||||
HunyuanImageToVideo,
|
||||
EmptyHunyuanImageLatent,
|
||||
HunyuanRefinerLatent,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user