HunyuanImage2.1: Implement Hunyuan APG

This commit is contained in:
KimbingNg 2025-09-15 22:27:43 +08:00
parent 4f1f26ac6c
commit 0836853fec

View File

@ -1,12 +1,14 @@
from numpy import arccos
import nodes
import node_helpers
import torch
import re
import comfy.model_management
class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
@ -23,6 +25,220 @@ class CLIPTextEncodeHunyuanDiT:
return (clip.encode_from_tokens_scheduled(tokens), )
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance_apg(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
class AdaptiveProjectedGuidance:
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum=None,
adaptive_projected_guidance_rescale: float = 15.0,
# eta: float = 1.0,
eta: float = 0.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__()
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor:
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
pred = normalized_guidance_apg(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
return pred
class HunyuanMixModeAPG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"has_quoted_text": ("HAS_QUOTED_TEXT", ),
"guidance_scale": ("FLOAT", {"default": 8.0, "min": 1.0, "max": 30.0, "step": 0.1}),
"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}),
"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_mix_mode_apg"
CATEGORY = "sampling/custom_sampling/hunyuan"
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return True
def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step,
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step):
general_apg = AdaptiveProjectedGuidance(
guidance_scale=guidance_scale,
eta=general_eta,
adaptive_projected_guidance_rescale=general_norm_threshold,
adaptive_projected_guidance_momentum=general_momentum
)
ocr_apg = AdaptiveProjectedGuidance(
eta=ocr_eta,
adaptive_projected_guidance_rescale=ocr_norm_threshold,
adaptive_projected_guidance_momentum=ocr_momentum
)
current_step = {"step": 0}
def cfg_function(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
step = current_step["step"]
current_step["step"] += 1
if not has_quoted_text:
if step > general_start_step:
modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16)
return modified_cond
else:
if cond_scale > 1:
_ = general_apg(cond, uncond, step) # track momentum
return uncond + (cond - uncond) * cond_scale
else:
if step > ocr_start_step:
modified_cond = ocr_apg(cond, uncond, step)
return modified_cond
else:
if cond_scale > 1:
_ = ocr_apg(cond, uncond, step)
return uncond + (cond - uncond) * cond_scale
return cond
m = model.clone()
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
return (m,)
class CLIPTextEncodeHunyuanDiTWithTextDetection:
@classmethod
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
RETURN_NAMES = ("conditioning", "has_quoted_text")
FUNCTION = "encode"
CATEGORY = "advanced/conditioning/hunyuan"
def detect_quoted_text(self, text):
"""Detect quoted text in the prompt"""
text_prompt_texts = []
# Patterns to match different quote styles
pattern_quote_double = r'\"(.*?)\"'
pattern_quote_chinese_single = r'(.*?)'
pattern_quote_chinese_double = r'“(.*?)”'
matches_quote_double = re.findall(pattern_quote_double, text)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
return len(text_prompt_texts) > 0
def encode(self, clip, text):
tokens = clip.tokenize(text)
has_quoted_text = self.detect_quoted_text(text)
conditioning = clip.encode_from_tokens_scheduled(tokens)
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['has_quoted_text'] = has_quoted_text
c.append(n)
return (c, has_quoted_text)
class EmptyHunyuanLatentVideo:
@classmethod
def INPUT_TYPES(s):
@ -151,8 +367,16 @@ class HunyuanRefinerLatent:
return (positive, negative, out_latent)
NODE_DISPLAY_NAME_MAPPINGS = {
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
}
NODE_CLASS_MAPPINGS = {
"HunyuanMixModeAPG": HunyuanMixModeAPG,
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,