ComfyUI/comfy_extras/nodes/nodes_language.py
2024-07-09 12:57:33 -07:00

451 lines
16 KiB
Python

from __future__ import annotations
import copy
import inspect
import logging
import operator
import os.path
from functools import reduce
from typing import Any, Dict, Optional, List, Callable, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel
from typing_extensions import TypedDict
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
_AUTO_CHAT_TEMPLATE = "default"
# add llava support
try:
from llava import model
logging.info("Additional LLaVA models are now supported")
except ImportError as exc:
logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
# aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any]
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
TOKENS_TYPE_NAME = "TOKENS"
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.on_finalized_text_handler = on_finalized_text
def on_finalized_text(self, text: str, stream_end: bool = False):
self.on_finalized_text_handler(text, stream_end)
class _ProgressLogitsProcessor(LogitsProcessor):
def __init__(self, model: TransformersManagedModel):
self.eos_token_id = model.tokenizer.eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
self.eos_probability = probabilities[:, self.eos_token_id].item()
return scores
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
class TransformerStreamedProgress(TypedDict):
next_token: str
class TransformerSamplerBase(CustomNode):
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language/samplers"
@property
def do_sample(self):
return True
def execute(self, **kwargs):
return {
"do_sample": self.do_sample,
**kwargs
},
class TransformerTopKSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_k": ("INT", {"default": 50, "min": 1})
}
}
class TransformerTopPSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
}
}
class TransformerTemperatureSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"temperature": ("FLOAT", {"default": 1.0, "min": 0})
}
}
class TransformerGreedySampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
}
}
class TransformersGenerationConfig(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",)
}
}
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language"
def execute(self, model: TransformersManagedModel):
if model.model.generation_config is not None:
return model.model.generation_config
return {}
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
top_k = TransformerTopKSampler.INPUT_TYPES()
top_k["required"] |= {
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
}
return top_k
class TransformerBeamSearchSampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"num_beams": ("INT", {"default": 1, "min": 0}),
"early_stopping": ("BOOLEAN", {"default": True})
}
}
class TransformerMergeSamplers(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "language"
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
FUNCTION = "execute"
def execute(self, **kwargs):
do_sample = {
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
}
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
class TransformersImageProcessorLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (get_huggingface_repo_list(),),
"subfolder": ("STRING", {}),
"model": ("MODEL", {}),
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
}
}
CATEGORY = "language"
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, model: TransformersManagedModel = None, overwrite_tokenizer: bool = False):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
return model.patch_processor(processor, overwrite_tokenizer),
class TransformersLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (get_huggingface_repo_list(),),
"subfolder": ("STRING", {})
},
}
CATEGORY = "language"
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
with comfy_tqdm():
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
"torch_dtype": unet_dtype(),
"device_map": get_torch_device_name(unet_offload_device()),
"low_cpu_mem_usage": True,
"trust_remote_code": True,
**hub_kwargs
}
try:
model = AutoModel.from_pretrained(**from_pretrained_kwargs)
except:
# not yet supported by automodel
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
except:
processor = LlavaNextProcessor.from_pretrained(**from_pretrained_kwargs)
except:
processor = None
if not isinstance(processor, ProcessorMixin):
processor = None
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
model_managed = TransformersManagedModel(
repo_id=ckpt_name,
model=model,
tokenizer=tokenizer,
config_dict=config_dict,
processor=processor
)
return model_managed,
class OneShotInstructTokenize(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
"chat_template": ([_AUTO_CHAT_TEMPLATE] + list(KNOWN_CHAT_TEMPLATES.keys()), {})
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
if chat_template == _AUTO_CHAT_TEMPLATE:
# use an exact match
model_name = os.path.basename(model.repo_id)
if model_name in KNOWN_CHAT_TEMPLATES:
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
else:
chat_template = None
else:
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
return model.tokenize(prompt, images, chat_template),
class TransformersGenerate(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"tokens": (TOKENS_TYPE_NAME, {}),
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
"use_cache": ("BOOLEAN", {"default": True}),
},
"optional": {
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
}
}
CATEGORY = "language"
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self,
model: Optional[TransformersManagedModel] = None,
tokens: _TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs
):
tokens = copy.copy(tokens)
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
load_model_gpu(model)
transformers_model: PreTrainedModel = model.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
# remove unused inputs
# maximizes compatibility with different models
generate_signature = inspect.signature(transformers_model.generate).parameters
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
gen_sig_keys = generate_signature.keys()
if "input_ids" in tokens and "inputs" in tokens:
if "input_ids" in gen_sig_keys:
to_delete.add("inputs")
elif "inputs" in gen_sig_keys:
to_delete.add("input_ids")
for unused_kwarg in to_delete:
tokens.pop(unused_kwarg)
logging.info(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
# images should be moved to model
for key in ("images", "pixel_values"):
if key in tokens:
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
inputs = tokens
progress_logits_processor = _ProgressLogitsProcessor(model)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
# todo: deal with batches correctly, don't assume batch size 1
token_count = 0
# progress
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
# todo: this has to be more mathematically sensible
eos_token_probability = progress_logits_processor.eos_probability
token_count += 1
value = max(eos_token_probability * max_new_tokens, token_count)
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview)
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
output_ids = transformers_model.generate(
**inputs,
logits_processor=LogitsProcessorList([progress_logits_processor]),
streamer=text_streamer,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**generate_kwargs
)
if transformers_model.config.is_encoder_decoder:
start_position = 1
else:
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
output_ids = output_ids[:, start_position:]
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
# gpu-loaded stuff like images can now be unloaded
if hasattr(tokens, "to"):
del tokens
else:
for to_delete in tokens.values():
del to_delete
del tokens
# todo: better support batches
return outputs[0],
class PreviewString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"forceInput": True}),
}
}
CATEGORY = "language"
FUNCTION = "execute"
RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True
def execute(self, value: str):
return {"ui": {"string": [value]}}
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformerTopKSampler,
TransformerTopPSampler,
TransformerTemperatureSampler,
TransformerGreedySampler,
TransformerContrastiveSearchSampler,
TransformerBeamSearchSampler,
TransformerMergeSamplers,
TransformersLoader,
TransformersImageProcessorLoader,
TransformersGenerate,
OneShotInstructTokenize,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls