mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-10 19:57:42 +08:00
HunyuanImage2.1: Implement Hunyuan APG
This commit is contained in:
parent
4f1f26ac6c
commit
0836853fec
@ -1,12 +1,14 @@
|
||||
from numpy import arccos
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import re
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
@ -23,6 +25,220 @@ class CLIPTextEncodeHunyuanDiT:
|
||||
|
||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def normalized_guidance_apg(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
|
||||
class AdaptiveProjectedGuidance:
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum=None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
# eta: float = 1.0,
|
||||
eta: float = 0.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor:
|
||||
|
||||
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
|
||||
pred = normalized_guidance_apg(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
return pred
|
||||
|
||||
class HunyuanMixModeAPG:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", ),
|
||||
"has_quoted_text": ("HAS_QUOTED_TEXT", ),
|
||||
|
||||
"guidance_scale": ("FLOAT", {"default": 8.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
||||
|
||||
"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
|
||||
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
|
||||
"general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}),
|
||||
|
||||
"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
|
||||
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
|
||||
"ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}),
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "apply_mix_mode_apg"
|
||||
CATEGORY = "sampling/custom_sampling/hunyuan"
|
||||
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, *args, **kwargs):
|
||||
return True
|
||||
|
||||
def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step,
|
||||
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step):
|
||||
|
||||
general_apg = AdaptiveProjectedGuidance(
|
||||
guidance_scale=guidance_scale,
|
||||
eta=general_eta,
|
||||
adaptive_projected_guidance_rescale=general_norm_threshold,
|
||||
adaptive_projected_guidance_momentum=general_momentum
|
||||
)
|
||||
|
||||
ocr_apg = AdaptiveProjectedGuidance(
|
||||
eta=ocr_eta,
|
||||
adaptive_projected_guidance_rescale=ocr_norm_threshold,
|
||||
adaptive_projected_guidance_momentum=ocr_momentum
|
||||
)
|
||||
|
||||
current_step = {"step": 0}
|
||||
|
||||
def cfg_function(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
step = current_step["step"]
|
||||
current_step["step"] += 1
|
||||
|
||||
if not has_quoted_text:
|
||||
if step > general_start_step:
|
||||
modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16)
|
||||
return modified_cond
|
||||
else:
|
||||
if cond_scale > 1:
|
||||
_ = general_apg(cond, uncond, step) # track momentum
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
else:
|
||||
if step > ocr_start_step:
|
||||
modified_cond = ocr_apg(cond, uncond, step)
|
||||
return modified_cond
|
||||
else:
|
||||
if cond_scale > 1:
|
||||
_ = ocr_apg(cond, uncond, step)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
return cond
|
||||
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
|
||||
return (m,)
|
||||
|
||||
class CLIPTextEncodeHunyuanDiTWithTextDetection:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {
|
||||
"clip": ("CLIP", ),
|
||||
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
|
||||
RETURN_NAMES = ("conditioning", "has_quoted_text")
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "advanced/conditioning/hunyuan"
|
||||
|
||||
def detect_quoted_text(self, text):
|
||||
"""Detect quoted text in the prompt"""
|
||||
text_prompt_texts = []
|
||||
|
||||
# Patterns to match different quote styles
|
||||
pattern_quote_double = r'\"(.*?)\"'
|
||||
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||
pattern_quote_chinese_double = r'“(.*?)”'
|
||||
|
||||
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
return len(text_prompt_texts) > 0
|
||||
|
||||
def encode(self, clip, text):
|
||||
tokens = clip.tokenize(text)
|
||||
has_quoted_text = self.detect_quoted_text(text)
|
||||
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['has_quoted_text'] = has_quoted_text
|
||||
c.append(n)
|
||||
|
||||
return (c, has_quoted_text)
|
||||
|
||||
class EmptyHunyuanLatentVideo:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -151,8 +367,16 @@ class HunyuanRefinerLatent:
|
||||
return (positive, negative, out_latent)
|
||||
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
|
||||
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
|
||||
}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HunyuanMixModeAPG": HunyuanMixModeAPG,
|
||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
|
||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user