mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
0f85e7d2b0
@ -1 +1 @@
|
||||
__version__ = "0.3.22"
|
||||
__version__ = "0.3.23"
|
||||
|
||||
@ -98,8 +98,12 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
||||
if embeds is not None:
|
||||
x = embeds + ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||
else:
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
@ -117,7 +121,10 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
if num_tokens is not None:
|
||||
pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
|
||||
else:
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
@ -205,6 +212,15 @@ class CLIPVision(torch.nn.Module):
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class LlavaProjector(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||
self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -214,7 +230,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
|
||||
else:
|
||||
self.visual_projection = lambda a: a
|
||||
|
||||
if "llava3" == config_dict.get("projector_type", None):
|
||||
self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
|
||||
else:
|
||||
self.multi_modal_projector = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
out = self.visual_projection(x[2])
|
||||
return (x[0], x[1], out)
|
||||
projected = None
|
||||
if self.multi_modal_projector is not None:
|
||||
projected = self.multi_modal_projector(x[1])
|
||||
|
||||
return (x[0], x[1], out, projected)
|
||||
|
||||
@ -78,6 +78,7 @@ class ClipVisionModel():
|
||||
outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device())
|
||||
outputs["mm_projected"] = out[3]
|
||||
return outputs
|
||||
|
||||
|
||||
@ -119,7 +120,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_siglip_384.json")
|
||||
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json")
|
||||
if "multi_modal_projector.linear_1.bias" in sd:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336_llava.json")
|
||||
else:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
|
||||
else:
|
||||
|
||||
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
19
comfy/clip_vision_config_vitl_336_llava.json
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"attention_dropout": 0.0,
|
||||
"dropout": 0.0,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 336,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"model_type": "clip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"projector_type": "llava3",
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
@ -936,6 +936,12 @@ class HunyuanVideo(BaseModel):
|
||||
return out
|
||||
|
||||
|
||||
class HunyuanVideoI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("concat_image", "mask_inverted")
|
||||
|
||||
|
||||
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
@ -175,71 +175,92 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
||||
embedding_weights = []
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
embeds_out = []
|
||||
attention_masks = []
|
||||
num_tokens = []
|
||||
|
||||
for x in tokens:
|
||||
attention_mask = []
|
||||
tokens_temp = []
|
||||
other_embeds = []
|
||||
eos = False
|
||||
index = 0
|
||||
for y in x:
|
||||
if isinstance(y, numbers.Integral):
|
||||
tokens_temp += [int(y)]
|
||||
else:
|
||||
if y.shape[0] == current_embeds.weight.shape[1]:
|
||||
embedding_weights += [y]
|
||||
tokens_temp += [next_new_token]
|
||||
next_new_token += 1
|
||||
if eos:
|
||||
attention_mask.append(0)
|
||||
else:
|
||||
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
|
||||
while len(tokens_temp) < len(x):
|
||||
tokens_temp += [self.special_tokens["pad"]]
|
||||
out_tokens += [tokens_temp]
|
||||
attention_mask.append(1)
|
||||
token = int(y)
|
||||
tokens_temp += [token]
|
||||
if not eos and token == cmp_token:
|
||||
if end_token is None:
|
||||
attention_mask[-1] = 0
|
||||
eos = True
|
||||
else:
|
||||
other_embeds.append((index, y))
|
||||
index += 1
|
||||
|
||||
n = token_dict_size
|
||||
if len(embedding_weights) > 0:
|
||||
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
||||
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
||||
for x in embedding_weights:
|
||||
new_embedding.weight[n] = x
|
||||
n += 1
|
||||
self.transformer.set_input_embeddings(new_embedding)
|
||||
tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
|
||||
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
||||
index = 0
|
||||
pad_extra = 0
|
||||
for o in other_embeds:
|
||||
emb = o[1]
|
||||
if torch.is_tensor(emb):
|
||||
emb = {"type": "embedding", "data": emb}
|
||||
|
||||
processed_tokens = []
|
||||
for x in out_tokens:
|
||||
processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] # The EOS token should always be the largest one
|
||||
emb_type = emb.get("type", None)
|
||||
if emb_type == "embedding":
|
||||
emb = emb.get("data", None)
|
||||
else:
|
||||
if hasattr(self.transformer, "preprocess_embed"):
|
||||
emb = self.transformer.preprocess_embed(emb, device=device)
|
||||
else:
|
||||
emb = None
|
||||
|
||||
return processed_tokens
|
||||
if emb is None:
|
||||
index += -1
|
||||
continue
|
||||
|
||||
ind = index + o[0]
|
||||
emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
|
||||
emb_shape = emb.shape[1]
|
||||
if emb.shape[-1] == tokens_embed.shape[-1]:
|
||||
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
||||
index += emb_shape - 1
|
||||
else:
|
||||
index += -1
|
||||
pad_extra += emb_shape
|
||||
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
|
||||
|
||||
if pad_extra > 0:
|
||||
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
|
||||
tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
|
||||
|
||||
embeds_out.append(tokens_embed)
|
||||
attention_masks.append(attention_mask)
|
||||
num_tokens.append(sum(attention_mask))
|
||||
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
|
||||
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == cmp_token:
|
||||
if end_token is None:
|
||||
attention_mask[x, y] = 0
|
||||
break
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
||||
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0].float()
|
||||
|
||||
@ -887,6 +887,16 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideoTokenizer, hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||
|
||||
|
||||
class HunyuanVideoI2V(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"in_channels": 33,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideoI2V(self, device=device)
|
||||
return out
|
||||
|
||||
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
@ -1022,6 +1032,6 @@ class WAN21_I2V(WAN21_T2V):
|
||||
return out
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -95,8 +95,11 @@ class BertEmbeddings(torch.nn.Module):
|
||||
|
||||
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, token_type_ids=None, dtype=None):
|
||||
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
||||
def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.word_embeddings(input_tokens, out_dtype=dtype)
|
||||
x += ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
|
||||
if token_type_ids is not None:
|
||||
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
|
||||
@ -115,8 +118,8 @@ class BertModel_(torch.nn.Module):
|
||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embeddings(input_tokens, dtype=dtype)
|
||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import numbers
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from .llama import Llama2
|
||||
@ -25,7 +26,7 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
tokenizer_path = files.get_package_as_path("comfy.text_encoders.llama_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length)
|
||||
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
@ -46,18 +47,26 @@ class HunyuanVideoTokenizer:
|
||||
tokenizer_data = {}
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, llama_template=None, **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
if llama_template is None:
|
||||
llama_text = "{}{}".format(self.llama_template, text)
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = "{}{}".format(llama_template, text)
|
||||
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||
llama_text = llama_template.format(text)
|
||||
llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||
embed_count = 0
|
||||
for r in llama_text_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 128257:
|
||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
out["llama"] = llama_text_tokens
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
@ -93,20 +102,45 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
|
||||
template_end = 0
|
||||
for i, v in enumerate(token_weight_pairs_llama[0]):
|
||||
if v[0] == 128007: # <|end_header_id|>
|
||||
template_end = i
|
||||
image_start = None
|
||||
image_end = None
|
||||
extra_sizes = 0
|
||||
user_end = 9999999999999
|
||||
|
||||
tok_pairs = token_weight_pairs_llama[0]
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 128006:
|
||||
if tok_pairs[i + 1][0] == 882:
|
||||
if tok_pairs[i + 2][0] == 128007:
|
||||
template_end = i + 2
|
||||
user_end = -1
|
||||
if elem == 128009 and user_end == -1:
|
||||
user_end = i + 1
|
||||
else:
|
||||
if elem.get("original_type") == "image":
|
||||
elem_size = elem.get("data").shape[0]
|
||||
if image_start is None:
|
||||
image_start = i + extra_sizes
|
||||
image_end = i + elem_size + extra_sizes
|
||||
extra_sizes += elem_size - 1
|
||||
|
||||
if llama_out.shape[1] > (template_end + 2):
|
||||
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
|
||||
if tok_pairs[template_end + 1][0] == 271:
|
||||
template_end += 2
|
||||
llama_out = llama_out[:, template_end:]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes]
|
||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
if image_start is not None:
|
||||
image_output = llama_out[:, image_start: image_end]
|
||||
llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1)
|
||||
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return llama_out, l_pooled, llama_extra_out
|
||||
return llama_output, l_pooled, llama_extra_out
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
|
||||
@ -245,8 +245,11 @@ class Llama2_(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
@ -251,8 +251,11 @@ class T5(torch.nn.Module):
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.shared = embeddings
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||
def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
|
||||
if input_ids is None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
|
||||
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
x = torch.nan_to_num(x) # Fix for fp8 T5 base
|
||||
return self.encoder(x, *args, **kwargs)
|
||||
return self.encoder(x, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
@ -2,7 +2,8 @@ import torch
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
|
||||
from comfy.nodes import base_nodes as nodes
|
||||
from comfy import node_helpers
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT:
|
||||
@classmethod
|
||||
@ -39,7 +40,73 @@ class EmptyHunyuanLatentVideo:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples":latent}, )
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
class TextEncodeHunyuanVideo_ImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning"
|
||||
|
||||
def encode(self, clip, clip_vision_output, prompt):
|
||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected)
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
|
||||
class HunyuanImageToVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"vae": ("VAE", ),
|
||||
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
},
|
||||
"optional": {"start_image": ("IMAGE", ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "latent")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, positive, vae, width, height, length, batch_size, start_image=None):
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
concat_latent_image = vae.encode(start_image)
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return (positive, out_latent)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user