Multi-modal LLM support and ongoing improvements to language features.

This commit is contained in:
doctorpangloss 2024-06-07 16:23:10 -07:00
parent 6575409461
commit 7f300bcb7a
7 changed files with 280 additions and 70 deletions

View File

@ -0,0 +1,3 @@
from .chat_templates import _update_known_chat_templates
_update_known_chat_templates()

View File

@ -0,0 +1,18 @@
from __future__ import annotations
import logging
from importlib.abc import Traversable
from importlib.resources import files
from pathlib import Path
KNOWN_CHAT_TEMPLATES = {}
def _update_known_chat_templates():
try:
_chat_templates: Traversable
with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates:
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
except ImportError as exc:
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)

View File

@ -0,0 +1,24 @@
from __future__ import annotations
import torch
from transformers import BatchEncoding
from typing_extensions import TypedDict, NotRequired
class ProcessorResult(TypedDict):
"""
Attributes:
attention_mask: attention mask
pixel_values: post image-processed values
images: used for LLaVA compatibility and points to pixel_values
inputs: used for LLaVA compatibility and points to input_ids
images_sizes: used for LLaVA compatibility, stores the (width, height) tuples of the original input images
"""
attention_mask: NotRequired[torch.Tensor]
pixel_values: NotRequired[torch.Tensor]
images: NotRequired[torch.Tensor]
inputs: BatchEncoding
image_sizes: NotRequired[torch.Tensor]

View File

@ -1,28 +1,66 @@
from __future__ import annotations
import copy
import logging
import warnings
from typing import Optional, Any
from typing import Optional, Any, Callable, Union, List
import numpy as np
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig
from PIL.Image import Image
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
TensorType, BatchFeature
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import PaddingStrategy
from .chat_templates import KNOWN_CHAT_TEMPLATES
from .language_types import ProcessorResult
from ..model_management import unet_offload_device, get_torch_device
from ..model_management_types import ModelManageable
LLaVAProcessor = Callable[
[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
Union[bool, str, PaddingStrategy], # padding parameter
Union[bool, str, TruncationStrategy], # truncation parameter
Optional[int], # max_length parameter
Optional[Union[str, TensorType]] # return_tensors parameter
],
BatchFeature
]
class TransformersManagedModel(ModelManageable):
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None, config_dict: Optional[dict] = None):
def __init__(
self,
repo_id: str,
model: PreTrainedModel,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
config_dict: Optional[dict] = None,
processor: Optional[ProcessorMixin | AutoProcessor] = None
):
self.repo_id = repo_id
self.model = model
self.tokenizer = tokenizer
self._tokenizer = tokenizer
self._processor = processor
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device()
self.offload_device = unet_offload_device()
self._config_dict = config_dict
self._on_set_processor(self._processor)
if model.device != self.offload_device:
model.to(device=self.offload_device)
@property
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
return self._tokenizer
@property
def processor(self) -> AutoProcessor | ProcessorMixin | LLaVAProcessor | None:
return self._processor
@property
def config_dict(self) -> dict:
"""
@ -67,7 +105,10 @@ class TransformersManagedModel(ModelManageable):
if not self.is_clone(clone):
return False
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
try:
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
except ValueError as no_adapters:
return True
def model_size(self) -> int:
return self._size
@ -92,3 +133,57 @@ class TransformersManagedModel(ModelManageable):
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=offload_device)
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
model = copy.copy(self)
model._processor = processor
if hasattr(processor, "tokenizer") and overwrite_tokenizer:
model._tokenizer = processor.tokenizer
self._on_set_processor(model._processor)
return model
def _on_set_processor(self, processor: Any):
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
processor.image_processor.do_rescale = False
def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str) -> ProcessorResult:
tokenizer = self.tokenizer
assert tokenizer is not None
assert hasattr(tokenizer, "decode")
# try to retrieve a matching chat template
chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
if chat_template is None:
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path]
if len(candidate_chat_templates) > 0:
filename, chat_template = candidate_chat_templates[0]
logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}")
try:
# todo: this should come from node inputs
prompt = tokenizer.apply_chat_template([
{"role": "user", "content": prompt},
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
except Exception as exc:
logging.error("Could not apply chat template", exc_info=exc)
if self.processor is None:
batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device)
return {**batch_encoding}
else:
assert images is not None and len(images) > 0, "When using a multi-modal model, pass at least one, possibly empty, image"
if hasattr(self.processor, "to"):
self.processor.to(device=self.load_device)
assert "<image>" in prompt, "You must specify a &lt;image&gt; token inside the prompt for it to be substituted correctly by a HuggingFace processor"
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device)
assert "input_ids" in batch_feature
batch_feature.to(device=self.load_device, dtype=self.model_dtype())
# noinspection PyTypeChecker
return {
"image_sizes": [(images.shape[-1], image.shape[-2]) for image in images],
"images": batch_feature["pixel_values"],
"inputs": batch_feature["input_ids"],
**batch_feature
}

View File

@ -303,6 +303,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS = {
"JingyeChen22/textdiffuser2_layout_planner",
'JingyeChen22/textdiffuser2-full-ft',
"microsoft/Phi-3-mini-4k-instruct",
"llava-hf/llava-v1.6-mistral-7b-hf"
}
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [

View File

@ -1,42 +1,44 @@
from __future__ import annotations
import copy
import inspect
import logging
import operator
import os.path
from functools import reduce
from importlib.resources import files
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Any, Dict, Optional, List, Callable, TypedDict
from typing import Any, Dict, Optional, List, Callable, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
LlavaNextForConditionalGeneration, LlavaNextProcessor
from typing_extensions import TypedDict
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
_AUTO_CHAT_TEMPLATE = "default"
# add llava support
try:
from llava import model
logging.info("Additional LLaVA models are now supported")
except ImportError as exc:
logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
# aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any]
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
_TOKENS_TYPE = torch.Tensor
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
TOKENS_TYPE_NAME = "TOKENS"
KNOWN_CHAT_TEMPLATES = {}
def _update_known_chat_templates():
try:
_chat_templates: Traversable
with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates:
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
except ImportError as exc:
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
@ -190,13 +192,15 @@ class TransformerMergeSamplers(CustomNode):
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
class TransformersLoader(CustomNode):
class TransformersImageProcessorLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),),
"subfolder": ("STRING", {})
"subfolder": ("STRING", {}),
"model": ("MODEL", {}),
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
}
}
@ -204,15 +208,66 @@ class TransformersLoader(CustomNode):
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str, subfolder: Optional[str] = None):
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, model: TransformersManagedModel = None, overwrite_tokenizer: bool = False):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
return model.patch_processor(processor, overwrite_tokenizer),
class TransformersLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),),
"subfolder": ("STRING", {})
},
}
CATEGORY = "language"
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
with comfy_tqdm():
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
"torch_dtype": unet_dtype(),
"device_map": get_torch_device_name(unet_offload_device()),
"low_cpu_mem_usage": True,
"trust_remote_code": True,
**hub_kwargs
}
try:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs)
except:
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer, config_dict)
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
except:
processor = LlavaNextProcessor.from_pretrained(**from_pretrained_kwargs)
except:
processor = None
if not isinstance(processor, ProcessorMixin):
processor = None
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
model_managed = TransformersManagedModel(
repo_id=ckpt_name,
model=model,
tokenizer=tokenizer,
config_dict=config_dict,
processor=processor
)
return model_managed,
@ -223,6 +278,10 @@ class OneShotInstructTokenize(CustomNode):
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
"chat_template": ([_AUTO_CHAT_TEMPLATE] + list(KNOWN_CHAT_TEMPLATES.keys()), {})
},
"optional": {
"images": ("IMAGE", {}),
}
}
@ -230,27 +289,17 @@ class OneShotInstructTokenize(CustomNode):
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
assert tokenizer is not None
assert hasattr(tokenizer, "decode")
# try to retrieve a matching chat template
chat_template = tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
if chat_template is None:
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in model.config_dict["_name_or_path"] or name in model.model.name_or_path]
if len(candidate_chat_templates) > 0:
filename, chat_template = candidate_chat_templates[0]
logging.debug(f"Selected chat template filename={filename} for {model.model.name_or_path}")
try:
# todo: this should come from node inputs
prompt = tokenizer.apply_chat_template([
{"role": "user", "content": prompt},
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
except Exception as exc:
logging.error("Could not apply chat template", exc_info=exc)
return tokenizer(prompt, return_tensors="pt"),
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
if chat_template == _AUTO_CHAT_TEMPLATE:
# use an exact match
model_name = os.path.basename(model.repo_id)
if model_name in KNOWN_CHAT_TEMPLATES:
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
else:
chat_template = None
else:
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
return model.tokenize(prompt, images, chat_template),
class TransformersGenerate(CustomNode):
@ -266,7 +315,6 @@ class TransformersGenerate(CustomNode):
"use_cache": ("BOOLEAN", {"default": True}),
},
"optional": {
"images": ("IMAGE", {}),
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
}
}
@ -281,24 +329,36 @@ class TransformersGenerate(CustomNode):
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
images: Optional[List[torch.Tensor] | torch.Tensor] = None,
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs
):
tokens = copy.copy(tokens)
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
# gracefully support LlaVA and others
if images is not None and not isinstance(images, torch.Tensor):
images = torch.stack(images, dim=0)
if images is not None:
generate_kwargs["images"] = images
# assuming it's of the form (batch, features..., height, width)
generate_kwargs["images_sizes"] = [(images.shape[-2], images.shape[-1]) for _ in range(images.shape[0])]
load_model_gpu(model)
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
inputs = tokens.to(model.current_device)
transformers_model: PreTrainedModel = model.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
# remove unused inputs
# maximizes compatibility with different models
generate_signature = inspect.signature(transformers_model.generate).parameters
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
gen_sig_keys = generate_signature.keys()
if "input_ids" in tokens and "inputs" in tokens:
if "input_ids" in gen_sig_keys:
to_delete.add("inputs")
elif "inputs" in gen_sig_keys:
to_delete.add("input_ids")
for unused_kwarg in to_delete:
tokens.pop(unused_kwarg)
logging.info(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
# images should be moved to model
for key in ("images", "pixel_values"):
if key in tokens:
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
inputs = tokens
progress_logits_processor = _ProgressLogitsProcessor(model)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
@ -320,9 +380,8 @@ class TransformersGenerate(CustomNode):
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
# load the model as close to the actual generation as possible
output_ids = transformers_model.generate(
inputs.input_ids,
**inputs,
logits_processor=LogitsProcessorList([progress_logits_processor]),
streamer=text_streamer,
max_new_tokens=max_new_tokens,
@ -333,12 +392,22 @@ class TransformersGenerate(CustomNode):
if transformers_model.config.is_encoder_decoder:
start_position = 1
else:
start_position = inputs.input_ids.shape[1]
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
output_ids = output_ids[:, start_position:]
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return outputs,
# gpu-loaded stuff like images can now be unloaded
if hasattr(tokens, "to"):
del tokens
else:
for to_delete in tokens.values():
del to_delete
del tokens
# todo: better support batches
return outputs[0],
class PreviewString(CustomNode):
@ -369,10 +438,9 @@ for cls in (
TransformerBeamSearchSampler,
TransformerMergeSamplers,
TransformersLoader,
TransformersImageProcessorLoader,
TransformersGenerate,
OneShotInstructTokenize,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls
_update_known_chat_templates()

View File

@ -284,10 +284,11 @@ class DevNullUris(CustomNode):
class StringJoin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
required = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
required["separator"] = ("STRING", {"default": "_"})
optional = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
optional["separator"] = ("STRING", {"default": "_"})
return {
"required": required
"required": {},
"optional": optional
}
RETURN_TYPES = ("STRING",)