mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Multi-modal LLM support and ongoing improvements to language features.
This commit is contained in:
parent
6575409461
commit
7f300bcb7a
@ -0,0 +1,3 @@
|
|||||||
|
from .chat_templates import _update_known_chat_templates
|
||||||
|
|
||||||
|
_update_known_chat_templates()
|
||||||
18
comfy/language/chat_templates.py
Normal file
18
comfy/language/chat_templates.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from importlib.abc import Traversable
|
||||||
|
from importlib.resources import files
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
KNOWN_CHAT_TEMPLATES = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _update_known_chat_templates():
|
||||||
|
try:
|
||||||
|
_chat_templates: Traversable
|
||||||
|
with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates:
|
||||||
|
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
||||||
|
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
||||||
|
except ImportError as exc:
|
||||||
|
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
|
||||||
24
comfy/language/language_types.py
Normal file
24
comfy/language/language_types.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import BatchEncoding
|
||||||
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessorResult(TypedDict):
|
||||||
|
"""
|
||||||
|
Attributes:
|
||||||
|
attention_mask: attention mask
|
||||||
|
pixel_values: post image-processed values
|
||||||
|
|
||||||
|
images: used for LLaVA compatibility and points to pixel_values
|
||||||
|
inputs: used for LLaVA compatibility and points to input_ids
|
||||||
|
images_sizes: used for LLaVA compatibility, stores the (width, height) tuples of the original input images
|
||||||
|
"""
|
||||||
|
|
||||||
|
attention_mask: NotRequired[torch.Tensor]
|
||||||
|
pixel_values: NotRequired[torch.Tensor]
|
||||||
|
|
||||||
|
images: NotRequired[torch.Tensor]
|
||||||
|
inputs: BatchEncoding
|
||||||
|
image_sizes: NotRequired[torch.Tensor]
|
||||||
@ -1,28 +1,66 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, Callable, Union, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig
|
from PIL.Image import Image
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
||||||
|
TensorType, BatchFeature
|
||||||
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
|
||||||
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
|
from .language_types import ProcessorResult
|
||||||
from ..model_management import unet_offload_device, get_torch_device
|
from ..model_management import unet_offload_device, get_torch_device
|
||||||
from ..model_management_types import ModelManageable
|
from ..model_management_types import ModelManageable
|
||||||
|
|
||||||
|
LLaVAProcessor = Callable[
|
||||||
|
[
|
||||||
|
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
|
||||||
|
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
|
||||||
|
Union[bool, str, PaddingStrategy], # padding parameter
|
||||||
|
Union[bool, str, TruncationStrategy], # truncation parameter
|
||||||
|
Optional[int], # max_length parameter
|
||||||
|
Optional[Union[str, TensorType]] # return_tensors parameter
|
||||||
|
],
|
||||||
|
BatchFeature
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TransformersManagedModel(ModelManageable):
|
class TransformersManagedModel(ModelManageable):
|
||||||
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None, config_dict: Optional[dict] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
|
config_dict: Optional[dict] = None,
|
||||||
|
processor: Optional[ProcessorMixin | AutoProcessor] = None
|
||||||
|
):
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
|
self._processor = processor
|
||||||
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
|
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
|
||||||
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
|
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
|
||||||
self.load_device = get_torch_device()
|
self.load_device = get_torch_device()
|
||||||
self.offload_device = unet_offload_device()
|
self.offload_device = unet_offload_device()
|
||||||
self._config_dict = config_dict
|
self._config_dict = config_dict
|
||||||
|
self._on_set_processor(self._processor)
|
||||||
if model.device != self.offload_device:
|
if model.device != self.offload_device:
|
||||||
model.to(device=self.offload_device)
|
model.to(device=self.offload_device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processor(self) -> AutoProcessor | ProcessorMixin | LLaVAProcessor | None:
|
||||||
|
return self._processor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_dict(self) -> dict:
|
def config_dict(self) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -67,7 +105,10 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
if not self.is_clone(clone):
|
if not self.is_clone(clone):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
|
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
|
||||||
|
except ValueError as no_adapters:
|
||||||
|
return True
|
||||||
|
|
||||||
def model_size(self) -> int:
|
def model_size(self) -> int:
|
||||||
return self._size
|
return self._size
|
||||||
@ -92,3 +133,57 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||||
return self.model.to(device=offload_device)
|
return self.model.to(device=offload_device)
|
||||||
|
|
||||||
|
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
|
||||||
|
model = copy.copy(self)
|
||||||
|
model._processor = processor
|
||||||
|
if hasattr(processor, "tokenizer") and overwrite_tokenizer:
|
||||||
|
model._tokenizer = processor.tokenizer
|
||||||
|
self._on_set_processor(model._processor)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _on_set_processor(self, processor: Any):
|
||||||
|
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
|
||||||
|
processor.image_processor.do_rescale = False
|
||||||
|
|
||||||
|
def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str) -> ProcessorResult:
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
assert tokenizer is not None
|
||||||
|
assert hasattr(tokenizer, "decode")
|
||||||
|
|
||||||
|
# try to retrieve a matching chat template
|
||||||
|
chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
|
||||||
|
if chat_template is None:
|
||||||
|
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path]
|
||||||
|
if len(candidate_chat_templates) > 0:
|
||||||
|
filename, chat_template = candidate_chat_templates[0]
|
||||||
|
logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}")
|
||||||
|
try:
|
||||||
|
# todo: this should come from node inputs
|
||||||
|
prompt = tokenizer.apply_chat_template([
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
||||||
|
except Exception as exc:
|
||||||
|
logging.error("Could not apply chat template", exc_info=exc)
|
||||||
|
|
||||||
|
if self.processor is None:
|
||||||
|
batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device)
|
||||||
|
return {**batch_encoding}
|
||||||
|
else:
|
||||||
|
assert images is not None and len(images) > 0, "When using a multi-modal model, pass at least one, possibly empty, image"
|
||||||
|
if hasattr(self.processor, "to"):
|
||||||
|
self.processor.to(device=self.load_device)
|
||||||
|
|
||||||
|
assert "<image>" in prompt, "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
||||||
|
batch_feature: BatchFeature = self.processor([prompt], images=images, padding=True, return_tensors="pt")
|
||||||
|
if hasattr(self.processor, "to"):
|
||||||
|
self.processor.to(device=self.offload_device)
|
||||||
|
assert "input_ids" in batch_feature
|
||||||
|
batch_feature.to(device=self.load_device, dtype=self.model_dtype())
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return {
|
||||||
|
"image_sizes": [(images.shape[-1], image.shape[-2]) for image in images],
|
||||||
|
"images": batch_feature["pixel_values"],
|
||||||
|
"inputs": batch_feature["input_ids"],
|
||||||
|
**batch_feature
|
||||||
|
}
|
||||||
|
|||||||
@ -303,6 +303,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS = {
|
|||||||
"JingyeChen22/textdiffuser2_layout_planner",
|
"JingyeChen22/textdiffuser2_layout_planner",
|
||||||
'JingyeChen22/textdiffuser2-full-ft',
|
'JingyeChen22/textdiffuser2-full-ft',
|
||||||
"microsoft/Phi-3-mini-4k-instruct",
|
"microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
"llava-hf/llava-v1.6-mistral-7b-hf"
|
||||||
}
|
}
|
||||||
|
|
||||||
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [
|
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [
|
||||||
|
|||||||
@ -1,42 +1,44 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
|
import os.path
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from importlib.resources import files
|
from typing import Any, Dict, Optional, List, Callable, Union
|
||||||
from importlib.resources.abc import Traversable
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional, List, Callable, TypedDict
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||||
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig
|
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
||||||
|
LlavaNextForConditionalGeneration, LlavaNextProcessor
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
|
from comfy.language.language_types import ProcessorResult
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_downloader import huggingface_repos
|
from comfy.model_downloader import huggingface_repos
|
||||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||||
|
|
||||||
|
_AUTO_CHAT_TEMPLATE = "default"
|
||||||
|
|
||||||
|
# add llava support
|
||||||
|
try:
|
||||||
|
from llava import model
|
||||||
|
|
||||||
|
logging.info("Additional LLaVA models are now supported")
|
||||||
|
except ImportError as exc:
|
||||||
|
logging.info(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
||||||
|
|
||||||
# aka kwargs type
|
# aka kwargs type
|
||||||
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
||||||
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
|
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
|
||||||
_TOKENS_TYPE = torch.Tensor
|
|
||||||
|
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
|
||||||
TOKENS_TYPE_NAME = "TOKENS"
|
TOKENS_TYPE_NAME = "TOKENS"
|
||||||
|
|
||||||
KNOWN_CHAT_TEMPLATES = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _update_known_chat_templates():
|
|
||||||
try:
|
|
||||||
_chat_templates: Traversable
|
|
||||||
with files("huggingface_extra_chat_templates") / "chat_templates" as _chat_templates:
|
|
||||||
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
|
||||||
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
|
||||||
except ImportError as exc:
|
|
||||||
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
|
|
||||||
|
|
||||||
|
|
||||||
class _ProgressTextStreamer(TextStreamer):
|
class _ProgressTextStreamer(TextStreamer):
|
||||||
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||||
@ -190,13 +192,15 @@ class TransformerMergeSamplers(CustomNode):
|
|||||||
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
|
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
|
||||||
|
|
||||||
|
|
||||||
class TransformersLoader(CustomNode):
|
class TransformersImageProcessorLoader(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"ckpt_name": (huggingface_repos(),),
|
"ckpt_name": (huggingface_repos(),),
|
||||||
"subfolder": ("STRING", {})
|
"subfolder": ("STRING", {}),
|
||||||
|
"model": ("MODEL", {}),
|
||||||
|
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,15 +208,66 @@ class TransformersLoader(CustomNode):
|
|||||||
RETURN_TYPES = "MODEL",
|
RETURN_TYPES = "MODEL",
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, ckpt_name: str, subfolder: Optional[str] = None):
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, model: TransformersManagedModel = None, overwrite_tokenizer: bool = False):
|
||||||
|
hub_kwargs = {}
|
||||||
|
if subfolder is not None and subfolder != "":
|
||||||
|
hub_kwargs["subfolder"] = subfolder
|
||||||
|
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
||||||
|
return model.patch_processor(processor, overwrite_tokenizer),
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersLoader(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": (huggingface_repos(),),
|
||||||
|
"subfolder": ("STRING", {})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
|
RETURN_TYPES = "MODEL",
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
|
||||||
hub_kwargs = {}
|
hub_kwargs = {}
|
||||||
if subfolder is not None and subfolder != "":
|
if subfolder is not None and subfolder != "":
|
||||||
hub_kwargs["subfolder"] = subfolder
|
hub_kwargs["subfolder"] = subfolder
|
||||||
with comfy_tqdm():
|
with comfy_tqdm():
|
||||||
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
from_pretrained_kwargs = {
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
|
"pretrained_model_name_or_path": ckpt_name,
|
||||||
|
"torch_dtype": unet_dtype(),
|
||||||
|
"device_map": get_torch_device_name(unet_offload_device()),
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"trust_remote_code": True,
|
||||||
|
**hub_kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
except:
|
||||||
|
model = LlavaNextForConditionalGeneration.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
|
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, trust_remote_code=True, **hub_kwargs)
|
||||||
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer, config_dict)
|
try:
|
||||||
|
try:
|
||||||
|
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
except:
|
||||||
|
processor = LlavaNextProcessor.from_pretrained(**from_pretrained_kwargs)
|
||||||
|
except:
|
||||||
|
processor = None
|
||||||
|
if not isinstance(processor, ProcessorMixin):
|
||||||
|
processor = None
|
||||||
|
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs)
|
||||||
|
|
||||||
|
model_managed = TransformersManagedModel(
|
||||||
|
repo_id=ckpt_name,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
config_dict=config_dict,
|
||||||
|
processor=processor
|
||||||
|
)
|
||||||
return model_managed,
|
return model_managed,
|
||||||
|
|
||||||
|
|
||||||
@ -223,6 +278,10 @@ class OneShotInstructTokenize(CustomNode):
|
|||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL",),
|
"model": ("MODEL",),
|
||||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||||
|
"chat_template": ([_AUTO_CHAT_TEMPLATE] + list(KNOWN_CHAT_TEMPLATES.keys()), {})
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"images": ("IMAGE", {}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,27 +289,17 @@ class OneShotInstructTokenize(CustomNode):
|
|||||||
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
|
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
|
||||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
if chat_template == _AUTO_CHAT_TEMPLATE:
|
||||||
assert tokenizer is not None
|
# use an exact match
|
||||||
assert hasattr(tokenizer, "decode")
|
model_name = os.path.basename(model.repo_id)
|
||||||
|
if model_name in KNOWN_CHAT_TEMPLATES:
|
||||||
# try to retrieve a matching chat template
|
chat_template = KNOWN_CHAT_TEMPLATES[model_name]
|
||||||
chat_template = tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
|
else:
|
||||||
if chat_template is None:
|
chat_template = None
|
||||||
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in model.config_dict["_name_or_path"] or name in model.model.name_or_path]
|
else:
|
||||||
if len(candidate_chat_templates) > 0:
|
chat_template = KNOWN_CHAT_TEMPLATES[chat_template]
|
||||||
filename, chat_template = candidate_chat_templates[0]
|
return model.tokenize(prompt, images, chat_template),
|
||||||
logging.debug(f"Selected chat template filename={filename} for {model.model.name_or_path}")
|
|
||||||
try:
|
|
||||||
# todo: this should come from node inputs
|
|
||||||
prompt = tokenizer.apply_chat_template([
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
|
||||||
except Exception as exc:
|
|
||||||
logging.error("Could not apply chat template", exc_info=exc)
|
|
||||||
|
|
||||||
return tokenizer(prompt, return_tensors="pt"),
|
|
||||||
|
|
||||||
|
|
||||||
class TransformersGenerate(CustomNode):
|
class TransformersGenerate(CustomNode):
|
||||||
@ -266,7 +315,6 @@ class TransformersGenerate(CustomNode):
|
|||||||
"use_cache": ("BOOLEAN", {"default": True}),
|
"use_cache": ("BOOLEAN", {"default": True}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"images": ("IMAGE", {}),
|
|
||||||
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
|
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -281,24 +329,36 @@ class TransformersGenerate(CustomNode):
|
|||||||
max_new_tokens: int = 512,
|
max_new_tokens: int = 512,
|
||||||
repetition_penalty: float = 0.0,
|
repetition_penalty: float = 0.0,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
images: Optional[List[torch.Tensor] | torch.Tensor] = None,
|
|
||||||
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
|
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
tokens = copy.copy(tokens)
|
||||||
sampler = sampler or {}
|
sampler = sampler or {}
|
||||||
generate_kwargs = copy.copy(sampler)
|
generate_kwargs = copy.copy(sampler)
|
||||||
# gracefully support LlaVA and others
|
|
||||||
if images is not None and not isinstance(images, torch.Tensor):
|
|
||||||
images = torch.stack(images, dim=0)
|
|
||||||
if images is not None:
|
|
||||||
generate_kwargs["images"] = images
|
|
||||||
# assuming it's of the form (batch, features..., height, width)
|
|
||||||
generate_kwargs["images_sizes"] = [(images.shape[-2], images.shape[-1]) for _ in range(images.shape[0])]
|
|
||||||
load_model_gpu(model)
|
load_model_gpu(model)
|
||||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
|
||||||
inputs = tokens.to(model.current_device)
|
|
||||||
transformers_model: PreTrainedModel = model.model
|
transformers_model: PreTrainedModel = model.model
|
||||||
|
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
||||||
|
# remove unused inputs
|
||||||
|
# maximizes compatibility with different models
|
||||||
|
generate_signature = inspect.signature(transformers_model.generate).parameters
|
||||||
|
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
|
||||||
|
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
|
||||||
|
gen_sig_keys = generate_signature.keys()
|
||||||
|
if "input_ids" in tokens and "inputs" in tokens:
|
||||||
|
if "input_ids" in gen_sig_keys:
|
||||||
|
to_delete.add("inputs")
|
||||||
|
elif "inputs" in gen_sig_keys:
|
||||||
|
to_delete.add("input_ids")
|
||||||
|
for unused_kwarg in to_delete:
|
||||||
|
tokens.pop(unused_kwarg)
|
||||||
|
logging.info(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
|
||||||
|
|
||||||
|
# images should be moved to model
|
||||||
|
for key in ("images", "pixel_values"):
|
||||||
|
if key in tokens:
|
||||||
|
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
|
||||||
|
inputs = tokens
|
||||||
progress_logits_processor = _ProgressLogitsProcessor(model)
|
progress_logits_processor = _ProgressLogitsProcessor(model)
|
||||||
progress_bar: ProgressBar
|
progress_bar: ProgressBar
|
||||||
with comfy_progress(total=max_new_tokens) as progress_bar:
|
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||||
@ -320,9 +380,8 @@ class TransformersGenerate(CustomNode):
|
|||||||
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||||
|
|
||||||
with seed_for_block(seed):
|
with seed_for_block(seed):
|
||||||
# load the model as close to the actual generation as possible
|
|
||||||
output_ids = transformers_model.generate(
|
output_ids = transformers_model.generate(
|
||||||
inputs.input_ids,
|
**inputs,
|
||||||
logits_processor=LogitsProcessorList([progress_logits_processor]),
|
logits_processor=LogitsProcessorList([progress_logits_processor]),
|
||||||
streamer=text_streamer,
|
streamer=text_streamer,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
@ -333,12 +392,22 @@ class TransformersGenerate(CustomNode):
|
|||||||
if transformers_model.config.is_encoder_decoder:
|
if transformers_model.config.is_encoder_decoder:
|
||||||
start_position = 1
|
start_position = 1
|
||||||
else:
|
else:
|
||||||
start_position = inputs.input_ids.shape[1]
|
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
||||||
output_ids = output_ids[:, start_position:]
|
output_ids = output_ids[:, start_position:]
|
||||||
|
|
||||||
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
||||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
return outputs,
|
|
||||||
|
# gpu-loaded stuff like images can now be unloaded
|
||||||
|
if hasattr(tokens, "to"):
|
||||||
|
del tokens
|
||||||
|
else:
|
||||||
|
for to_delete in tokens.values():
|
||||||
|
del to_delete
|
||||||
|
del tokens
|
||||||
|
|
||||||
|
# todo: better support batches
|
||||||
|
return outputs[0],
|
||||||
|
|
||||||
|
|
||||||
class PreviewString(CustomNode):
|
class PreviewString(CustomNode):
|
||||||
@ -369,10 +438,9 @@ for cls in (
|
|||||||
TransformerBeamSearchSampler,
|
TransformerBeamSearchSampler,
|
||||||
TransformerMergeSamplers,
|
TransformerMergeSamplers,
|
||||||
TransformersLoader,
|
TransformersLoader,
|
||||||
|
TransformersImageProcessorLoader,
|
||||||
TransformersGenerate,
|
TransformersGenerate,
|
||||||
OneShotInstructTokenize,
|
OneShotInstructTokenize,
|
||||||
PreviewString,
|
PreviewString,
|
||||||
):
|
):
|
||||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||||
|
|
||||||
_update_known_chat_templates()
|
|
||||||
|
|||||||
@ -284,10 +284,11 @@ class DevNullUris(CustomNode):
|
|||||||
class StringJoin(CustomNode):
|
class StringJoin(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
required = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
|
optional = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
|
||||||
required["separator"] = ("STRING", {"default": "_"})
|
optional["separator"] = ("STRING", {"default": "_"})
|
||||||
return {
|
return {
|
||||||
"required": required
|
"required": {},
|
||||||
|
"optional": optional
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("STRING",)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user