From 657540946151f6b55c990754a9884d28dd5b98c6 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 6 Jun 2024 20:51:05 -0700 Subject: [PATCH] Additional chat templates to ease the use of many models. --- .../language/transformers_model_management.py | 21 +++- comfy_extras/nodes/nodes_language.py | 106 +++++++++++++----- comfy_extras/nodes/nodes_textdiffusers.py | 53 +++++---- requirements.txt | 1 + 4 files changed, 128 insertions(+), 53 deletions(-) diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index 55b47d81b..e318c729d 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -4,14 +4,14 @@ import warnings from typing import Optional, Any import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig from ..model_management import unet_offload_device, get_torch_device from ..model_management_types import ModelManageable class TransformersManagedModel(ModelManageable): - def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None): + def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None, config_dict: Optional[dict] = None): self.repo_id = repo_id self.model = model self.tokenizer = tokenizer @@ -19,10 +19,25 @@ class TransformersManagedModel(ModelManageable): self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values()) self.load_device = get_torch_device() self.offload_device = unet_offload_device() - + self._config_dict = config_dict if model.device != self.offload_device: model.to(device=self.offload_device) + @property + def config_dict(self) -> dict: + """ + The original configuration dictionary located in the Transformers model. + + Many models derive from base models and should inherit their settings like a chat template. This + config_dict will have the base model's name in _name_or_path, enabling a lookup for the valid + chat template when it is not specified by the derived model (it almost never is). + :return: the dict value of the config.json in the HuggingFace model + """ + if self._config_dict is not None: + return self._config_dict + + return self.model.config.to_dict() + @property def lowvram_patch_counter(self): return 0 diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index 1e33f7839..c2095d41d 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -1,23 +1,41 @@ from __future__ import annotations +import copy import logging import operator from functools import reduce +from importlib.resources import files +from importlib.resources.abc import Traversable +from pathlib import Path from typing import Any, Dict, Optional, List, Callable, TypedDict import torch from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ - PreTrainedTokenizerBase, LogitsProcessorList + PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_downloader import huggingface_repos from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device -from comfy.nodes.package_typing import CustomNode, InputTypes +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar # aka kwargs type _GENERATION_KWARGS_TYPE = Dict[str, Any] -_GENERATION_KWARGS_TYPE_NAME = "GENERATE_KWARGS" +_GENERATION_KWARGS_TYPE_NAME = "SAMPLER" +_TOKENS_TYPE = torch.Tensor +TOKENS_TYPE_NAME = "TOKENS" + +KNOWN_CHAT_TEMPLATES = {} + + +def _update_known_chat_templates(): + try: + _chat_templates: Traversable + with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates: + _extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()} + KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates) + except ImportError as exc: + logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc) class _ProgressTextStreamer(TextStreamer): @@ -193,20 +211,59 @@ class TransformersLoader(CustomNode): with comfy_tqdm(): model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs) tokenizer = AutoTokenizer.from_pretrained(ckpt_name) - model_managed = TransformersManagedModel(ckpt_name, model, tokenizer) + config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs) + model_managed = TransformersManagedModel(ckpt_name, model, tokenizer, config_dict) return model_managed, -class TransformerGenerate(CustomNode): +class OneShotInstructTokenize(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { "model": ("MODEL",), "prompt": ("STRING", {"default": "", "multiline": True}), + } + } + + CATEGORY = "language" + RETURN_TYPES = (TOKENS_TYPE_NAME,) + FUNCTION = "execute" + + def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult: + tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer + assert tokenizer is not None + assert hasattr(tokenizer, "decode") + + # try to retrieve a matching chat template + chat_template = tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None + if chat_template is None: + candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in model.config_dict["_name_or_path"] or name in model.model.name_or_path] + if len(candidate_chat_templates) > 0: + filename, chat_template = candidate_chat_templates[0] + logging.debug(f"Selected chat template filename={filename} for {model.model.name_or_path}") + try: + # todo: this should come from node inputs + prompt = tokenizer.apply_chat_template([ + {"role": "user", "content": prompt}, + ], chat_template=chat_template, add_generation_prompt=True, tokenize=False) + except Exception as exc: + logging.error("Could not apply chat template", exc_info=exc) + + return tokenizer(prompt, return_tensors="pt"), + + +class TransformersGenerate(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL",), + "tokens": (TOKENS_TYPE_NAME, {}), "max_new_tokens": ("INT", {"default": 512, "min": 1}), "repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}), - "seed": ("INT", {"default": 0}), + "seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}), + "use_cache": ("BOOLEAN", {"default": True}), }, "optional": { "images": ("IMAGE", {}), @@ -220,32 +277,27 @@ class TransformerGenerate(CustomNode): def execute(self, model: Optional[TransformersManagedModel] = None, - prompt: str = "", + tokens: _TOKENS_TYPE = None, max_new_tokens: int = 512, repetition_penalty: float = 0.0, seed: int = 0, - images: Optional[List[torch.Tensor]] = None, + images: Optional[List[torch.Tensor] | torch.Tensor] = None, sampler: Optional[_GENERATION_KWARGS_TYPE] = None, *args, **kwargs ): + sampler = sampler or {} + generate_kwargs = copy.copy(sampler) + # gracefully support LlaVA and others + if images is not None and not isinstance(images, torch.Tensor): + images = torch.stack(images, dim=0) + if images is not None: + generate_kwargs["images"] = images + # assuming it's of the form (batch, features..., height, width) + generate_kwargs["images_sizes"] = [(images.shape[-2], images.shape[-1]) for _ in range(images.shape[0])] load_model_gpu(model) - - if sampler is None: - sampler = {} - tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer - assert tokenizer is not None - assert hasattr(tokenizer, "decode") - - try: - # todo: this should come from node inputs - prompt = tokenizer.apply_chat_template([ - {"role": "user", "content": prompt}, - ], add_generation_prompt=True, tokenize=False) - except Exception as exc: - logging.error("Could not apply chat template", exc_info=exc) - inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device) + inputs = tokens.to(model.current_device) transformers_model: PreTrainedModel = model.model progress_logits_processor = _ProgressLogitsProcessor(model) progress_bar: ProgressBar @@ -264,7 +316,6 @@ class TransformerGenerate(CustomNode): value = max(eos_token_probability * max_new_tokens, token_count) preview = TransformerStreamedProgress(next_token=next_token) progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview) - pass text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) @@ -276,7 +327,7 @@ class TransformerGenerate(CustomNode): streamer=text_streamer, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty if repetition_penalty != 0 else None, - **sampler + **generate_kwargs ) if transformers_model.config.is_encoder_decoder: @@ -318,7 +369,10 @@ for cls in ( TransformerBeamSearchSampler, TransformerMergeSamplers, TransformersLoader, - TransformerGenerate, + TransformersGenerate, + OneShotInstructTokenize, PreviewString, ): NODE_CLASS_MAPPINGS[cls.__name__] = cls + +_update_known_chat_templates() diff --git a/comfy_extras/nodes/nodes_textdiffusers.py b/comfy_extras/nodes/nodes_textdiffusers.py index bd253f9c5..15af8a005 100644 --- a/comfy_extras/nodes/nodes_textdiffusers.py +++ b/comfy_extras/nodes/nodes_textdiffusers.py @@ -23,8 +23,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import string -from typing import Optional +from typing import Optional, List from comfy.language.transformers_model_management import TransformersManagedModel from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult @@ -32,7 +34,7 @@ from comfy.sd import CLIP from comfy.sd1_clip import SDTokenizer -class TextDiffuserTokens(CustomNode): +class TextDiffuserAddTokens(CustomNode): ALPHABET = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(alphabet) = 95 TOKENS = [] @@ -49,17 +51,17 @@ class TextDiffuserTokens(CustomNode): def execute(self, clip: CLIP): clip = clip.clone() - if len(TextDiffuserTokens.TOKENS) == 0: + if len(TextDiffuserAddTokens.TOKENS) == 0: for i in range(520): - TextDiffuserTokens.TOKENS.append(f'l{i}') - TextDiffuserTokens.TOKENS.append(f't{i}') - TextDiffuserTokens.TOKENS.append(f'r{i}') - TextDiffuserTokens.TOKENS.append(f'b{i}') - for c in TextDiffuserTokens.ALPHABET: - TextDiffuserTokens.TOKENS.append(f'[{c}]') + TextDiffuserAddTokens.TOKENS.append(f'l{i}') + TextDiffuserAddTokens.TOKENS.append(f't{i}') + TextDiffuserAddTokens.TOKENS.append(f'r{i}') + TextDiffuserAddTokens.TOKENS.append(f'b{i}') + for c in TextDiffuserAddTokens.ALPHABET: + TextDiffuserAddTokens.TOKENS.append(f'[{c}]') tokenizer: SDTokenizer = clip.tokenizer.sd_tokenizer existing_vocab = frozenset(tokenizer.tokenizer.get_vocab().keys()) - tokens = [t for t in TextDiffuserTokens.TOKENS if t not in existing_vocab] + tokens = [t for t in TextDiffuserAddTokens.TOKENS if t not in existing_vocab] if len(tokens) != 0: tokenizer.add_tokens(tokens) @@ -67,15 +69,15 @@ class TextDiffuserTokens(CustomNode): return clip, -class TextDiffuserPrepare(CustomNode): +class TextDiffuserPrepareInstructPrompt(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "prompt": ("STRING", {"default": "", "multiline": True}), + "text": ("STRING", {"default": "", "multiline": True}), }, "optional": { - "text": ("STRING", {"default": "", "multiline": True}) + "text_to_render": ("STRING", {"default": "", "multiline": True}) } } @@ -83,27 +85,27 @@ class TextDiffuserPrepare(CustomNode): RETURN_TYPES = "STRING", RETURN_NAMES = "INSTRUCT STRING", - def execute(self, prompt: str, text: Optional[str] = None, *args, **kwargs) -> ValidatedNodeResult: - keywords = text.split("\n") + def execute(self, text: str, text_to_render: Optional[str] = None, *args, **kwargs) -> ValidatedNodeResult: + keywords = text_to_render.split("\n") if len(keywords) > 0: # text diffusers does indeed format keywords as # ['some', 'word'] - message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}. Keywords: {str(keywords)}' + message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {text}. Keywords: {str(keywords)}' else: - message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}' + message = f'Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {text}' return message, -class TextDiffuserDecodeLayout(CustomNode): +class TextDiffuserDecodeLayoutString2ClipString(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { "layout_model": ("MODEL", {}), "clip": ("CLIP", {}), - "prompt": ("STRING", {}), - "instruct_response": ("STRING", {}) + "prompt": ("STRING", {"forceInput": True}), + "instruct_response": ("STRING", {"forceInput": True}) } } @@ -111,7 +113,10 @@ class TextDiffuserDecodeLayout(CustomNode): RETURN_TYPES = "STRING", RETURN_NAMES = "CLIP STRING", - def execute(self, layout_model: TransformersManagedModel, clip: CLIP, prompt: str, instruct_response: str, *args, **kwargs) -> ValidatedNodeResult: + def execute(self, layout_model: TransformersManagedModel, clip: CLIP, prompt: str, instruct_response: str | List[str], *args, **kwargs) -> ValidatedNodeResult: + # todo: better support for batching + if isinstance(instruct_response, List): + instruct_response = instruct_response[0] current_ocr = instruct_response.split('\n') words = [clip.tokenizer.sd_tokenizer.tokenizer.eos_token, clip.tokenizer.sd_tokenizer.tokenizer.bos_token] for ocr in current_ocr: @@ -136,8 +141,8 @@ class TextDiffuserDecodeLayout(CustomNode): NODE_CLASS_MAPPINGS = {} for cls in ( - TextDiffuserDecodeLayout, - TextDiffuserPrepare, - TextDiffuserTokens, + TextDiffuserDecodeLayoutString2ClipString, + TextDiffuserPrepareInstructPrompt, + TextDiffuserAddTokens, ): NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/requirements.txt b/requirements.txt index 31c2ccb90..c4dcf9ce1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,6 +49,7 @@ opentelemetry-util-http opentelemetry-instrumentation-aio-pika opentelemetry-instrumentation-requests opentelemetry-semantic-conventions +huggingface_extra_chat_templates @ git+https://github.com/AppMana/appmana-comfyui-chat-templates.git wrapt>=1.16.0 certifi spandrel \ No newline at end of file