mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
628 lines
23 KiB
Python
628 lines
23 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import inspect
|
|
import logging
|
|
import operator
|
|
import os.path
|
|
from functools import reduce
|
|
from typing import Any, Dict, Optional, List, Callable, Union
|
|
|
|
import torch
|
|
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
|
PreTrainedTokenizerBase, PretrainedConfig, AutoProcessor, BatchFeature, AutoModel, AutoModelForCausalLM, \
|
|
AutoModelForSeq2SeqLM
|
|
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, \
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, AutoModelForVision2Seq
|
|
from transformers.models.m2m_100.tokenization_m2m_100 import \
|
|
FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES
|
|
from transformers.models.nllb.tokenization_nllb import \
|
|
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
|
|
from typing_extensions import TypedDict
|
|
|
|
from comfy import model_management
|
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
|
from comfy.language.language_types import ProcessorResult
|
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
|
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
|
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
|
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
|
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
|
|
|
_AUTO_CHAT_TEMPLATE = "default"
|
|
|
|
# add llava support
|
|
try:
|
|
from llava import model as _llava_model_side_effects
|
|
|
|
logging.debug("Additional LLaVA models are now supported")
|
|
except ImportError as exc:
|
|
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
|
|
|
# aka kwargs type
|
|
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
|
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
|
|
|
|
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
|
|
TOKENS_TYPE_NAME = "TOKENS"
|
|
|
|
|
|
class _ProgressTextStreamer(TextStreamer):
|
|
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
|
self.on_finalized_text_handler = on_finalized_text
|
|
|
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
|
self.on_finalized_text_handler(text, stream_end)
|
|
|
|
|
|
class _ProgressLogitsProcessor(LogitsProcessor):
|
|
def __init__(self, model: TransformersManagedModel):
|
|
self.eos_token_id = model.tokenizer.eos_token_id
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
probabilities = scores.softmax(dim=-1)
|
|
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
|
return scores
|
|
|
|
|
|
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
|
|
class TransformerStreamedProgress(TypedDict):
|
|
next_token: str
|
|
|
|
|
|
class TransformerSamplerBase(CustomNode):
|
|
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
|
RETURN_NAMES = "GENERATION ARGS",
|
|
FUNCTION = "execute"
|
|
CATEGORY = "language/samplers"
|
|
|
|
@property
|
|
def do_sample(self):
|
|
return True
|
|
|
|
def execute(self, **kwargs):
|
|
return {
|
|
"do_sample": self.do_sample,
|
|
**kwargs
|
|
},
|
|
|
|
|
|
class TransformerTopKSampler(TransformerSamplerBase):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"top_k": ("INT", {"default": 50, "min": 1})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerTopPSampler(TransformerSamplerBase):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerTemperatureSampler(TransformerSamplerBase):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"temperature": ("FLOAT", {"default": 1.0, "min": 0})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerGreedySampler(TransformerSamplerBase):
|
|
@property
|
|
def do_sample(self):
|
|
return False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
}
|
|
}
|
|
|
|
|
|
class TransformersGenerationConfig(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL", {})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
|
RETURN_NAMES = "GENERATION ARGS",
|
|
FUNCTION = "execute"
|
|
CATEGORY = "language"
|
|
|
|
def execute(self, model: TransformersManagedModel):
|
|
if model.model.generation_config is not None:
|
|
return model.model.generation_config
|
|
|
|
return {}
|
|
|
|
|
|
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
top_k = TransformerTopKSampler.INPUT_TYPES()
|
|
top_k["required"] |= {
|
|
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
|
|
}
|
|
return top_k
|
|
|
|
|
|
class TransformerBeamSearchSampler(TransformerSamplerBase):
|
|
@property
|
|
def do_sample(self):
|
|
return False
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"num_beams": ("INT", {"default": 1, "min": 0}),
|
|
"early_stopping": ("BOOLEAN", {"default": True})
|
|
}
|
|
}
|
|
|
|
|
|
class TransformerMergeSamplers(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
|
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
|
|
|
return {
|
|
"required": range_
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, **kwargs):
|
|
do_sample = {
|
|
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
|
|
}
|
|
|
|
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
|
|
|
|
|
|
class TransformersImageProcessorLoader(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"ckpt_name": (get_huggingface_repo_list(),),
|
|
"subfolder": ("STRING", {}),
|
|
"model": ("MODEL", {}),
|
|
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = "MODEL",
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, model: TransformersManagedModel = None, overwrite_tokenizer: bool = False):
|
|
hub_kwargs = {}
|
|
if subfolder is not None and subfolder != "":
|
|
hub_kwargs["subfolder"] = subfolder
|
|
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
|
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
|
return model.patch_processor(processor, overwrite_tokenizer),
|
|
|
|
|
|
class TransformersLoader(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"ckpt_name": (get_huggingface_repo_list(),),
|
|
"subfolder": ("STRING", {})
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = "MODEL",
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
|
|
hub_kwargs = {}
|
|
if subfolder is not None and subfolder != "":
|
|
hub_kwargs["subfolder"] = subfolder
|
|
|
|
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
|
with comfy_tqdm():
|
|
from_pretrained_kwargs = {
|
|
"pretrained_model_name_or_path": ckpt_name,
|
|
"trust_remote_code": True,
|
|
**hub_kwargs
|
|
}
|
|
|
|
# if flash attention exists, use it
|
|
|
|
# compute bitsandbytes configuration
|
|
try:
|
|
import bitsandbytes
|
|
except ImportError:
|
|
pass
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
|
model_type = config_dict["model_type"]
|
|
# language models prefer to use bfloat16 over float16
|
|
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
|
"low_cpu_mem_usage": True,
|
|
"device_map": str(unet_offload_device()), }, {})
|
|
|
|
# if we have flash-attn installed, try to use it
|
|
try:
|
|
import flash_attn
|
|
attn_override_kwargs = {
|
|
"attn_implementation": "flash_attention_2",
|
|
**kwargs_to_try[0]
|
|
}
|
|
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
|
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
|
except ImportError:
|
|
pass
|
|
for i, props in enumerate(kwargs_to_try):
|
|
try:
|
|
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
|
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
|
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
|
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
|
else:
|
|
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
|
if model is not None:
|
|
break
|
|
except Exception as exc_info:
|
|
if i == len(kwargs_to_try) - 1:
|
|
raise exc_info
|
|
else:
|
|
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
|
|
|
for i, props in enumerate(kwargs_to_try):
|
|
try:
|
|
try:
|
|
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
|
except:
|
|
processor = None
|
|
if isinstance(processor, PreTrainedTokenizerBase):
|
|
tokenizer = processor
|
|
processor = None
|
|
else:
|
|
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
|
if tokenizer is not None or processor is not None:
|
|
break
|
|
except Exception as exc_info:
|
|
if i == len(kwargs_to_try) - 1:
|
|
raise exc_info
|
|
|
|
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
|
model.enable_xformers_memory_efficient_attention()
|
|
logging.debug("enabled xformers memory efficient attention")
|
|
|
|
model_managed = TransformersManagedModel(
|
|
repo_id=ckpt_name,
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
config_dict=config_dict,
|
|
processor=processor
|
|
)
|
|
return model_managed,
|
|
|
|
|
|
class TransformersTokenize(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
|
|
return model.tokenize(prompt, [], None),
|
|
|
|
|
|
class TransformersM2M100LanguageCodes(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"lang_id": (tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES["m2m100"], {"default": "en"}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, lang_id: str) -> ValidatedNodeResult:
|
|
return lang_id,
|
|
|
|
|
|
class TransformersFlores200LanguageCodes(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"lang_id": (tokenization_nllb_FAIRSEQ_LANGUAGE_CODES, {"default": "eng_Latn"}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, lang_id: str) -> ValidatedNodeResult:
|
|
return lang_id,
|
|
|
|
|
|
class TransformersTranslationTokenize(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
|
"src_lang": ("STRING", {}),
|
|
"tgt_lang": ("STRING", {}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: TransformersManagedModel, prompt: str, src_lang: str, tgt_lang: str) -> ValidatedNodeResult:
|
|
tokenizer = model.tokenizer
|
|
|
|
if hasattr(tokenizer, "src_lang"):
|
|
prev_src_lang = tokenizer.src_lang
|
|
else:
|
|
prev_src_lang = None
|
|
if hasattr(tokenizer, "tgt_lang"):
|
|
prev_tgt_lang = tokenizer.tgt_lang
|
|
else:
|
|
prev_tgt_lang = None
|
|
|
|
try:
|
|
if hasattr(tokenizer, "_build_translation_inputs"):
|
|
encoded = tokenizer._build_translation_inputs(
|
|
prompt, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
|
)
|
|
else:
|
|
tokenizer.src_lang = src_lang
|
|
tokenizer.tgt_lang = tgt_lang
|
|
|
|
encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
|
encoded["input_ids"] = encoded["input_ids"].to(device=model.load_device)
|
|
encoded["attention_mask"] = encoded["attention_mask"].to(device=model.load_device)
|
|
encoded["src_lang"] = src_lang
|
|
encoded["tgt_lang"] = tgt_lang
|
|
return encoded,
|
|
finally:
|
|
if prev_src_lang is not None:
|
|
tokenizer.src_lang = prev_src_lang
|
|
if prev_tgt_lang is not None:
|
|
tokenizer.tgt_lang = prev_tgt_lang
|
|
|
|
|
|
class OneShotInstructTokenize(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
|
"chat_template": ([_AUTO_CHAT_TEMPLATE] + list(KNOWN_CHAT_TEMPLATES.keys()), {})
|
|
},
|
|
"optional": {
|
|
"images": ("IMAGE", {}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
|
|
if chat_template == _AUTO_CHAT_TEMPLATE:
|
|
# use an exact match
|
|
model_name = os.path.basename(model.repo_id)
|
|
if model_name in KNOWN_CHAT_TEMPLATES:
|
|
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
|
|
else:
|
|
chat_template = None
|
|
else:
|
|
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
|
|
return model.tokenize(prompt, images, chat_template),
|
|
|
|
|
|
class TransformersGenerate(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
"tokens": (TOKENS_TYPE_NAME, {}),
|
|
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
|
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
|
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
|
|
"use_cache": ("BOOLEAN", {"default": True}),
|
|
},
|
|
"optional": {
|
|
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
RETURN_TYPES = ("STRING",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self,
|
|
model: Optional[TransformersManagedModel] = None,
|
|
tokens: _TOKENS_TYPE = None,
|
|
max_new_tokens: int = 512,
|
|
repetition_penalty: float = 0.0,
|
|
seed: int = 0,
|
|
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
|
|
*args,
|
|
**kwargs
|
|
):
|
|
tokens = copy.copy(tokens)
|
|
tokens_original = copy.copy(tokens)
|
|
sampler = sampler or {}
|
|
generate_kwargs = copy.copy(sampler)
|
|
load_models_gpu([model])
|
|
transformers_model: PreTrainedModel = model.model
|
|
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
|
# remove unused inputs
|
|
# maximizes compatibility with different models
|
|
generate_signature = inspect.signature(transformers_model.generate).parameters
|
|
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
|
|
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
|
|
gen_sig_keys = generate_signature.keys()
|
|
if "tgt_lang" in tokens:
|
|
to_delete.add("tgt_lang")
|
|
to_delete.add("src_lang")
|
|
to_delete.discard("input_ids")
|
|
if "forced_bos_token_id" in tokens:
|
|
to_delete.discard("forced_bos_token_id")
|
|
elif hasattr(tokenizer, "convert_tokens_to_ids"):
|
|
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
|
|
else:
|
|
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
|
|
if "input_ids" in tokens and "inputs" in tokens:
|
|
if "input_ids" in gen_sig_keys:
|
|
to_delete.add("inputs")
|
|
elif "inputs" in gen_sig_keys:
|
|
to_delete.add("input_ids")
|
|
for unused_kwarg in to_delete:
|
|
tokens.pop(unused_kwarg)
|
|
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
|
|
|
|
# images should be moved to model
|
|
for key in ("images", "pixel_values"):
|
|
if key in tokens:
|
|
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
|
|
|
|
# sets up inputs
|
|
inputs = tokens
|
|
|
|
# used to determine if text streaming is supported
|
|
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
|
|
|
|
progress_bar: ProgressBar
|
|
with comfy_progress(total=max_new_tokens) as progress_bar:
|
|
# todo: deal with batches correctly, don't assume batch size 1
|
|
token_count = 0
|
|
|
|
# progress
|
|
def on_finalized_text(next_token: str, stop: bool):
|
|
nonlocal token_count
|
|
nonlocal progress_bar
|
|
|
|
token_count += 1
|
|
preview = TransformerStreamedProgress(next_token=next_token)
|
|
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
|
|
|
|
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
|
|
|
with seed_for_block(seed):
|
|
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
|
inputs.pop("attention_mask")
|
|
output_ids = transformers_model.generate(
|
|
**inputs,
|
|
streamer=text_streamer if num_beams <= 1 else None,
|
|
max_new_tokens=max_new_tokens,
|
|
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
|
**generate_kwargs
|
|
)
|
|
|
|
if not transformers_model.config.is_encoder_decoder:
|
|
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
|
output_ids = output_ids[:, start_position:]
|
|
|
|
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
|
|
prev_src_lang = tokenizer.src_lang
|
|
tokenizer.src_lang = tokens_original["src_lang"]
|
|
else:
|
|
prev_src_lang = None
|
|
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
|
try:
|
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
finally:
|
|
if prev_src_lang is not None:
|
|
tokenizer.src_lang = prev_src_lang
|
|
# gpu-loaded stuff like images can now be unloaded
|
|
if hasattr(tokens, "to"):
|
|
del tokens
|
|
else:
|
|
for to_delete in tokens.values():
|
|
del to_delete
|
|
del tokens
|
|
|
|
# todo: better support batches
|
|
return outputs[0],
|
|
|
|
|
|
class PreviewString(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"value": ("STRING", {"forceInput": True}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "language"
|
|
FUNCTION = "execute"
|
|
RETURN_TYPES = ("STRING",)
|
|
OUTPUT_NODE = True
|
|
|
|
def execute(self, value: str):
|
|
return {"ui": {"string": [value]}}
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {}
|
|
for cls in (
|
|
TransformerTopKSampler,
|
|
TransformerTopPSampler,
|
|
TransformerTemperatureSampler,
|
|
TransformerGreedySampler,
|
|
TransformerContrastiveSearchSampler,
|
|
TransformerBeamSearchSampler,
|
|
TransformerMergeSamplers,
|
|
TransformersLoader,
|
|
TransformersImageProcessorLoader,
|
|
TransformersGenerate,
|
|
OneShotInstructTokenize,
|
|
TransformersM2M100LanguageCodes,
|
|
TransformersTokenize,
|
|
TransformersFlores200LanguageCodes,
|
|
TransformersTranslationTokenize,
|
|
PreviewString,
|
|
):
|
|
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|