Additional chat templates to ease the use of many models.

This commit is contained in:
doctorpangloss 2024-06-06 20:51:05 -07:00
parent ebf2ef27c7
commit 6575409461
4 changed files with 128 additions and 53 deletions

View File

@ -4,14 +4,14 @@ import warnings
from typing import Optional, Any
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig
from ..model_management import unet_offload_device, get_torch_device
from ..model_management_types import ModelManageable
class TransformersManagedModel(ModelManageable):
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None):
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None, config_dict: Optional[dict] = None):
self.repo_id = repo_id
self.model = model
self.tokenizer = tokenizer
@ -19,10 +19,25 @@ class TransformersManagedModel(ModelManageable):
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device()
self.offload_device = unet_offload_device()
self._config_dict = config_dict
if model.device != self.offload_device:
model.to(device=self.offload_device)
@property
def config_dict(self) -> dict:
"""
The original configuration dictionary located in the Transformers model.
Many models derive from base models and should inherit their settings like a chat template. This
config_dict will have the base model's name in _name_or_path, enabling a lookup for the valid
chat template when it is not specified by the derived model (it almost never is).
:return: the dict value of the config.json in the HuggingFace model
"""
if self._config_dict is not None:
return self._config_dict
return self.model.config.to_dict()
@property
def lowvram_patch_counter(self):
return 0

View File

@ -1,23 +1,41 @@
from __future__ import annotations
import copy
import logging
import operator
from functools import reduce
from importlib.resources import files
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Any, Dict, Optional, List, Callable, TypedDict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, LogitsProcessorList
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
# aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any]
_GENERATION_KWARGS_TYPE_NAME = "GENERATE_KWARGS"
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
_TOKENS_TYPE = torch.Tensor
TOKENS_TYPE_NAME = "TOKENS"
KNOWN_CHAT_TEMPLATES = {}
def _update_known_chat_templates():
try:
_chat_templates: Traversable
with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates:
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
except ImportError as exc:
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
class _ProgressTextStreamer(TextStreamer):
@ -193,20 +211,59 @@ class TransformersLoader(CustomNode):
with comfy_tqdm():
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer, config_dict)
return model_managed,
class TransformerGenerate(CustomNode):
class OneShotInstructTokenize(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
}
}
CATEGORY = "language"
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
assert tokenizer is not None
assert hasattr(tokenizer, "decode")
# try to retrieve a matching chat template
chat_template = tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
if chat_template is None:
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in model.config_dict["_name_or_path"] or name in model.model.name_or_path]
if len(candidate_chat_templates) > 0:
filename, chat_template = candidate_chat_templates[0]
logging.debug(f"Selected chat template filename={filename} for {model.model.name_or_path}")
try:
# todo: this should come from node inputs
prompt = tokenizer.apply_chat_template([
{"role": "user", "content": prompt},
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
except Exception as exc:
logging.error("Could not apply chat template", exc_info=exc)
return tokenizer(prompt, return_tensors="pt"),
class TransformersGenerate(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"tokens": (TOKENS_TYPE_NAME, {}),
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": ("INT", {"default": 0}),
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
"use_cache": ("BOOLEAN", {"default": True}),
},
"optional": {
"images": ("IMAGE", {}),
@ -220,32 +277,27 @@ class TransformerGenerate(CustomNode):
def execute(self,
model: Optional[TransformersManagedModel] = None,
prompt: str = "",
tokens: _TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
images: Optional[List[torch.Tensor]] = None,
images: Optional[List[torch.Tensor] | torch.Tensor] = None,
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs
):
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
# gracefully support LlaVA and others
if images is not None and not isinstance(images, torch.Tensor):
images = torch.stack(images, dim=0)
if images is not None:
generate_kwargs["images"] = images
# assuming it's of the form (batch, features..., height, width)
generate_kwargs["images_sizes"] = [(images.shape[-2], images.shape[-1]) for _ in range(images.shape[0])]
load_model_gpu(model)
if sampler is None:
sampler = {}
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
assert tokenizer is not None
assert hasattr(tokenizer, "decode")
try:
# todo: this should come from node inputs
prompt = tokenizer.apply_chat_template([
{"role": "user", "content": prompt},
], add_generation_prompt=True, tokenize=False)
except Exception as exc:
logging.error("Could not apply chat template", exc_info=exc)
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
inputs = tokens.to(model.current_device)
transformers_model: PreTrainedModel = model.model
progress_logits_processor = _ProgressLogitsProcessor(model)
progress_bar: ProgressBar
@ -264,7 +316,6 @@ class TransformerGenerate(CustomNode):
value = max(eos_token_probability * max_new_tokens, token_count)
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview)
pass
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
@ -276,7 +327,7 @@ class TransformerGenerate(CustomNode):
streamer=text_streamer,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**sampler
**generate_kwargs
)
if transformers_model.config.is_encoder_decoder:
@ -318,7 +369,10 @@ for cls in (
TransformerBeamSearchSampler,
TransformerMergeSamplers,
TransformersLoader,
TransformerGenerate,
TransformersGenerate,
OneShotInstructTokenize,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls
_update_known_chat_templates()

View File

@ -23,8 +23,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import annotations
import string
from typing import Optional
from typing import Optional, List
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
@ -32,7 +34,7 @@ from comfy.sd import CLIP
from comfy.sd1_clip import SDTokenizer
class TextDiffuserTokens(CustomNode):
class TextDiffuserAddTokens(CustomNode):
ALPHABET = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(alphabet) = 95
TOKENS = []
@ -49,17 +51,17 @@ class TextDiffuserTokens(CustomNode):
def execute(self, clip: CLIP):
clip = clip.clone()
if len(TextDiffuserTokens.TOKENS) == 0:
if len(TextDiffuserAddTokens.TOKENS) == 0:
for i in range(520):
TextDiffuserTokens.TOKENS.append(f'l{i}</w>')
TextDiffuserTokens.TOKENS.append(f't{i}</w>')
TextDiffuserTokens.TOKENS.append(f'r{i}</w>')
TextDiffuserTokens.TOKENS.append(f'b{i}</w>')
for c in TextDiffuserTokens.ALPHABET:
TextDiffuserTokens.TOKENS.append(f'[{c}]</w>')
TextDiffuserAddTokens.TOKENS.append(f'l{i}</w>')
TextDiffuserAddTokens.TOKENS.append(f't{i}</w>')
TextDiffuserAddTokens.TOKENS.append(f'r{i}</w>')
TextDiffuserAddTokens.TOKENS.append(f'b{i}</w>')
for c in TextDiffuserAddTokens.ALPHABET:
TextDiffuserAddTokens.TOKENS.append(f'[{c}]</w>')
tokenizer: SDTokenizer = clip.tokenizer.sd_tokenizer
existing_vocab = frozenset(tokenizer.tokenizer.get_vocab().keys())
tokens = [t for t in TextDiffuserTokens.TOKENS if t not in existing_vocab]
tokens = [t for t in TextDiffuserAddTokens.TOKENS if t not in existing_vocab]
if len(tokens) != 0:
tokenizer.add_tokens(tokens)
@ -67,15 +69,15 @@ class TextDiffuserTokens(CustomNode):
return clip,
class TextDiffuserPrepare(CustomNode):
class TextDiffuserPrepareInstructPrompt(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"prompt": ("STRING", {"default": "", "multiline": True}),
"text": ("STRING", {"default": "", "multiline": True}),
},
"optional": {
"text": ("STRING", {"default": "", "multiline": True})
"text_to_render": ("STRING", {"default": "", "multiline": True})
}
}
@ -83,27 +85,27 @@ class TextDiffuserPrepare(CustomNode):
RETURN_TYPES = "STRING",
RETURN_NAMES = "INSTRUCT STRING",
def execute(self, prompt: str, text: Optional[str] = None, *args, **kwargs) -> ValidatedNodeResult:
keywords = text.split("\n")
def execute(self, text: str, text_to_render: Optional[str] = None, *args, **kwargs) -> ValidatedNodeResult:
keywords = text_to_render.split("\n")
if len(keywords) > 0:
# text diffusers does indeed format keywords as
# ['some', 'word']
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}. Keywords: {str(keywords)}'
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {text}. Keywords: {str(keywords)}'
else:
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}'
message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {text}'
return message,
class TextDiffuserDecodeLayout(CustomNode):
class TextDiffuserDecodeLayoutString2ClipString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"layout_model": ("MODEL", {}),
"clip": ("CLIP", {}),
"prompt": ("STRING", {}),
"instruct_response": ("STRING", {})
"prompt": ("STRING", {"forceInput": True}),
"instruct_response": ("STRING", {"forceInput": True})
}
}
@ -111,7 +113,10 @@ class TextDiffuserDecodeLayout(CustomNode):
RETURN_TYPES = "STRING",
RETURN_NAMES = "CLIP STRING",
def execute(self, layout_model: TransformersManagedModel, clip: CLIP, prompt: str, instruct_response: str, *args, **kwargs) -> ValidatedNodeResult:
def execute(self, layout_model: TransformersManagedModel, clip: CLIP, prompt: str, instruct_response: str | List[str], *args, **kwargs) -> ValidatedNodeResult:
# todo: better support for batching
if isinstance(instruct_response, List):
instruct_response = instruct_response[0]
current_ocr = instruct_response.split('\n')
words = [clip.tokenizer.sd_tokenizer.tokenizer.eos_token, clip.tokenizer.sd_tokenizer.tokenizer.bos_token]
for ocr in current_ocr:
@ -136,8 +141,8 @@ class TextDiffuserDecodeLayout(CustomNode):
NODE_CLASS_MAPPINGS = {}
for cls in (
TextDiffuserDecodeLayout,
TextDiffuserPrepare,
TextDiffuserTokens,
TextDiffuserDecodeLayoutString2ClipString,
TextDiffuserPrepareInstructPrompt,
TextDiffuserAddTokens,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -49,6 +49,7 @@ opentelemetry-util-http
opentelemetry-instrumentation-aio-pika
opentelemetry-instrumentation-requests
opentelemetry-semantic-conventions
huggingface_extra_chat_templates @ git+https://github.com/AppMana/appmana-comfyui-chat-templates.git
wrapt>=1.16.0
certifi
spandrel