ComfyUI/comfy_extras/nodes_hunyuan.py
2025-09-15 23:34:50 +08:00

386 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from numpy import arccos
import nodes
import node_helpers
import torch
import re
import comfy.model_management
class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, bert, mt5xl):
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
return (clip.encode_from_tokens_scheduled(tokens), )
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance_apg(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
class AdaptiveProjectedGuidance:
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum=None,
adaptive_projected_guidance_rescale: float = 15.0,
# eta: float = 1.0,
eta: float = 0.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__()
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor:
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
pred = normalized_guidance_apg(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
return pred
class HunyuanMixModeAPG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"has_quoted_text": ("HAS_QUOTED_TEXT", ),
"guidance_scale": ("FLOAT", {"default": 8.0, "min": 1.0, "max": 30.0, "step": 0.1}),
"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}),
"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_mix_mode_apg"
CATEGORY = "sampling/custom_sampling/hunyuan"
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return True
def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step,
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step):
general_apg = AdaptiveProjectedGuidance(
guidance_scale=guidance_scale,
eta=general_eta,
adaptive_projected_guidance_rescale=general_norm_threshold,
adaptive_projected_guidance_momentum=general_momentum
)
ocr_apg = AdaptiveProjectedGuidance(
eta=ocr_eta,
adaptive_projected_guidance_rescale=ocr_norm_threshold,
adaptive_projected_guidance_momentum=ocr_momentum
)
current_step = {"step": 0}
def cfg_function(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
step = current_step["step"]
current_step["step"] += 1
if not has_quoted_text:
if step > general_start_step:
modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16)
return modified_cond
else:
if cond_scale > 1:
_ = general_apg(cond, uncond, step) # track momentum
return uncond + (cond - uncond) * cond_scale
else:
if step > ocr_start_step:
modified_cond = ocr_apg(cond, uncond, step)
return modified_cond
else:
if cond_scale > 1:
_ = ocr_apg(cond, uncond, step)
return uncond + (cond - uncond) * cond_scale
return cond
m = model.clone()
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
return (m,)
class CLIPTextEncodeHunyuanDiTWithTextDetection:
@classmethod
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
RETURN_NAMES = ("conditioning", "has_quoted_text")
FUNCTION = "encode"
CATEGORY = "advanced/conditioning/hunyuan"
def detect_quoted_text(self, text):
"""Detect quoted text in the prompt"""
text_prompt_texts = []
# Patterns to match different quote styles
pattern_quote_double = r'\"(.*?)\"'
pattern_quote_chinese_single = r'(.*?)'
pattern_quote_chinese_double = r'“(.*?)”'
matches_quote_double = re.findall(pattern_quote_double, text)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
return len(text_prompt_texts) > 0
def encode(self, clip, text):
tokens = clip.tokenize(text)
has_quoted_text = self.detect_quoted_text(text)
conditioning = clip.encode_from_tokens_scheduled(tokens)
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['has_quoted_text'] = has_quoted_text
c.append(n)
return (c, has_quoted_text)
class EmptyHunyuanLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
class TextEncodeHunyuanVideo_ImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, clip_vision_output, prompt, image_interleave):
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
return (clip.encode_from_tokens_scheduled(tokens), )
class HunyuanImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image)
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
elif guidance_type == "v2 (replace)":
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
elif guidance_type == "custom":
cond = {"ref_latent": concat_latent_image}
positive = node_helpers.conditioning_set_values(positive, cond)
out_latent["samples"] = latent
return (positive, out_latent)
class EmptyHunyuanImageLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
class HunyuanRefinerLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT", ),
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "execute"
def execute(self, positive, negative, latent, noise_augmentation):
latent = latent["samples"]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
out_latent = {}
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
return (positive, negative, out_latent)
NODE_DISPLAY_NAME_MAPPINGS = {
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
}
NODE_CLASS_MAPPINGS = {
"HunyuanMixModeAPG": HunyuanMixModeAPG,
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
"HunyuanRefinerLatent": HunyuanRefinerLatent,
}