mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 08:52:34 +08:00
Merge 780b7d7d28 into cd66d72b46
This commit is contained in:
commit
8df9de3e38
@ -836,7 +836,7 @@ class CLIPType(Enum):
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
|
||||
HUNYUAN_IMAGE_REFINER = 20
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
@ -995,6 +995,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_IMAGE_REFINER:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, refiner=True, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageRefinerTokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
|
||||
@ -4,6 +4,8 @@ from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
from transformers import ByT5Tokenizer
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import numbers
|
||||
|
||||
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@ -38,6 +40,13 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
||||
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
class HunyuanImageRefinerTokenizer(HunyuanImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||
@ -53,9 +62,9 @@ class ByT5SmallModel(sd1_clip.SDClipModel):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
|
||||
class HunyuanImageTEModel(QwenImageTEModel):
|
||||
class HunyuanImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
|
||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
|
||||
if byt5:
|
||||
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
|
||||
@ -63,11 +72,35 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
||||
self.byt5_small = None
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs)
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||
count_im_start = 0
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 151644 and count_im_start < 2:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
|
||||
if out.shape[1] > (template_end + 3):
|
||||
if tok_pairs[template_end + 1][0] == 872:
|
||||
if tok_pairs[template_end + 2][0] == 198:
|
||||
template_end += 3
|
||||
|
||||
out = out[:, template_end:]
|
||||
|
||||
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||
extra.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
# noqa: W293
|
||||
|
||||
if self.byt5_small is not None and "byt5" in token_weight_pairs:
|
||||
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
|
||||
extra["conditioning_byt5small"] = out[0]
|
||||
return cond, p, extra
|
||||
byt5_out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
|
||||
extra["conditioning_byt5small"] = byt5_out[0]
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
|
||||
def set_clip_options(self, options):
|
||||
super().set_clip_options(options)
|
||||
@ -84,9 +117,33 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
||||
return self.byt5_small.load_sd(sd)
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
class HunyuanImageRefinerTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 6171:
|
||||
template_end = i
|
||||
break
|
||||
|
||||
out = out[:, template_end-1:]
|
||||
|
||||
extra["attention_mask"] = extra["attention_mask"][:, template_end-1:]
|
||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||
extra.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None, refiner=False):
|
||||
class HunyuanImageTEModel_(HunyuanImageTEModel):
|
||||
|
||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
@ -94,4 +151,14 @@ def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||
return QwenImageTEModel_
|
||||
class HunyuanImageTEModel_refiner(HunyuanImageRefinerTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
assert refiner, "refiner must be True"
|
||||
assert not byt5, "byt5 must be False"
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return HunyuanImageTEModel_refiner if refiner else HunyuanImageTEModel_
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
import math
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import re
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
@ -23,6 +26,249 @@ class CLIPTextEncodeHunyuanDiT:
|
||||
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def normalized_guidance_apg(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
|
||||
class AdaptiveProjectedGuidance:
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum=None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
# eta: float = 1.0,
|
||||
eta: float = 0.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, is_first_step=False) -> torch.Tensor:
|
||||
|
||||
if is_first_step and self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
|
||||
pred = normalized_guidance_apg(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
return pred
|
||||
|
||||
class HunyuanMixModeAPG:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", ),
|
||||
"has_quoted_text": ("BOOLEAN", ),
|
||||
|
||||
"guidance_scale": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
||||
|
||||
"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
|
||||
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
|
||||
"general_start_percent": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of general APG."}),
|
||||
|
||||
"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
|
||||
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
|
||||
"ocr_start_percent": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of OCR APG."}),
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "apply_mix_mode_apg"
|
||||
CATEGORY = "sampling/custom_sampling/hunyuan"
|
||||
|
||||
|
||||
def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_percent,
|
||||
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_percent):
|
||||
|
||||
general_apg = AdaptiveProjectedGuidance(
|
||||
guidance_scale=guidance_scale,
|
||||
eta=general_eta,
|
||||
adaptive_projected_guidance_rescale=general_norm_threshold,
|
||||
adaptive_projected_guidance_momentum=general_momentum
|
||||
)
|
||||
|
||||
ocr_apg = AdaptiveProjectedGuidance(
|
||||
eta=ocr_eta,
|
||||
adaptive_projected_guidance_rescale=ocr_norm_threshold,
|
||||
adaptive_projected_guidance_momentum=ocr_momentum
|
||||
)
|
||||
|
||||
m = model.clone()
|
||||
|
||||
|
||||
model_sampling = m.model.model_sampling
|
||||
general_start_t = model_sampling.percent_to_sigma(general_start_percent)
|
||||
ocr_start_t = model_sampling.percent_to_sigma(ocr_start_percent)
|
||||
|
||||
|
||||
def cfg_function(args):
|
||||
sigma = args["sigma"].to(torch.float32)
|
||||
is_first_step = math.isclose(sigma.item(), args['model_options']['transformer_options']['sample_sigmas'][0].item())
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
sigma = sigma[:, None, None, None]
|
||||
|
||||
|
||||
if not has_quoted_text:
|
||||
if sigma[0] <= general_start_t:
|
||||
modified_cond = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step)
|
||||
return modified_cond * sigma
|
||||
else:
|
||||
if cond_scale > 1:
|
||||
_ = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
else:
|
||||
if sigma[0] <= ocr_start_t:
|
||||
modified_cond = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step)
|
||||
return modified_cond * sigma
|
||||
else:
|
||||
if cond_scale > 1:
|
||||
_ = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
return cond
|
||||
|
||||
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
|
||||
return (m,)
|
||||
|
||||
class CLIPTextEncodeHunyuanDiTWithTextDetection:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "BOOLEAN", "STRING")
|
||||
RETURN_NAMES = ("conditioning", "has_quoted_text", "text")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning/hunyuan"
|
||||
|
||||
def detect_quoted_text(self, text):
|
||||
"""Detect quoted text in the prompt"""
|
||||
text_prompt_texts = []
|
||||
|
||||
# Patterns to match different quote styles
|
||||
pattern_quote_double = r'\"(.*?)\"'
|
||||
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||
pattern_quote_chinese_double = r'“(.*?)”'
|
||||
|
||||
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
return len(text_prompt_texts) > 0
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
has_quoted_text = self.detect_quoted_text(text)
|
||||
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['has_quoted_text'] = has_quoted_text
|
||||
c.append(n)
|
||||
|
||||
return (conditioning, has_quoted_text, text)
|
||||
|
||||
|
||||
class CLIPTextEncodeHunyuanImageRefiner:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"text": ("STRING", ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
RETURN_NAMES = ("conditioning",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning/hunyuan"
|
||||
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
c.append(n)
|
||||
|
||||
return (c, )
|
||||
|
||||
class EmptyHunyuanLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -151,8 +397,16 @@ class HunyuanRefinerLatent:
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
|
||||
}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HunyuanMixModeAPG": HunyuanMixModeAPG,
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
|
||||
"CLIPTextEncodeHunyuanImageRefiner": CLIPTextEncodeHunyuanImageRefiner,
|
||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -929,7 +929,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image","hunyuan_image_refiner"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user