mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-10 19:57:42 +08:00
HunyuanImage2.1: Fix refiner template
This commit is contained in:
parent
0836853fec
commit
192b74ccc1
@ -836,7 +836,7 @@ class CLIPType(Enum):
|
|||||||
OMNIGEN2 = 17
|
OMNIGEN2 = 17
|
||||||
QWEN_IMAGE = 18
|
QWEN_IMAGE = 18
|
||||||
HUNYUAN_IMAGE = 19
|
HUNYUAN_IMAGE = 19
|
||||||
|
HUNYUAN_IMAGE_REFINER = 20
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
@ -995,6 +995,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||||
|
elif clip_type == CLIPType.HUNYUAN_IMAGE_REFINER:
|
||||||
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, refiner=True, **llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageRefinerTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
|||||||
from transformers import ByT5Tokenizer
|
from transformers import ByT5Tokenizer
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import torch
|
||||||
|
import numbers
|
||||||
|
|
||||||
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
|
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -38,6 +40,13 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
|||||||
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
|
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HunyuanImageRefinerTokenizer(HunyuanImageTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.llama_template = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||||
@ -53,9 +62,9 @@ class ByT5SmallModel(sd1_clip.SDClipModel):
|
|||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
|
|
||||||
class HunyuanImageTEModel(QwenImageTEModel):
|
class HunyuanImageTEModel(sd1_clip.SD1ClipModel):
|
||||||
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
|
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
|
||||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||||
|
|
||||||
if byt5:
|
if byt5:
|
||||||
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
|
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
|
||||||
@ -63,11 +72,35 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
|||||||
self.byt5_small = None
|
self.byt5_small = None
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
cond, p, extra = super().encode_token_weights(token_weight_pairs)
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||||
|
count_im_start = 0
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 151644 and count_im_start < 2:
|
||||||
|
template_end = i
|
||||||
|
count_im_start += 1
|
||||||
|
|
||||||
|
if out.shape[1] > (template_end + 3):
|
||||||
|
if tok_pairs[template_end + 1][0] == 872:
|
||||||
|
if tok_pairs[template_end + 2][0] == 198:
|
||||||
|
template_end += 3
|
||||||
|
|
||||||
|
out = out[:, template_end:]
|
||||||
|
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||||
|
extra.pop("attention_mask") # attention mask is useless if no masked elements
|
||||||
|
# noqa: W293
|
||||||
|
|
||||||
if self.byt5_small is not None and "byt5" in token_weight_pairs:
|
if self.byt5_small is not None and "byt5" in token_weight_pairs:
|
||||||
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
|
byt5_out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
|
||||||
extra["conditioning_byt5small"] = out[0]
|
extra["conditioning_byt5small"] = byt5_out[0]
|
||||||
return cond, p, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
super().set_clip_options(options)
|
super().set_clip_options(options)
|
||||||
@ -84,9 +117,33 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
|||||||
return self.byt5_small.load_sd(sd)
|
return self.byt5_small.load_sd(sd)
|
||||||
else:
|
else:
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
class HunyuanImageRefinerTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 6171:
|
||||||
|
template_end = i
|
||||||
|
break
|
||||||
|
|
||||||
|
out = out[:, template_end-1:]
|
||||||
|
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, template_end-1:]
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||||
|
extra.pop("attention_mask") # attention mask is useless if no masked elements
|
||||||
|
|
||||||
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
|
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None, refiner=False):
|
||||||
|
class HunyuanImageTEModel_(HunyuanImageTEModel):
|
||||||
|
|
||||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
|
||||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
@ -94,4 +151,14 @@ def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
|||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return QwenImageTEModel_
|
class HunyuanImageTEModel_refiner(HunyuanImageRefinerTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
assert refiner, "refiner must be True"
|
||||||
|
assert not byt5, "byt5 must be False"
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return HunyuanImageTEModel_refiner if refiner else HunyuanImageTEModel_
|
||||||
|
|||||||
@ -199,9 +199,8 @@ class CLIPTextEncodeHunyuanDiTWithTextDetection:
|
|||||||
"clip": ("CLIP", ),
|
"clip": ("CLIP", ),
|
||||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
}}
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT", "STRING")
|
||||||
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
|
RETURN_NAMES = ("conditioning", "has_quoted_text", "text")
|
||||||
RETURN_NAMES = ("conditioning", "has_quoted_text")
|
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/hunyuan"
|
CATEGORY = "advanced/conditioning/hunyuan"
|
||||||
@ -237,7 +236,35 @@ class CLIPTextEncodeHunyuanDiTWithTextDetection:
|
|||||||
n[1]['has_quoted_text'] = has_quoted_text
|
n[1]['has_quoted_text'] = has_quoted_text
|
||||||
c.append(n)
|
c.append(n)
|
||||||
|
|
||||||
return (c, has_quoted_text)
|
return (c, has_quoted_text, text)
|
||||||
|
|
||||||
|
class CLIPTextEncodeHunyuanImageRefiner:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"text": ("STRING", ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
RETURN_NAMES = ("conditioning",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/hunyuan"
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, clip, text):
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
|
||||||
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
|
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
c.append(n)
|
||||||
|
|
||||||
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyHunyuanLatentVideo:
|
class EmptyHunyuanLatentVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -370,13 +397,13 @@ class HunyuanRefinerLatent:
|
|||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
|
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
|
||||||
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"HunyuanMixModeAPG": HunyuanMixModeAPG,
|
"HunyuanMixModeAPG": HunyuanMixModeAPG,
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
|
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
|
||||||
|
"CLIPTextEncodeHunyuanImageRefiner": CLIPTextEncodeHunyuanImageRefiner,
|
||||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -929,7 +929,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image","hunyuan_image_refiner"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user