From 7f300bcb7a8c73e70213dada80a6a6e7079313d5 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 7 Jun 2024 16:23:10 -0700 Subject: [PATCH] Multi-modal LLM support and ongoing improvements to language features. --- comfy/language/__init__.py | 3 + comfy/language/chat_templates.py | 18 ++ comfy/language/language_types.py | 24 +++ .../language/transformers_model_management.py | 105 +++++++++- comfy/model_downloader.py | 1 + comfy_extras/nodes/nodes_language.py | 192 ++++++++++++------ comfy_extras/nodes/nodes_open_api.py | 7 +- 7 files changed, 280 insertions(+), 70 deletions(-) create mode 100644 comfy/language/chat_templates.py create mode 100644 comfy/language/language_types.py diff --git a/comfy/language/__init__.py b/comfy/language/__init__.py index e69de29bb..cb63896d4 100644 --- a/comfy/language/__init__.py +++ b/comfy/language/__init__.py @@ -0,0 +1,3 @@ +from .chat_templates import _update_known_chat_templates + +_update_known_chat_templates() diff --git a/comfy/language/chat_templates.py b/comfy/language/chat_templates.py new file mode 100644 index 000000000..64b9b2de1 --- /dev/null +++ b/comfy/language/chat_templates.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import logging +from importlib.abc import Traversable +from importlib.resources import files +from pathlib import Path + +KNOWN_CHAT_TEMPLATES = {} + + +def _update_known_chat_templates(): + try: + _chat_templates: Traversable + with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates: + _extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()} + KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates) + except ImportError as exc: + logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc) diff --git a/comfy/language/language_types.py b/comfy/language/language_types.py new file mode 100644 index 000000000..a6d26f64f --- /dev/null +++ b/comfy/language/language_types.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import torch +from transformers import BatchEncoding +from typing_extensions import TypedDict, NotRequired + + +class ProcessorResult(TypedDict): + """ + Attributes: + attention_mask: attention mask + pixel_values: post image-processed values + + images: used for LLaVA compatibility and points to pixel_values + inputs: used for LLaVA compatibility and points to input_ids + images_sizes: used for LLaVA compatibility, stores the (width, height) tuples of the original input images + """ + + attention_mask: NotRequired[torch.Tensor] + pixel_values: NotRequired[torch.Tensor] + + images: NotRequired[torch.Tensor] + inputs: BatchEncoding + image_sizes: NotRequired[torch.Tensor] diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index e318c729d..cf6f51d8d 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -1,28 +1,66 @@ from __future__ import annotations +import copy +import logging import warnings -from typing import Optional, Any +from typing import Optional, Any, Callable, Union, List +import numpy as np import torch -from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig +from PIL.Image import Image +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \ + TensorType, BatchFeature +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import PaddingStrategy +from .chat_templates import KNOWN_CHAT_TEMPLATES +from .language_types import ProcessorResult from ..model_management import unet_offload_device, get_torch_device from ..model_management_types import ModelManageable +LLaVAProcessor = Callable[ + [ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter + Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter + Union[bool, str, PaddingStrategy], # padding parameter + Union[bool, str, TruncationStrategy], # truncation parameter + Optional[int], # max_length parameter + Optional[Union[str, TensorType]] # return_tensors parameter + ], + BatchFeature +] + class TransformersManagedModel(ModelManageable): - def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None, config_dict: Optional[dict] = None): + def __init__( + self, + repo_id: str, + model: PreTrainedModel, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + config_dict: Optional[dict] = None, + processor: Optional[ProcessorMixin | AutoProcessor] = None + ): self.repo_id = repo_id self.model = model - self.tokenizer = tokenizer + self._tokenizer = tokenizer + self._processor = processor self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values()) self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values()) self.load_device = get_torch_device() self.offload_device = unet_offload_device() self._config_dict = config_dict + self._on_set_processor(self._processor) if model.device != self.offload_device: model.to(device=self.offload_device) + @property + def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer: + return self._tokenizer + + @property + def processor(self) -> AutoProcessor | ProcessorMixin | LLaVAProcessor | None: + return self._processor + @property def config_dict(self) -> dict: """ @@ -67,7 +105,10 @@ class TransformersManagedModel(ModelManageable): if not self.is_clone(clone): return False - return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters()) + try: + return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters()) + except ValueError as no_adapters: + return True def model_size(self) -> int: return self._size @@ -92,3 +133,57 @@ class TransformersManagedModel(ModelManageable): def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: warnings.warn("Transformers models do not currently support adapters like LoRAs") return self.model.to(device=offload_device) + + def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel: + model = copy.copy(self) + model._processor = processor + if hasattr(processor, "tokenizer") and overwrite_tokenizer: + model._tokenizer = processor.tokenizer + self._on_set_processor(model._processor) + return model + + def _on_set_processor(self, processor: Any): + if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"): + processor.image_processor.do_rescale = False + + def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str) -> ProcessorResult: + tokenizer = self.tokenizer + assert tokenizer is not None + assert hasattr(tokenizer, "decode") + + # try to retrieve a matching chat template + chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None + if chat_template is None: + candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path] + if len(candidate_chat_templates) > 0: + filename, chat_template = candidate_chat_templates[0] + logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}") + try: + # todo: this should come from node inputs + prompt = tokenizer.apply_chat_template([ + {"role": "user", "content": prompt}, + ], chat_template=chat_template, add_generation_prompt=True, tokenize=False) + except Exception as exc: + logging.error("Could not apply chat template", exc_info=exc) + + if self.processor is None: + batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device) + return {**batch_encoding} + else: + assert images is not None and len(images) > 0, "When using a multi-modal model, pass at least one, possibly empty, image" + if hasattr(self.processor, "to"): + self.processor.to(device=self.load_device) + + assert "" in prompt, "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor" + batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt") + if hasattr(self.processor, "to"): + self.processor.to(device=self.offload_device) + assert "input_ids" in batch_feature + batch_feature.to(device=self.load_device, dtype=self.model_dtype()) + # noinspection PyTypeChecker + return { + "image_sizes": [(images.shape[-1], image.shape[-2]) for image in images], + "images": batch_feature["pixel_values"], + "inputs": batch_feature["input_ids"], + **batch_feature + } diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index d6af7b7a7..f983a94b3 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -303,6 +303,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS = { "JingyeChen22/textdiffuser2_layout_planner", 'JingyeChen22/textdiffuser2-full-ft', "microsoft/Phi-3-mini-4k-instruct", + "llava-hf/llava-v1.6-mistral-7b-hf" } KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [ diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index c2095d41d..d097a18e7 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -1,42 +1,44 @@ from __future__ import annotations import copy +import inspect import logging import operator +import os.path from functools import reduce -from importlib.resources import files -from importlib.resources.abc import Traversable -from pathlib import Path -from typing import Any, Dict, Optional, List, Callable, TypedDict +from typing import Any, Dict, Optional, List, Callable, Union import torch from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ - PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig + PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \ + LlavaNextForConditionalGeneration, LlavaNextProcessor +from typing_extensions import TypedDict +from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES +from comfy.language.language_types import ProcessorResult from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_downloader import huggingface_repos from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar +_AUTO_CHAT_TEMPLATE = "default" + +# add llava support +try: + from llava import model + + logging.info("Additional LLaVA models are now supported") +except ImportError as exc: + logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support") + # aka kwargs type _GENERATION_KWARGS_TYPE = Dict[str, Any] _GENERATION_KWARGS_TYPE_NAME = "SAMPLER" -_TOKENS_TYPE = torch.Tensor + +_TOKENS_TYPE = Union[ProcessorResult, BatchFeature] TOKENS_TYPE_NAME = "TOKENS" -KNOWN_CHAT_TEMPLATES = {} - - -def _update_known_chat_templates(): - try: - _chat_templates: Traversable - with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates: - _extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()} - KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates) - except ImportError as exc: - logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc) - class _ProgressTextStreamer(TextStreamer): def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): @@ -190,13 +192,15 @@ class TransformerMergeSamplers(CustomNode): return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),) -class TransformersLoader(CustomNode): +class TransformersImageProcessorLoader(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { "ckpt_name": (huggingface_repos(),), - "subfolder": ("STRING", {}) + "subfolder": ("STRING", {}), + "model": ("MODEL", {}), + "overwrite_tokenizer": ("BOOLEAN", {"default": False}), } } @@ -204,15 +208,66 @@ class TransformersLoader(CustomNode): RETURN_TYPES = "MODEL", FUNCTION = "execute" - def execute(self, ckpt_name: str, subfolder: Optional[str] = None): + def execute(self, ckpt_name: str, subfolder: Optional[str] = None, model: TransformersManagedModel = None, overwrite_tokenizer: bool = False): + hub_kwargs = {} + if subfolder is not None and subfolder != "": + hub_kwargs["subfolder"] = subfolder + processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs) + return model.patch_processor(processor, overwrite_tokenizer), + + +class TransformersLoader(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "ckpt_name": (huggingface_repos(),), + "subfolder": ("STRING", {}) + }, + } + + CATEGORY = "language" + RETURN_TYPES = "MODEL", + FUNCTION = "execute" + + def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs): hub_kwargs = {} if subfolder is not None and subfolder != "": hub_kwargs["subfolder"] = subfolder with comfy_tqdm(): - model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs) - tokenizer = AutoTokenizer.from_pretrained(ckpt_name) + from_pretrained_kwargs = { + "pretrained_model_name_or_path": ckpt_name, + "torch_dtype": unet_dtype(), + "device_map": get_torch_device_name(unet_offload_device()), + "low_cpu_mem_usage": True, + "trust_remote_code": True, + **hub_kwargs + } + + try: + model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs) + except: + model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs) + config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs) - model_managed = TransformersManagedModel(ckpt_name, model, tokenizer, config_dict) + try: + try: + processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs) + except: + processor = LlavaNextProcessor.from_pretrained(**from_pretrained_kwargs) + except: + processor = None + if not isinstance(processor, ProcessorMixin): + processor = None + tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs) + + model_managed = TransformersManagedModel( + repo_id=ckpt_name, + model=model, + tokenizer=tokenizer, + config_dict=config_dict, + processor=processor + ) return model_managed, @@ -223,6 +278,10 @@ class OneShotInstructTokenize(CustomNode): "required": { "model": ("MODEL",), "prompt": ("STRING", {"default": "", "multiline": True}), + "chat_template": ([_AUTO_CHAT_TEMPLATE] + list(KNOWN_CHAT_TEMPLATES.keys()), {}) + }, + "optional": { + "images": ("IMAGE", {}), } } @@ -230,27 +289,17 @@ class OneShotInstructTokenize(CustomNode): RETURN_TYPES = (TOKENS_TYPE_NAME,) FUNCTION = "execute" - def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult: - tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer - assert tokenizer is not None - assert hasattr(tokenizer, "decode") - - # try to retrieve a matching chat template - chat_template = tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None - if chat_template is None: - candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in model.config_dict["_name_or_path"] or name in model.model.name_or_path] - if len(candidate_chat_templates) > 0: - filename, chat_template = candidate_chat_templates[0] - logging.debug(f"Selected chat template filename={filename} for {model.model.name_or_path}") - try: - # todo: this should come from node inputs - prompt = tokenizer.apply_chat_template([ - {"role": "user", "content": prompt}, - ], chat_template=chat_template, add_generation_prompt=True, tokenize=False) - except Exception as exc: - logging.error("Could not apply chat template", exc_info=exc) - - return tokenizer(prompt, return_tensors="pt"), + def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult: + if chat_template == _AUTO_CHAT_TEMPLATE: + # use an exact match + model_name = os.path.basename(model.repo_id) + if model_name in KNOWN_CHAT_TEMPLATES: + chat_template = KNOWN_CHAT_TEMPLATES[model_name] + else: + chat_template = None + else: + chat_template = KNOWN_CHAT_TEMPLATES[chat_template] + return model.tokenize(prompt, images, chat_template), class TransformersGenerate(CustomNode): @@ -266,7 +315,6 @@ class TransformersGenerate(CustomNode): "use_cache": ("BOOLEAN", {"default": True}), }, "optional": { - "images": ("IMAGE", {}), "sampler": (_GENERATION_KWARGS_TYPE_NAME, {}), } } @@ -281,24 +329,36 @@ class TransformersGenerate(CustomNode): max_new_tokens: int = 512, repetition_penalty: float = 0.0, seed: int = 0, - images: Optional[List[torch.Tensor] | torch.Tensor] = None, sampler: Optional[_GENERATION_KWARGS_TYPE] = None, *args, **kwargs ): + tokens = copy.copy(tokens) sampler = sampler or {} generate_kwargs = copy.copy(sampler) - # gracefully support LlaVA and others - if images is not None and not isinstance(images, torch.Tensor): - images = torch.stack(images, dim=0) - if images is not None: - generate_kwargs["images"] = images - # assuming it's of the form (batch, features..., height, width) - generate_kwargs["images_sizes"] = [(images.shape[-2], images.shape[-1]) for _ in range(images.shape[0])] load_model_gpu(model) - tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer - inputs = tokens.to(model.current_device) transformers_model: PreTrainedModel = model.model + tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer + # remove unused inputs + # maximizes compatibility with different models + generate_signature = inspect.signature(transformers_model.generate).parameters + prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters + to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature]))) + gen_sig_keys = generate_signature.keys() + if "input_ids" in tokens and "inputs" in tokens: + if "input_ids" in gen_sig_keys: + to_delete.add("inputs") + elif "inputs" in gen_sig_keys: + to_delete.add("input_ids") + for unused_kwarg in to_delete: + tokens.pop(unused_kwarg) + logging.info(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing") + + # images should be moved to model + for key in ("images", "pixel_values"): + if key in tokens: + tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype()) + inputs = tokens progress_logits_processor = _ProgressLogitsProcessor(model) progress_bar: ProgressBar with comfy_progress(total=max_new_tokens) as progress_bar: @@ -320,9 +380,8 @@ class TransformersGenerate(CustomNode): text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) with seed_for_block(seed): - # load the model as close to the actual generation as possible output_ids = transformers_model.generate( - inputs.input_ids, + **inputs, logits_processor=LogitsProcessorList([progress_logits_processor]), streamer=text_streamer, max_new_tokens=max_new_tokens, @@ -333,12 +392,22 @@ class TransformersGenerate(CustomNode): if transformers_model.config.is_encoder_decoder: start_position = 1 else: - start_position = inputs.input_ids.shape[1] + start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1] output_ids = output_ids[:, start_position:] # todo: is this redundant consider I'm decoding in the on_finalized_text block? outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - return outputs, + + # gpu-loaded stuff like images can now be unloaded + if hasattr(tokens, "to"): + del tokens + else: + for to_delete in tokens.values(): + del to_delete + del tokens + + # todo: better support batches + return outputs[0], class PreviewString(CustomNode): @@ -369,10 +438,9 @@ for cls in ( TransformerBeamSearchSampler, TransformerMergeSamplers, TransformersLoader, + TransformersImageProcessorLoader, TransformersGenerate, OneShotInstructTokenize, PreviewString, ): NODE_CLASS_MAPPINGS[cls.__name__] = cls - -_update_known_chat_templates() diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index b71792a72..714240efe 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -284,10 +284,11 @@ class DevNullUris(CustomNode): class StringJoin(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: - required = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)} - required["separator"] = ("STRING", {"default": "_"}) + optional = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)} + optional["separator"] = ("STRING", {"default": "_"}) return { - "required": required + "required": {}, + "optional": optional } RETURN_TYPES = ("STRING",)