ComfyUI/comfy_extras/nodes/nodes_language.py

628 lines
23 KiB
Python

from __future__ import annotations
import copy
import inspect
import logging
import operator
import os.path
from functools import reduce
from typing import Any, Dict, Optional, List, Callable, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, PretrainedConfig, AutoProcessor, BatchFeature, AutoModel, AutoModelForCausalLM, \
AutoModelForSeq2SeqLM
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, \
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, AutoModelForVision2Seq
from transformers.models.m2m_100.tokenization_m2m_100 import \
FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES
from transformers.models.nllb.tokenization_nllb import \
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
from typing_extensions import TypedDict
from comfy import model_management
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
_AUTO_CHAT_TEMPLATE = "default"
# add llava support
try:
from llava import model as _llava_model_side_effects
logging.debug("Additional LLaVA models are now supported")
except ImportError as exc:
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
# aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any]
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
TOKENS_TYPE_NAME = "TOKENS"
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.on_finalized_text_handler = on_finalized_text
def on_finalized_text(self, text: str, stream_end: bool = False):
self.on_finalized_text_handler(text, stream_end)
class _ProgressLogitsProcessor(LogitsProcessor):
def __init__(self, model: TransformersManagedModel):
self.eos_token_id = model.tokenizer.eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
self.eos_probability = probabilities[:, self.eos_token_id].item()
return scores
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
class TransformerStreamedProgress(TypedDict):
next_token: str
class TransformerSamplerBase(CustomNode):
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language/samplers"
@property
def do_sample(self):
return True
def execute(self, **kwargs):
return {
"do_sample": self.do_sample,
**kwargs
},
class TransformerTopKSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_k": ("INT", {"default": 50, "min": 1})
}
}
class TransformerTopPSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
}
}
class TransformerTemperatureSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"temperature": ("FLOAT", {"default": 1.0, "min": 0})
}
}
class TransformerGreedySampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
}
}
class TransformersGenerationConfig(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL", {})
}
}
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language"
def execute(self, model: TransformersManagedModel):
if model.model.generation_config is not None:
return model.model.generation_config
return {}
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
top_k = TransformerTopKSampler.INPUT_TYPES()
top_k["required"] |= {
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
}
return top_k
class TransformerBeamSearchSampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"num_beams": ("INT", {"default": 1, "min": 0}),
"early_stopping": ("BOOLEAN", {"default": True})
}
}
class TransformerMergeSamplers(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "language"
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
FUNCTION = "execute"
def execute(self, **kwargs):
do_sample = {
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
}
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
class TransformersImageProcessorLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (get_huggingface_repo_list(),),
"subfolder": ("STRING", {}),
"model": ("MODEL", {}),
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
}
}
CATEGORY = "language"
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, model: TransformersManagedModel = None, overwrite_tokenizer: bool = False):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
return model.patch_processor(processor, overwrite_tokenizer),
class TransformersLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (get_huggingface_repo_list(),),
"subfolder": ("STRING", {})
},
}
CATEGORY = "language"
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
with comfy_tqdm():
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
"trust_remote_code": True,
**hub_kwargs
}
# if flash attention exists, use it
# compute bitsandbytes configuration
try:
import bitsandbytes
except ImportError:
pass
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
model_type = config_dict["model_type"]
# language models prefer to use bfloat16 over float16
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()), }, {})
# if we have flash-attn installed, try to use it
try:
import flash_attn
attn_override_kwargs = {
"attn_implementation": "flash_attention_2",
**kwargs_to_try[0]
}
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
except ImportError:
pass
for i, props in enumerate(kwargs_to_try):
try:
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
else:
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
if model is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
else:
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
for i, props in enumerate(kwargs_to_try):
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
except:
processor = None
if isinstance(processor, PreTrainedTokenizerBase):
tokenizer = processor
processor = None
else:
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
if tokenizer is not None or processor is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
logging.debug("enabled xformers memory efficient attention")
model_managed = TransformersManagedModel(
repo_id=ckpt_name,
model=model,
tokenizer=tokenizer,
config_dict=config_dict,
processor=processor
)
return model_managed,
class TransformersTokenize(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
},
}
CATEGORY = "language"
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
return model.tokenize(prompt, [], None),
class TransformersM2M100LanguageCodes(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"lang_id": (tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES["m2m100"], {"default": "en"}),
},
}
CATEGORY = "language"
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, lang_id: str) -> ValidatedNodeResult:
return lang_id,
class TransformersFlores200LanguageCodes(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"lang_id": (tokenization_nllb_FAIRSEQ_LANGUAGE_CODES, {"default": "eng_Latn"}),
},
}
CATEGORY = "language"
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, lang_id: str) -> ValidatedNodeResult:
return lang_id,
class TransformersTranslationTokenize(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
"src_lang": ("STRING", {}),
"tgt_lang": ("STRING", {}),
},
}
CATEGORY = "language"
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, src_lang: str, tgt_lang: str) -> ValidatedNodeResult:
tokenizer = model.tokenizer
if hasattr(tokenizer, "src_lang"):
prev_src_lang = tokenizer.src_lang
else:
prev_src_lang = None
if hasattr(tokenizer, "tgt_lang"):
prev_tgt_lang = tokenizer.tgt_lang
else:
prev_tgt_lang = None
try:
if hasattr(tokenizer, "_build_translation_inputs"):
encoded = tokenizer._build_translation_inputs(
prompt, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
)
else:
tokenizer.src_lang = src_lang
tokenizer.tgt_lang = tgt_lang
encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
encoded["input_ids"] = encoded["input_ids"].to(device=model.load_device)
encoded["attention_mask"] = encoded["attention_mask"].to(device=model.load_device)
encoded["src_lang"] = src_lang
encoded["tgt_lang"] = tgt_lang
return encoded,
finally:
if prev_src_lang is not None:
tokenizer.src_lang = prev_src_lang
if prev_tgt_lang is not None:
tokenizer.tgt_lang = prev_tgt_lang
class OneShotInstructTokenize(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
"chat_template": ([_AUTO_CHAT_TEMPLATE] + list(KNOWN_CHAT_TEMPLATES.keys()), {})
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
if chat_template == _AUTO_CHAT_TEMPLATE:
# use an exact match
model_name = os.path.basename(model.repo_id)
if model_name in KNOWN_CHAT_TEMPLATES:
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
else:
chat_template = None
else:
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
return model.tokenize(prompt, images, chat_template),
class TransformersGenerate(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"tokens": (TOKENS_TYPE_NAME, {}),
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
"use_cache": ("BOOLEAN", {"default": True}),
},
"optional": {
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
}
}
CATEGORY = "language"
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self,
model: Optional[TransformersManagedModel] = None,
tokens: _TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs
):
tokens = copy.copy(tokens)
tokens_original = copy.copy(tokens)
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
load_models_gpu([model])
transformers_model: PreTrainedModel = model.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
# remove unused inputs
# maximizes compatibility with different models
generate_signature = inspect.signature(transformers_model.generate).parameters
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
gen_sig_keys = generate_signature.keys()
if "tgt_lang" in tokens:
to_delete.add("tgt_lang")
to_delete.add("src_lang")
to_delete.discard("input_ids")
if "forced_bos_token_id" in tokens:
to_delete.discard("forced_bos_token_id")
elif hasattr(tokenizer, "convert_tokens_to_ids"):
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
else:
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
if "input_ids" in tokens and "inputs" in tokens:
if "input_ids" in gen_sig_keys:
to_delete.add("inputs")
elif "inputs" in gen_sig_keys:
to_delete.add("input_ids")
for unused_kwarg in to_delete:
tokens.pop(unused_kwarg)
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
# images should be moved to model
for key in ("images", "pixel_values"):
if key in tokens:
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
# sets up inputs
inputs = tokens
# used to determine if text streaming is supported
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
# todo: deal with batches correctly, don't assume batch size 1
token_count = 0
# progress
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
token_count += 1
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
inputs.pop("attention_mask")
output_ids = transformers_model.generate(
**inputs,
streamer=text_streamer if num_beams <= 1 else None,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**generate_kwargs
)
if not transformers_model.config.is_encoder_decoder:
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
output_ids = output_ids[:, start_position:]
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
prev_src_lang = tokenizer.src_lang
tokenizer.src_lang = tokens_original["src_lang"]
else:
prev_src_lang = None
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
try:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
finally:
if prev_src_lang is not None:
tokenizer.src_lang = prev_src_lang
# gpu-loaded stuff like images can now be unloaded
if hasattr(tokens, "to"):
del tokens
else:
for to_delete in tokens.values():
del to_delete
del tokens
# todo: better support batches
return outputs[0],
class PreviewString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"forceInput": True}),
}
}
CATEGORY = "language"
FUNCTION = "execute"
RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True
def execute(self, value: str):
return {"ui": {"string": [value]}}
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformerTopKSampler,
TransformerTopPSampler,
TransformerTemperatureSampler,
TransformerGreedySampler,
TransformerContrastiveSearchSampler,
TransformerBeamSearchSampler,
TransformerMergeSamplers,
TransformersLoader,
TransformersImageProcessorLoader,
TransformersGenerate,
OneShotInstructTokenize,
TransformersM2M100LanguageCodes,
TransformersTokenize,
TransformersFlores200LanguageCodes,
TransformersTranslationTokenize,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls