mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-10 19:57:42 +08:00
HunyuanImage2.1: Implement Hunyuan APG
This commit is contained in:
parent
4f1f26ac6c
commit
0836853fec
@ -1,12 +1,14 @@
|
|||||||
|
from numpy import arccos
|
||||||
import nodes
|
import nodes
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeHunyuanDiT:
|
class CLIPTextEncodeHunyuanDiT:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(cls):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"clip": ("CLIP", ),
|
"clip": ("CLIP", ),
|
||||||
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
@ -23,6 +25,220 @@ class CLIPTextEncodeHunyuanDiT:
|
|||||||
|
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return (clip.encode_from_tokens_scheduled(tokens), )
|
||||||
|
|
||||||
|
class MomentumBuffer:
|
||||||
|
def __init__(self, momentum: float):
|
||||||
|
self.momentum = momentum
|
||||||
|
self.running_average = 0
|
||||||
|
|
||||||
|
def update(self, update_value: torch.Tensor):
|
||||||
|
new_average = self.momentum * self.running_average
|
||||||
|
self.running_average = update_value + new_average
|
||||||
|
|
||||||
|
def normalized_guidance_apg(
|
||||||
|
pred_cond: torch.Tensor,
|
||||||
|
pred_uncond: torch.Tensor,
|
||||||
|
guidance_scale: float,
|
||||||
|
momentum_buffer,
|
||||||
|
eta: float = 1.0,
|
||||||
|
norm_threshold: float = 0.0,
|
||||||
|
use_original_formulation: bool = False,
|
||||||
|
):
|
||||||
|
diff = pred_cond - pred_uncond
|
||||||
|
dim = [-i for i in range(1, len(diff.shape))]
|
||||||
|
|
||||||
|
if momentum_buffer is not None:
|
||||||
|
momentum_buffer.update(diff)
|
||||||
|
diff = momentum_buffer.running_average
|
||||||
|
|
||||||
|
if norm_threshold > 0:
|
||||||
|
ones = torch.ones_like(diff)
|
||||||
|
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||||
|
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||||
|
diff = diff * scale_factor
|
||||||
|
|
||||||
|
v0, v1 = diff.double(), pred_cond.double()
|
||||||
|
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||||
|
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||||
|
v0_orthogonal = v0 - v0_parallel
|
||||||
|
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||||
|
|
||||||
|
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||||
|
pred = pred_cond if use_original_formulation else pred_uncond
|
||||||
|
pred = pred + guidance_scale * normalized_update
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
class AdaptiveProjectedGuidance:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
adaptive_projected_guidance_momentum=None,
|
||||||
|
adaptive_projected_guidance_rescale: float = 15.0,
|
||||||
|
# eta: float = 1.0,
|
||||||
|
eta: float = 0.0,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
use_original_formulation: bool = False,
|
||||||
|
start: float = 0.0,
|
||||||
|
stop: float = 1.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||||
|
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||||
|
self.eta = eta
|
||||||
|
self.guidance_rescale = guidance_rescale
|
||||||
|
self.use_original_formulation = use_original_formulation
|
||||||
|
self.momentum_buffer = None
|
||||||
|
|
||||||
|
def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor:
|
||||||
|
|
||||||
|
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
|
||||||
|
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||||
|
|
||||||
|
pred = normalized_guidance_apg(
|
||||||
|
pred_cond,
|
||||||
|
pred_uncond,
|
||||||
|
self.guidance_scale,
|
||||||
|
self.momentum_buffer,
|
||||||
|
self.eta,
|
||||||
|
self.adaptive_projected_guidance_rescale,
|
||||||
|
self.use_original_formulation,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
class HunyuanMixModeAPG:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL", ),
|
||||||
|
"has_quoted_text": ("HAS_QUOTED_TEXT", ),
|
||||||
|
|
||||||
|
"guidance_scale": ("FLOAT", {"default": 8.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
||||||
|
|
||||||
|
"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
|
||||||
|
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}),
|
||||||
|
|
||||||
|
"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
|
||||||
|
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}),
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "apply_mix_mode_apg"
|
||||||
|
CATEGORY = "sampling/custom_sampling/hunyuan"
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, *args, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step,
|
||||||
|
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step):
|
||||||
|
|
||||||
|
general_apg = AdaptiveProjectedGuidance(
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
eta=general_eta,
|
||||||
|
adaptive_projected_guidance_rescale=general_norm_threshold,
|
||||||
|
adaptive_projected_guidance_momentum=general_momentum
|
||||||
|
)
|
||||||
|
|
||||||
|
ocr_apg = AdaptiveProjectedGuidance(
|
||||||
|
eta=ocr_eta,
|
||||||
|
adaptive_projected_guidance_rescale=ocr_norm_threshold,
|
||||||
|
adaptive_projected_guidance_momentum=ocr_momentum
|
||||||
|
)
|
||||||
|
|
||||||
|
current_step = {"step": 0}
|
||||||
|
|
||||||
|
def cfg_function(args):
|
||||||
|
cond = args["cond"]
|
||||||
|
uncond = args["uncond"]
|
||||||
|
cond_scale = args["cond_scale"]
|
||||||
|
|
||||||
|
step = current_step["step"]
|
||||||
|
current_step["step"] += 1
|
||||||
|
|
||||||
|
if not has_quoted_text:
|
||||||
|
if step > general_start_step:
|
||||||
|
modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16)
|
||||||
|
return modified_cond
|
||||||
|
else:
|
||||||
|
if cond_scale > 1:
|
||||||
|
_ = general_apg(cond, uncond, step) # track momentum
|
||||||
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
else:
|
||||||
|
if step > ocr_start_step:
|
||||||
|
modified_cond = ocr_apg(cond, uncond, step)
|
||||||
|
return modified_cond
|
||||||
|
else:
|
||||||
|
if cond_scale > 1:
|
||||||
|
_ = ocr_apg(cond, uncond, step)
|
||||||
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
return cond
|
||||||
|
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
|
||||||
|
return (m,)
|
||||||
|
|
||||||
|
class CLIPTextEncodeHunyuanDiTWithTextDetection:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
|
||||||
|
RETURN_NAMES = ("conditioning", "has_quoted_text")
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/hunyuan"
|
||||||
|
|
||||||
|
def detect_quoted_text(self, text):
|
||||||
|
"""Detect quoted text in the prompt"""
|
||||||
|
text_prompt_texts = []
|
||||||
|
|
||||||
|
# Patterns to match different quote styles
|
||||||
|
pattern_quote_double = r'\"(.*?)\"'
|
||||||
|
pattern_quote_chinese_single = r'‘(.*?)’'
|
||||||
|
pattern_quote_chinese_double = r'“(.*?)”'
|
||||||
|
|
||||||
|
matches_quote_double = re.findall(pattern_quote_double, text)
|
||||||
|
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
|
||||||
|
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
|
||||||
|
|
||||||
|
text_prompt_texts.extend(matches_quote_double)
|
||||||
|
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||||
|
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||||
|
|
||||||
|
return len(text_prompt_texts) > 0
|
||||||
|
|
||||||
|
def encode(self, clip, text):
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
has_quoted_text = self.detect_quoted_text(text)
|
||||||
|
|
||||||
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
|
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
n[1]['has_quoted_text'] = has_quoted_text
|
||||||
|
c.append(n)
|
||||||
|
|
||||||
|
return (c, has_quoted_text)
|
||||||
|
|
||||||
class EmptyHunyuanLatentVideo:
|
class EmptyHunyuanLatentVideo:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -151,8 +367,16 @@ class HunyuanRefinerLatent:
|
|||||||
return (positive, negative, out_latent)
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
|
||||||
|
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
|
||||||
|
}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"HunyuanMixModeAPG": HunyuanMixModeAPG,
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
|
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
|
||||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
||||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user