mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
- Experimental support for sage attention on Linux - Diffusers loader now supports model indices - Transformers model management now aligns with updates to ComfyUI - Flux layers correctly use unbind - Add float8 support for model loading in more places - Experimental quantization approaches from Quanto and torchao - Model upscaling interacts with memory management better This update also disables ROCm testing because it isn't reliable enough on consumer hardware. ROCm is not really supported by the 7600.
422 lines
19 KiB
Python
422 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
import copy
|
|
import inspect
|
|
import logging
|
|
import operator
|
|
import pathlib
|
|
import warnings
|
|
from functools import reduce
|
|
from typing import Optional, Any, Callable
|
|
|
|
import torch
|
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
|
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
|
|
PretrainedConfig, TextStreamer, LogitsProcessor
|
|
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
|
|
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
|
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
|
|
LLaVAProcessor, LanguageModel, LanguagePrompt
|
|
from .. import model_management
|
|
from ..component_model.tensor_types import RGBImageBatch
|
|
from ..model_downloader import get_or_download_huggingface_repo
|
|
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
|
|
from ..model_management_types import ModelManageable
|
|
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block, tensor2pil
|
|
|
|
|
|
class TransformersManagedModel(ModelManageable, LanguageModel):
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
model: PreTrainedModel,
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
|
config_dict: Optional[dict] = None,
|
|
processor: Optional[ProcessorMixin | AutoProcessor] = None
|
|
):
|
|
self._repo_id = repo_id
|
|
self.model = model
|
|
self._tokenizer = tokenizer
|
|
self._processor = processor
|
|
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
|
|
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
|
|
self.load_device = get_torch_device()
|
|
self.offload_device = unet_offload_device()
|
|
self._config_dict = config_dict
|
|
self._on_set_processor(self._processor)
|
|
if model.device != self.offload_device:
|
|
model.to(device=self.offload_device)
|
|
|
|
@staticmethod
|
|
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "TransformersManagedModel":
|
|
hub_kwargs = {}
|
|
if subfolder is not None and subfolder != "":
|
|
hub_kwargs["subfolder"] = subfolder
|
|
repo_id = ckpt_name
|
|
with comfy_tqdm():
|
|
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
|
|
|
from_pretrained_kwargs = {
|
|
"pretrained_model_name_or_path": ckpt_name,
|
|
"trust_remote_code": True,
|
|
**hub_kwargs
|
|
}
|
|
|
|
# compute bitsandbytes configuration
|
|
try:
|
|
import bitsandbytes
|
|
except ImportError:
|
|
pass
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
|
model_type = config_dict["model_type"]
|
|
# language models prefer to use bfloat16 over float16
|
|
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
|
"low_cpu_mem_usage": True,
|
|
"device_map": str(unet_offload_device()), }, {})
|
|
|
|
# if we have flash-attn installed, try to use it
|
|
try:
|
|
if model_management.flash_attn_enabled():
|
|
attn_override_kwargs = {
|
|
"attn_implementation": "flash_attention_2",
|
|
**kwargs_to_try[0]
|
|
}
|
|
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
|
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
|
except ImportError:
|
|
pass
|
|
for i, props in enumerate(kwargs_to_try):
|
|
try:
|
|
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
|
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
|
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
|
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
|
else:
|
|
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
|
if model is not None:
|
|
break
|
|
except Exception as exc_info:
|
|
if i == len(kwargs_to_try) - 1:
|
|
raise exc_info
|
|
else:
|
|
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
|
finally:
|
|
torch.set_default_dtype(torch.float32)
|
|
|
|
for i, props in enumerate(kwargs_to_try):
|
|
try:
|
|
try:
|
|
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
|
except:
|
|
processor = None
|
|
if isinstance(processor, PreTrainedTokenizerBase):
|
|
tokenizer = processor
|
|
processor = None
|
|
else:
|
|
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
|
if tokenizer is not None or processor is not None:
|
|
break
|
|
except Exception as exc_info:
|
|
if i == len(kwargs_to_try) - 1:
|
|
raise exc_info
|
|
finally:
|
|
torch.set_default_dtype(torch.float32)
|
|
|
|
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
|
model.enable_xformers_memory_efficient_attention()
|
|
logging.debug("enabled xformers memory efficient attention")
|
|
|
|
model_managed = TransformersManagedModel(
|
|
repo_id=repo_id,
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
config_dict=config_dict,
|
|
processor=processor
|
|
)
|
|
|
|
return model_managed
|
|
|
|
def generate(self, tokens: TOKENS_TYPE = None,
|
|
max_new_tokens: int = 512,
|
|
repetition_penalty: float = 0.0,
|
|
seed: int = 0,
|
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
|
*args,
|
|
**kwargs) -> str:
|
|
tokens = copy.copy(tokens)
|
|
tokens_original = copy.copy(tokens)
|
|
sampler = sampler or {}
|
|
generate_kwargs = copy.copy(sampler)
|
|
load_models_gpu([self])
|
|
transformers_model: PreTrainedModel = self.model
|
|
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = self.tokenizer
|
|
# remove unused inputs
|
|
# maximizes compatibility with different models
|
|
generate_signature = inspect.signature(transformers_model.generate).parameters
|
|
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
|
|
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
|
|
gen_sig_keys = generate_signature.keys()
|
|
if "tgt_lang" in tokens:
|
|
to_delete.add("tgt_lang")
|
|
to_delete.add("src_lang")
|
|
to_delete.discard("input_ids")
|
|
if "forced_bos_token_id" in tokens:
|
|
to_delete.discard("forced_bos_token_id")
|
|
elif hasattr(tokenizer, "convert_tokens_to_ids"):
|
|
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
|
|
else:
|
|
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
|
|
if "input_ids" in tokens and "inputs" in tokens:
|
|
if "input_ids" in gen_sig_keys:
|
|
to_delete.add("inputs")
|
|
elif "inputs" in gen_sig_keys:
|
|
to_delete.add("input_ids")
|
|
for unused_kwarg in to_delete:
|
|
tokens.pop(unused_kwarg)
|
|
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
|
|
|
|
# images should be moved to model
|
|
for key in ("images", "pixel_values"):
|
|
if key in tokens:
|
|
tokens[key] = tokens[key].to(device=self.current_device, dtype=self.model_dtype())
|
|
|
|
# sets up inputs
|
|
inputs = tokens
|
|
|
|
# used to determine if text streaming is supported
|
|
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
|
|
|
|
progress_bar: ProgressBar
|
|
with comfy_progress(total=max_new_tokens) as progress_bar:
|
|
# todo: deal with batches correctly, don't assume batch size 1
|
|
token_count = 0
|
|
|
|
# progress
|
|
def on_finalized_text(next_token: str, stop: bool):
|
|
nonlocal token_count
|
|
nonlocal progress_bar
|
|
|
|
token_count += 1
|
|
preview = TransformerStreamedProgress(next_token=next_token)
|
|
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
|
|
|
|
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
|
|
|
with seed_for_block(seed):
|
|
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
|
inputs.pop("attention_mask")
|
|
output_ids = transformers_model.generate(
|
|
**inputs,
|
|
streamer=text_streamer if num_beams <= 1 else None,
|
|
max_new_tokens=max_new_tokens,
|
|
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
|
**generate_kwargs
|
|
)
|
|
|
|
if not transformers_model.config.is_encoder_decoder:
|
|
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
|
output_ids = output_ids[:, start_position:]
|
|
|
|
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
|
|
prev_src_lang = tokenizer.src_lang
|
|
tokenizer.src_lang = tokens_original["src_lang"]
|
|
else:
|
|
prev_src_lang = None
|
|
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
|
try:
|
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
finally:
|
|
if prev_src_lang is not None:
|
|
tokenizer.src_lang = prev_src_lang
|
|
# gpu-loaded stuff like images can now be unloaded
|
|
if hasattr(tokens, "to"):
|
|
del tokens
|
|
else:
|
|
for to_delete in tokens.values():
|
|
del to_delete
|
|
del tokens
|
|
|
|
# todo: better support batches
|
|
return outputs[0]
|
|
|
|
@property
|
|
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
|
|
return self._tokenizer
|
|
|
|
@property
|
|
def processor(self) -> AutoProcessor | ProcessorMixin | LLaVAProcessor | None:
|
|
return self._processor
|
|
|
|
@property
|
|
def config_dict(self) -> dict:
|
|
"""
|
|
The original configuration dictionary located in the Transformers model.
|
|
|
|
Many models derive from base models and should inherit their settings like a chat template. This
|
|
config_dict will have the base model's name in _name_or_path, enabling a lookup for the valid
|
|
chat template when it is not specified by the derived model (it almost never is).
|
|
:return: the dict value of the config.json in the HuggingFace model
|
|
"""
|
|
if self._config_dict is not None:
|
|
return self._config_dict
|
|
|
|
return self.model.config.to_dict()
|
|
|
|
def lowvram_patch_counter(self):
|
|
return 0
|
|
|
|
@property
|
|
def current_device(self) -> torch.device:
|
|
return self.model.device
|
|
|
|
def is_clone(self, other: Any) -> bool:
|
|
return hasattr(other, "model") and self.model is other.model
|
|
|
|
def clone_has_same_weights(self, clone: Any) -> bool:
|
|
if not isinstance(clone, TransformersManagedModel):
|
|
return False
|
|
|
|
clone: TransformersManagedModel
|
|
|
|
if not self.is_clone(clone):
|
|
return False
|
|
|
|
try:
|
|
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
|
|
except ValueError as no_adapters:
|
|
return True
|
|
|
|
def model_size(self) -> int:
|
|
return self._size
|
|
|
|
def model_patches_to(self, arg: torch.device | torch.dtype):
|
|
if isinstance(arg, torch.device):
|
|
self.model.to(device=arg)
|
|
else:
|
|
self.model.to(arg)
|
|
|
|
def model_dtype(self) -> torch.dtype:
|
|
return self.model.dtype
|
|
|
|
|
|
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
|
return self.model.to(device=device_to)
|
|
|
|
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
|
return self.model.to(device=device_to)
|
|
|
|
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
|
|
model = copy.copy(self)
|
|
model._processor = processor
|
|
if hasattr(processor, "tokenizer") and overwrite_tokenizer:
|
|
model._tokenizer = processor.tokenizer
|
|
self._on_set_processor(model._processor)
|
|
return model
|
|
|
|
def _on_set_processor(self, processor: Any):
|
|
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
|
|
processor.image_processor.do_rescale = False
|
|
|
|
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult:
|
|
tokenizer = self.tokenizer
|
|
assert tokenizer is not None
|
|
assert hasattr(tokenizer, "decode")
|
|
|
|
# try to retrieve a matching chat template
|
|
chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
|
|
if chat_template is None and self.config_dict is not None and "_name_or_path" in self.config_dict:
|
|
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path]
|
|
if len(candidate_chat_templates) > 0:
|
|
filename, chat_template = candidate_chat_templates[0]
|
|
logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}")
|
|
if isinstance(images, list):
|
|
images = torch.stack(images, dim=0)
|
|
if images is not None:
|
|
# PIL it for the sake of simplicity
|
|
image_sizes = [(image.shape[-2], image.shape[-3]) for image in images]
|
|
else:
|
|
image_sizes = []
|
|
images = []
|
|
|
|
try:
|
|
if hasattr(tokenizer, "apply_chat_template"):
|
|
messages: LanguagePrompt
|
|
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
|
|
messages = prompt
|
|
elif "content[" in chat_template:
|
|
messages = [
|
|
{"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": prompt
|
|
}
|
|
] + [
|
|
{"type": "image"} for _ in range(len(images))
|
|
]
|
|
|
|
}
|
|
]
|
|
else:
|
|
messages = [
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False)
|
|
except Exception as exc:
|
|
logging.debug("Could not apply chat template", exc_info=exc)
|
|
|
|
if self.processor is None and isinstance(prompt, str):
|
|
batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device)
|
|
return {**batch_encoding}
|
|
else:
|
|
if hasattr(self.processor, "to"):
|
|
self.processor.to(device=self.load_device)
|
|
|
|
batch_feature: BatchFeature = self.processor(text=[prompt], images=images.unbind(), return_tensors="pt", padding=True)
|
|
if hasattr(self.processor, "to"):
|
|
self.processor.to(device=self.offload_device)
|
|
assert "input_ids" in batch_feature
|
|
batch_feature.to(device=self.load_device, dtype=self.model_dtype())
|
|
# noinspection PyTypeChecker
|
|
return {
|
|
"image_sizes": image_sizes,
|
|
"images": batch_feature["pixel_values"],
|
|
"inputs": batch_feature["input_ids"],
|
|
**batch_feature
|
|
}
|
|
|
|
@property
|
|
def repo_id(self) -> str:
|
|
return self._repo_id
|
|
|
|
def __str__(self):
|
|
if self.repo_id is not None:
|
|
repo_id_as_path = pathlib.PurePath(self.repo_id)
|
|
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
|
|
else:
|
|
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|
|
|
|
|
|
class _ProgressTextStreamer(TextStreamer):
|
|
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
|
self.on_finalized_text_handler = on_finalized_text
|
|
|
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
|
self.on_finalized_text_handler(text, stream_end)
|
|
|
|
|
|
class _ProgressLogitsProcessor(LogitsProcessor):
|
|
def __init__(self, model: TransformersManagedModel):
|
|
self.eos_token_id = model.tokenizer.eos_token_id
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
probabilities = scores.softmax(dim=-1)
|
|
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
|
return scores
|