diff --git a/comfy/__init__.py b/comfy/__init__.py index c3aa137ea..db42feda8 100644 --- a/comfy/__init__.py +++ b/comfy/__init__.py @@ -1 +1 @@ -__version__ = "0.3.22" +__version__ = "0.3.23" diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 58cf8f338..074bacca7 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -98,8 +98,12 @@ class CLIPTextModel_(torch.nn.Module): self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): - x = self.embeddings(input_tokens, dtype=dtype) + def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + if embeds is not None: + x = embeds + ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) + else: + x = self.embeddings(input_tokens, dtype=dtype) + mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) @@ -117,7 +121,10 @@ class CLIPTextModel_(torch.nn.Module): if i is not None and final_layer_norm_intermediate: i = self.final_layer_norm(i) - pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] + if num_tokens is not None: + pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))] + else: + pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] return x, i, pooled_output class CLIPTextModel(torch.nn.Module): @@ -205,6 +212,15 @@ class CLIPVision(torch.nn.Module): pooled_output = self.post_layernorm(x[:, 0, :]) return x, i, pooled_output +class LlavaProjector(torch.nn.Module): + def __init__(self, in_dim, out_dim, dtype, device, operations): + super().__init__() + self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype) + self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:]))) + class CLIPVisionModelProjection(torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -214,7 +230,16 @@ class CLIPVisionModelProjection(torch.nn.Module): else: self.visual_projection = lambda a: a + if "llava3" == config_dict.get("projector_type", None): + self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations) + else: + self.multi_modal_projector = None + def forward(self, *args, **kwargs): x = self.vision_model(*args, **kwargs) out = self.visual_projection(x[2]) - return (x[0], x[1], out) + projected = None + if self.multi_modal_projector is not None: + projected = self.multi_modal_projector(x[1]) + + return (x[0], x[1], out, projected) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 070421be8..1c9289524 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -78,6 +78,7 @@ class ClipVisionModel(): outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(model_management.intermediate_device()) outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device()) + outputs["mm_projected"] = out[3] return outputs @@ -119,7 +120,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: json_config = files.get_path_as_dict(None, "clip_vision_siglip_384.json") elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: - json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json") + if "multi_modal_projector.linear_1.bias" in sd: + json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336_llava.json") + else: + json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json") else: json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json") else: diff --git a/comfy/clip_vision_config_vitl_336_llava.json b/comfy/clip_vision_config_vitl_336_llava.json new file mode 100644 index 000000000..f23a50d8b --- /dev/null +++ b/comfy/clip_vision_config_vitl_336_llava.json @@ -0,0 +1,19 @@ +{ + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "image_size": 336, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-5, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "projector_type": "llava3", + "torch_dtype": "float32" +} diff --git a/comfy/model_base.py b/comfy/model_base.py index 0a628f5b4..5cbf324c0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -936,6 +936,12 @@ class HunyuanVideo(BaseModel): return out +class HunyuanVideoI2V(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + self.concat_keys = ("concat_image", "mask_inverted") + + class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 9450a3710..c4e7902e5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -175,71 +175,92 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] - def set_up_textual_embeddings(self, tokens, current_embeds): - out_tokens = [] - next_new_token = token_dict_size = current_embeds.weight.shape[0] - embedding_weights = [] + def process_tokens(self, tokens, device): + end_token = self.special_tokens.get("end", None) + if end_token is None: + cmp_token = self.special_tokens.get("pad", -1) + else: + cmp_token = end_token + + embeds_out = [] + attention_masks = [] + num_tokens = [] for x in tokens: + attention_mask = [] tokens_temp = [] + other_embeds = [] + eos = False + index = 0 for y in x: if isinstance(y, numbers.Integral): - tokens_temp += [int(y)] - else: - if y.shape[0] == current_embeds.weight.shape[1]: - embedding_weights += [y] - tokens_temp += [next_new_token] - next_new_token += 1 + if eos: + attention_mask.append(0) else: - logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1])) - while len(tokens_temp) < len(x): - tokens_temp += [self.special_tokens["pad"]] - out_tokens += [tokens_temp] + attention_mask.append(1) + token = int(y) + tokens_temp += [token] + if not eos and token == cmp_token: + if end_token is None: + attention_mask[-1] = 0 + eos = True + else: + other_embeds.append((index, y)) + index += 1 - n = token_dict_size - if len(embedding_weights) > 0: - new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) - new_embedding.weight[:token_dict_size] = current_embeds.weight - for x in embedding_weights: - new_embedding.weight[n] = x - n += 1 - self.transformer.set_input_embeddings(new_embedding) + tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long) + tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) + index = 0 + pad_extra = 0 + for o in other_embeds: + emb = o[1] + if torch.is_tensor(emb): + emb = {"type": "embedding", "data": emb} - processed_tokens = [] - for x in out_tokens: - processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one + emb_type = emb.get("type", None) + if emb_type == "embedding": + emb = emb.get("data", None) + else: + if hasattr(self.transformer, "preprocess_embed"): + emb = self.transformer.preprocess_embed(emb, device=device) + else: + emb = None - return processed_tokens + if emb is None: + index += -1 + continue + + ind = index + o[0] + emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32) + emb_shape = emb.shape[1] + if emb.shape[-1] == tokens_embed.shape[-1]: + tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) + attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] + index += emb_shape - 1 + else: + index += -1 + pad_extra += emb_shape + logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1])) + + if pad_extra > 0: + padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32) + tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1) + + embeds_out.append(tokens_embed) + attention_masks.append(attention_mask) + num_tokens.append(sum(attention_mask)) + + return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens def forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.tensor(tokens, dtype=torch.long).to(device) - - attention_mask = None - if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: - attention_mask = torch.zeros_like(tokens) - end_token = self.special_tokens.get("end", None) - if end_token is None: - cmp_token = self.special_tokens.get("pad", -1) - else: - cmp_token = end_token - - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == cmp_token: - if end_token is None: - attention_mask[x, y] = 0 - break + device = self.transformer.get_input_embeddings().weight.device + embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) - self.transformer.set_input_embeddings(backup_embeds) + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2c984857e..f19223e1e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -887,6 +887,16 @@ class HunyuanVideo(supported_models_base.BASE): return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideoTokenizer, hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) +class HunyuanVideoI2V(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "in_channels": 33, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideoI2V(self, device=device) + return out + class HunyuanVideoSkyreelsI2V(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -1022,6 +1032,6 @@ class WAN21_I2V(WAN21_T2V): return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] models += [SVD_img2vid] diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index c4c12b071..2113bbbe0 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -95,8 +95,11 @@ class BertEmbeddings(torch.nn.Module): self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, input_tokens, token_type_ids=None, dtype=None): - x = self.word_embeddings(input_tokens, out_dtype=dtype) + def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None): + if embeds is not None: + x = embeds + else: + x = self.word_embeddings(input_tokens, out_dtype=dtype) x += ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x) if token_type_ids is not None: x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype) @@ -115,8 +118,8 @@ class BertModel_(torch.nn.Module): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): - x = self.embeddings(input_tokens, dtype=dtype) + def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 675e9bcbc..112e349ae 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -1,4 +1,5 @@ import torch +import numbers from transformers import LlamaTokenizerFast from .llama import Llama2 @@ -25,7 +26,7 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer): if tokenizer_data is None: tokenizer_data = {} tokenizer_path = files.get_package_as_path("comfy.text_encoders.llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) class LLAMAModel(sd1_clip.SDClipModel): @@ -46,18 +47,26 @@ class HunyuanVideoTokenizer: tokenizer_data = {} clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens + self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) - def tokenize_with_weights(self, text: str, return_word_ids=False, llama_template=None, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs): out = {} out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) if llama_template is None: - llama_text = "{}{}".format(self.llama_template, text) + llama_text = self.llama_template.format(text) else: - llama_text = "{}{}".format(llama_template, text) - out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids) + llama_text = llama_template.format(text) + llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids) + embed_count = 0 + for r in llama_text_tokens: + for i in range(len(r)): + if r[i][0] == 128257: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + out["llama"] = llama_text_tokens return out def untokenize(self, token_weight_pair): @@ -93,20 +102,45 @@ class HunyuanVideoClipModel(torch.nn.Module): llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) template_end = 0 - for i, v in enumerate(token_weight_pairs_llama[0]): - if v[0] == 128007: # <|end_header_id|> - template_end = i + image_start = None + image_end = None + extra_sizes = 0 + user_end = 9999999999999 + + tok_pairs = token_weight_pairs_llama[0] + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 128006: + if tok_pairs[i + 1][0] == 882: + if tok_pairs[i + 2][0] == 128007: + template_end = i + 2 + user_end = -1 + if elem == 128009 and user_end == -1: + user_end = i + 1 + else: + if elem.get("original_type") == "image": + elem_size = elem.get("data").shape[0] + if image_start is None: + image_start = i + extra_sizes + image_end = i + elem_size + extra_sizes + extra_sizes += elem_size - 1 if llama_out.shape[1] > (template_end + 2): - if token_weight_pairs_llama[0][template_end + 1][0] == 271: + if tok_pairs[template_end + 1][0] == 271: template_end += 2 - llama_out = llama_out[:, template_end:] - llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:] + llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes] + llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes] if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements + if image_start is not None: + image_output = llama_out[:, image_start: image_end] + llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) - return llama_out, l_pooled, llama_extra_out + return llama_output, l_pooled, llama_extra_out def load_sd(self, sd): if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 49ace1ee4..c887bf9a4 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -245,8 +245,11 @@ class Llama2_(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): - x = self.embed_tokens(x, out_dtype=dtype) + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) if self.normalize_in: x *= self.config.hidden_size ** 0.5 diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index dd5da6ebc..f87a68db3 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -251,8 +251,11 @@ class T5(torch.nn.Module): def set_input_embeddings(self, embeddings): self.shared = embeddings - def forward(self, input_ids, *args, **kwargs): - x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) + def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs): + if input_ids is None: + x = embeds + else: + x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]: x = torch.nan_to_num(x) # Fix for fp8 T5 base - return self.encoder(x, *args, **kwargs) + return self.encoder(x, attention_mask=attention_mask, **kwargs) diff --git a/comfy_extras/nodes/nodes_hunyuan.py b/comfy_extras/nodes/nodes_hunyuan.py index 5d286b290..493ac9f08 100644 --- a/comfy_extras/nodes/nodes_hunyuan.py +++ b/comfy_extras/nodes/nodes_hunyuan.py @@ -2,7 +2,8 @@ import torch import comfy.model_management from comfy.nodes.common import MAX_RESOLUTION - +from comfy.nodes import base_nodes as nodes +from comfy import node_helpers class CLIPTextEncodeHunyuanDiT: @classmethod @@ -39,7 +40,73 @@ class EmptyHunyuanLatentVideo: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) return ({"samples":latent}, ) +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +class TextEncodeHunyuanVideo_ImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, clip_vision_output, prompt): + tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected) + return (clip.encode_from_tokens_scheduled(tokens), ) + + +class HunyuanImageToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, vae, width, height, length, batch_size, start_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + concat_latent_image = vae.encode(start_image) + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, out_latent) + + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, + "HunyuanImageToVideo": HunyuanImageToVideo, } diff --git a/setup.py b/setup.py index 28d8bfce7..db52385f5 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ package_name = "comfyui" """ The current version. """ -version = "0.3.22" +version = "0.3.23" """ The package index to the torch built with AMD ROCm.