mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
187 lines
8.1 KiB
Python
187 lines
8.1 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
import pathlib
|
|
import warnings
|
|
from typing import Optional, Any, Callable, Union, List
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL.Image import Image
|
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
|
TensorType, BatchFeature
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
|
|
from transformers.utils import PaddingStrategy
|
|
|
|
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
|
from .language_types import ProcessorResult
|
|
from ..model_management import unet_offload_device, get_torch_device
|
|
from ..model_management_types import ModelManageable
|
|
|
|
LLaVAProcessor = Callable[
|
|
[
|
|
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
|
|
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
|
|
Union[bool, str, PaddingStrategy], # padding parameter
|
|
Union[bool, str, TruncationStrategy], # truncation parameter
|
|
Optional[int], # max_length parameter
|
|
Optional[Union[str, TensorType]] # return_tensors parameter
|
|
],
|
|
BatchFeature
|
|
]
|
|
|
|
|
|
class TransformersManagedModel(ModelManageable):
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
model: PreTrainedModel,
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
|
config_dict: Optional[dict] = None,
|
|
processor: Optional[ProcessorMixin | AutoProcessor] = None
|
|
):
|
|
self.repo_id = repo_id
|
|
self.model = model
|
|
self._tokenizer = tokenizer
|
|
self._processor = processor
|
|
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
|
|
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
|
|
self.load_device = get_torch_device()
|
|
self.offload_device = unet_offload_device()
|
|
self._config_dict = config_dict
|
|
self._on_set_processor(self._processor)
|
|
if model.device != self.offload_device:
|
|
model.to(device=self.offload_device)
|
|
|
|
@property
|
|
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
|
|
return self._tokenizer
|
|
|
|
@property
|
|
def processor(self) -> AutoProcessor | ProcessorMixin | LLaVAProcessor | None:
|
|
return self._processor
|
|
|
|
@property
|
|
def config_dict(self) -> dict:
|
|
"""
|
|
The original configuration dictionary located in the Transformers model.
|
|
|
|
Many models derive from base models and should inherit their settings like a chat template. This
|
|
config_dict will have the base model's name in _name_or_path, enabling a lookup for the valid
|
|
chat template when it is not specified by the derived model (it almost never is).
|
|
:return: the dict value of the config.json in the HuggingFace model
|
|
"""
|
|
if self._config_dict is not None:
|
|
return self._config_dict
|
|
|
|
return self.model.config.to_dict()
|
|
|
|
def lowvram_patch_counter(self):
|
|
return 0
|
|
|
|
@property
|
|
def current_device(self) -> torch.device:
|
|
return self.model.device
|
|
|
|
def is_clone(self, other: Any) -> bool:
|
|
return hasattr(other, "model") and self.model is other.model
|
|
|
|
def clone_has_same_weights(self, clone: Any) -> bool:
|
|
if not isinstance(clone, TransformersManagedModel):
|
|
return False
|
|
|
|
clone: TransformersManagedModel
|
|
|
|
if not self.is_clone(clone):
|
|
return False
|
|
|
|
try:
|
|
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
|
|
except ValueError as no_adapters:
|
|
return True
|
|
|
|
def model_size(self) -> int:
|
|
return self._size
|
|
|
|
def model_patches_to(self, arg: torch.device | torch.dtype):
|
|
if isinstance(arg, torch.device):
|
|
self.model.to(device=arg)
|
|
else:
|
|
self.model.to(arg)
|
|
|
|
def model_dtype(self) -> torch.dtype:
|
|
return self.model.dtype
|
|
|
|
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module:
|
|
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
|
return self.model.to(device=device_to)
|
|
|
|
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
|
return self.model.to(device=device_to)
|
|
|
|
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
|
return self.model.to(device=offload_device)
|
|
|
|
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
|
|
model = copy.copy(self)
|
|
model._processor = processor
|
|
if hasattr(processor, "tokenizer") and overwrite_tokenizer:
|
|
model._tokenizer = processor.tokenizer
|
|
self._on_set_processor(model._processor)
|
|
return model
|
|
|
|
def _on_set_processor(self, processor: Any):
|
|
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
|
|
processor.image_processor.do_rescale = False
|
|
|
|
def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str | None = None) -> ProcessorResult:
|
|
tokenizer = self.tokenizer
|
|
assert tokenizer is not None
|
|
assert hasattr(tokenizer, "decode")
|
|
|
|
# try to retrieve a matching chat template
|
|
chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
|
|
if chat_template is None and self.config_dict is not None and "_name_or_path" in self.config_dict:
|
|
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path]
|
|
if len(candidate_chat_templates) > 0:
|
|
filename, chat_template = candidate_chat_templates[0]
|
|
logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}")
|
|
try:
|
|
if hasattr(tokenizer, "apply_chat_template"):
|
|
# todo: this should come from node inputs
|
|
prompt = tokenizer.apply_chat_template([
|
|
{"role": "user", "content": prompt},
|
|
], chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
|
except Exception as exc:
|
|
logging.debug("Could not apply chat template", exc_info=exc)
|
|
|
|
if self.processor is None:
|
|
batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device)
|
|
return {**batch_encoding}
|
|
else:
|
|
assert images is not None and len(images) > 0, "When using a multi-modal model, pass at least one, possibly empty, image"
|
|
if hasattr(self.processor, "to"):
|
|
self.processor.to(device=self.load_device)
|
|
|
|
assert "<image>" in prompt.lower(), "You must specify a <image> token inside the prompt for it to be substituted correctly by a HuggingFace processor"
|
|
batch_feature: BatchFeature = self.processor([prompt], images=images.unbind(), padding=True, return_tensors="pt")
|
|
if hasattr(self.processor, "to"):
|
|
self.processor.to(device=self.offload_device)
|
|
assert "input_ids" in batch_feature
|
|
batch_feature.to(device=self.load_device, dtype=self.model_dtype())
|
|
# noinspection PyTypeChecker
|
|
return {
|
|
"image_sizes": [(images.shape[-1], image.shape[-2]) for image in images],
|
|
"images": batch_feature["pixel_values"],
|
|
"inputs": batch_feature["input_ids"],
|
|
**batch_feature
|
|
}
|
|
|
|
def __str__(self):
|
|
if self.repo_id is not None:
|
|
repo_id_as_path = pathlib.PurePath(self.repo_id)
|
|
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
|
|
else:
|
|
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|