mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
512 lines
16 KiB
Python
512 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import operator
|
|
import os.path
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from functools import reduce
|
|
from typing import Optional, List
|
|
|
|
import torch
|
|
from transformers import AutoProcessor
|
|
from transformers.models.m2m_100.tokenization_m2m_100 import \
|
|
FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES
|
|
from transformers.models.nllb.tokenization_nllb import \
|
|
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
|
|
|
|
from comfy.cmd import folder_paths
|
|
from comfy.component_model.folder_path_types import SaveImagePathTuple
|
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
|
from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \
|
|
TOKENS_TYPE_NAME, LanguageModel, LanguagePrompt
|
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
|
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
|
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
|
from comfy.node_helpers import export_custom_nodes, export_package_as_web_directory
|
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
|
|
|
|
_AUTO_CHAT_TEMPLATE = "default"
|
|
|
|
|
|
class TransformerSamplerBase(CustomNode, ABC):
|
|
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
|
RETURN_NAMES = "GENERATION ARGS",
|
|
FUNCTION = "execute"
|
|
CATEGORY = "language/samplers"
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return ...
|
|
|
|
@property
|
|
def do_sample(self):
|
|
return True
|
|
|
|
def execute(self, **kwargs):
|
|
return {
|
|
"do_sample": self.do_sample,
|
|
**kwargs
|
|
},
|
|
|
|
|
|
class TransformerTopKSampler(TransformerSamplerBase):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"top_k": ("INT", {"default": 50, "min": 1})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerTopPSampler(TransformerSamplerBase):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerTemperatureSampler(TransformerSamplerBase):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"temperature": ("FLOAT", {"default": 1.0, "min": 0, "step": 0.001})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerGreedySampler(TransformerSamplerBase):
|
|
@property
|
|
def do_sample(self):
|
|
return False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
}
|
|
}
|
|
|
|
|
|
class TransformersGenerationConfig(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL", {})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
|
RETURN_NAMES = "GENERATION ARGS",
|
|
FUNCTION = "execute"
|
|
CATEGORY = "language"
|
|
|
|
def execute(self, model: TransformersManagedModel):
|
|
if model.model.generation_config is not None:
|
|
return model.model.generation_config
|
|
|
|
return {}
|
|
|
|
|
|
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
top_k = TransformerTopKSampler.INPUT_TYPES()
|
|
top_k["required"] |= {
|
|
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
|
|
}
|
|
return top_k
|
|
|
|
|
|
class TransformerBeamSearchSampler(TransformerSamplerBase):
|
|
@property
|
|
def do_sample(self):
|
|
return False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"num_beams": ("INT", {"default": 1, "min": 0}),
|
|
"early_stopping": ("BOOLEAN", {"default": True})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerMergeSamplers(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
range_ = {"value0": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
|
range_.update({f"value{i}": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
|
|
|
return {
|
|
"required": range_
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, **kwargs):
|
|
do_sample = {
|
|
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
|
|
}
|
|
|
|
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
|
|
|
|
|
|
class TransformersImageProcessorLoader(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"ckpt_name": (get_huggingface_repo_list(),),
|
|
"subfolder": ("STRING", {}),
|
|
"model": ("MODEL", {}),
|
|
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = "MODEL",
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, model: TransformersManagedModel = None, overwrite_tokenizer: bool = False):
|
|
hub_kwargs = {}
|
|
if subfolder is not None and subfolder != "":
|
|
hub_kwargs["subfolder"] = subfolder
|
|
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
|
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
|
return model.patch_processor(processor, overwrite_tokenizer),
|
|
|
|
|
|
class TransformersLoader(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"ckpt_name": (get_huggingface_repo_list(),),
|
|
},
|
|
"optional": {
|
|
"subfolder": ("STRING", {}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = "MODEL",
|
|
RETURN_NAMES = "language model",
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs) -> tuple[TransformersManagedModel]:
|
|
return TransformersManagedModel.from_pretrained(ckpt_name, subfolder),
|
|
|
|
|
|
class TransformersLoader1(TransformersLoader):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"ckpt_name": ("STRING", {}),
|
|
},
|
|
"optional": {
|
|
"subfolder": ("STRING", {}),
|
|
}
|
|
}
|
|
|
|
|
|
class TransformersTokenize(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult:
|
|
return model.tokenize(prompt, [], None),
|
|
|
|
|
|
class TransformersM2M100LanguageCodes(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"lang_id": (tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES["m2m100"], {"default": "en"}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, lang_id: str) -> ValidatedNodeResult:
|
|
return lang_id,
|
|
|
|
|
|
class TransformersFlores200LanguageCodes(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"lang_id": (tokenization_nllb_FAIRSEQ_LANGUAGE_CODES, {"default": "eng_Latn"}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, lang_id: str) -> ValidatedNodeResult:
|
|
return lang_id,
|
|
|
|
|
|
class TransformersTranslationTokenize(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
|
"src_lang": ("STRING", {}),
|
|
"tgt_lang": ("STRING", {}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: TransformersManagedModel, prompt: str, src_lang: str, tgt_lang: str) -> ValidatedNodeResult:
|
|
tokenizer = model.tokenizer
|
|
|
|
if hasattr(tokenizer, "src_lang"):
|
|
prev_src_lang = tokenizer.src_lang
|
|
else:
|
|
prev_src_lang = None
|
|
if hasattr(tokenizer, "tgt_lang"):
|
|
prev_tgt_lang = tokenizer.tgt_lang
|
|
else:
|
|
prev_tgt_lang = None
|
|
|
|
try:
|
|
if hasattr(tokenizer, "_build_translation_inputs"):
|
|
encoded = tokenizer._build_translation_inputs(
|
|
prompt, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
|
)
|
|
else:
|
|
tokenizer.src_lang = src_lang
|
|
tokenizer.tgt_lang = tgt_lang
|
|
|
|
encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
|
encoded["input_ids"] = encoded["input_ids"].to(device=model.load_device)
|
|
encoded["attention_mask"] = encoded["attention_mask"].to(device=model.load_device)
|
|
encoded["src_lang"] = src_lang
|
|
encoded["tgt_lang"] = tgt_lang
|
|
return encoded,
|
|
finally:
|
|
if prev_src_lang is not None:
|
|
tokenizer.src_lang = prev_src_lang
|
|
if prev_tgt_lang is not None:
|
|
tokenizer.tgt_lang = prev_tgt_lang
|
|
|
|
|
|
class OneShotInstructTokenize(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
|
"chat_template": ([_AUTO_CHAT_TEMPLATE] + list(KNOWN_CHAT_TEMPLATES.keys()), {})
|
|
},
|
|
"optional": {
|
|
"images": ("IMAGE", {}),
|
|
"videos": ("VIDEO", {}),
|
|
"system_prompt": ("STRING", {"multiline": True, "default": ""})
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, videos: list | object = None, chat_template: Optional[str] = _AUTO_CHAT_TEMPLATE, system_prompt: str = "") -> ValidatedNodeResult:
|
|
if chat_template == _AUTO_CHAT_TEMPLATE:
|
|
model_name = os.path.basename(model.repo_id)
|
|
if model_name in KNOWN_CHAT_TEMPLATES:
|
|
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
|
|
else:
|
|
chat_template = None
|
|
elif chat_template is not None:
|
|
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
|
|
|
|
video_tensors = []
|
|
if videos is not None:
|
|
if not isinstance(videos, list):
|
|
videos_list = [videos]
|
|
else:
|
|
videos_list = videos
|
|
|
|
for vid in videos_list:
|
|
if hasattr(vid, "get_components"):
|
|
components = vid.get_components()
|
|
video_tensors.append(components.images)
|
|
elif isinstance(vid, torch.Tensor):
|
|
video_tensors.append(vid)
|
|
|
|
messages: LanguagePrompt | str
|
|
|
|
has_images = images is not None and len(images) > 0
|
|
has_videos = len(video_tensors) > 0
|
|
|
|
if system_prompt != "" or has_images or has_videos:
|
|
user_content = [{"type": "text", "text": prompt}]
|
|
if has_images:
|
|
user_content += [{"type": "image"} for _ in range(len(images))]
|
|
|
|
if has_videos:
|
|
user_content += [{"type": "video"} for _ in range(len(video_tensors))]
|
|
|
|
messages = [
|
|
{"role": "user", "content": user_content}
|
|
]
|
|
|
|
if system_prompt.strip() != "":
|
|
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
else:
|
|
messages = prompt
|
|
|
|
return model.tokenize(messages, images, video_tensors, chat_template),
|
|
|
|
|
|
class TransformersGenerate(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"tokens": (TOKENS_TYPE_NAME, {}),
|
|
"max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 0xffffffff}),
|
|
"seed": Seed,
|
|
},
|
|
"optional": {
|
|
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self,
|
|
model: Optional[LanguageModel] = None,
|
|
tokens: TOKENS_TYPE = None,
|
|
max_new_tokens: int = 512,
|
|
seed: int = 0,
|
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
|
):
|
|
return model.generate(tokens, max_new_tokens, seed, sampler),
|
|
|
|
|
|
class PreviewString(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"forceInput": True}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "strings"
|
|
FUNCTION = "execute"
|
|
RETURN_TYPES = ("STRING",)
|
|
OUTPUT_NODE = True
|
|
|
|
def execute(self, value: str):
|
|
return {"ui": {"string": [value]}}
|
|
|
|
|
|
class SaveString(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"forceInput": True}),
|
|
"filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
|
|
},
|
|
"optional": {
|
|
"extension": ("STRING", {"default": ".txt"})
|
|
}
|
|
}
|
|
|
|
CATEGORY = "strings"
|
|
FUNCTION = "execute"
|
|
OUTPUT_NODE = True
|
|
RETURN_TYPES = ()
|
|
|
|
def get_save_path(self, filename_prefix) -> SaveImagePathTuple:
|
|
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
|
|
|
|
def execute(self, value: str | list[str] = "", filename_prefix: str = "ComfyUI", extension: str = ".txt"):
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = self.get_save_path(filename_prefix)
|
|
if isinstance(value, str):
|
|
value = [value]
|
|
|
|
for i, value_i in enumerate(value):
|
|
# roughly matches the behavior of save image, but does not support batch numbers
|
|
with open(os.path.join(full_output_folder, f"{filename}_{counter:05d}{extension}" if len(value) == 1 else f"{filename}_{counter:05d}_{i:02d}{extension}"), "wt+") as f:
|
|
f.write(value_i)
|
|
return {"ui": {"string": value}}
|
|
|
|
|
|
class OmitThink(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"forceInput": True}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "strings"
|
|
FUNCTION = "execute"
|
|
OUTPUT_NODE = True
|
|
RETURN_TYPES = ("STRING",)
|
|
|
|
def execute(self, value: str | list[str] = "") -> tuple[list[str]]:
|
|
pattern_explicit = r"<think>.*?</think>"
|
|
pattern_missing_start = r"^.*?</think>"
|
|
|
|
if isinstance(value, str):
|
|
values = [value]
|
|
else:
|
|
values = value
|
|
|
|
result = []
|
|
for value in values:
|
|
if "<think>" in value:
|
|
cleaned_text = re.sub(pattern_explicit, "", value, flags=re.DOTALL)
|
|
elif "</think>" in value:
|
|
cleaned_text = re.sub(pattern_missing_start, "", value, flags=re.DOTALL)
|
|
else:
|
|
cleaned_text = value
|
|
result.append(cleaned_text.strip())
|
|
|
|
return result,
|
|
|
|
|
|
export_custom_nodes()
|
|
export_package_as_web_directory("comfy_extras.language_web")
|