ComfyUI/comfy/language/transformers_model_management.py
doctorpangloss bbe2ed330c Memory management and compilation improvements
- Experimental support for sage attention on Linux
 - Diffusers loader now supports model indices
 - Transformers model management now aligns with updates to ComfyUI
 - Flux layers correctly use unbind
 - Add float8 support for model loading in more places
 - Experimental quantization approaches from Quanto and torchao
 - Model upscaling interacts with memory management better

This update also disables ROCm testing because it isn't reliable enough
on consumer hardware. ROCm is not really supported by the 7600.
2024-10-09 09:13:47 -07:00

422 lines
19 KiB
Python

from __future__ import annotations
import copy
import inspect
import logging
import operator
import pathlib
import warnings
from functools import reduce
from typing import Optional, Any, Callable
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
PretrainedConfig, TextStreamer, LogitsProcessor
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .chat_templates import KNOWN_CHAT_TEMPLATES
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
LLaVAProcessor, LanguageModel, LanguagePrompt
from .. import model_management
from ..component_model.tensor_types import RGBImageBatch
from ..model_downloader import get_or_download_huggingface_repo
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
from ..model_management_types import ModelManageable
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block, tensor2pil
class TransformersManagedModel(ModelManageable, LanguageModel):
def __init__(
self,
repo_id: str,
model: PreTrainedModel,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
config_dict: Optional[dict] = None,
processor: Optional[ProcessorMixin | AutoProcessor] = None
):
self._repo_id = repo_id
self.model = model
self._tokenizer = tokenizer
self._processor = processor
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device()
self.offload_device = unet_offload_device()
self._config_dict = config_dict
self._on_set_processor(self._processor)
if model.device != self.offload_device:
model.to(device=self.offload_device)
@staticmethod
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "TransformersManagedModel":
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
repo_id = ckpt_name
with comfy_tqdm():
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
"trust_remote_code": True,
**hub_kwargs
}
# compute bitsandbytes configuration
try:
import bitsandbytes
except ImportError:
pass
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
model_type = config_dict["model_type"]
# language models prefer to use bfloat16 over float16
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()), }, {})
# if we have flash-attn installed, try to use it
try:
if model_management.flash_attn_enabled():
attn_override_kwargs = {
"attn_implementation": "flash_attention_2",
**kwargs_to_try[0]
}
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
except ImportError:
pass
for i, props in enumerate(kwargs_to_try):
try:
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
else:
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
if model is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
else:
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
finally:
torch.set_default_dtype(torch.float32)
for i, props in enumerate(kwargs_to_try):
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
except:
processor = None
if isinstance(processor, PreTrainedTokenizerBase):
tokenizer = processor
processor = None
else:
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
if tokenizer is not None or processor is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
finally:
torch.set_default_dtype(torch.float32)
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
logging.debug("enabled xformers memory efficient attention")
model_managed = TransformersManagedModel(
repo_id=repo_id,
model=model,
tokenizer=tokenizer,
config_dict=config_dict,
processor=processor
)
return model_managed
def generate(self, tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs) -> str:
tokens = copy.copy(tokens)
tokens_original = copy.copy(tokens)
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
load_models_gpu([self])
transformers_model: PreTrainedModel = self.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = self.tokenizer
# remove unused inputs
# maximizes compatibility with different models
generate_signature = inspect.signature(transformers_model.generate).parameters
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
gen_sig_keys = generate_signature.keys()
if "tgt_lang" in tokens:
to_delete.add("tgt_lang")
to_delete.add("src_lang")
to_delete.discard("input_ids")
if "forced_bos_token_id" in tokens:
to_delete.discard("forced_bos_token_id")
elif hasattr(tokenizer, "convert_tokens_to_ids"):
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
else:
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
if "input_ids" in tokens and "inputs" in tokens:
if "input_ids" in gen_sig_keys:
to_delete.add("inputs")
elif "inputs" in gen_sig_keys:
to_delete.add("input_ids")
for unused_kwarg in to_delete:
tokens.pop(unused_kwarg)
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
# images should be moved to model
for key in ("images", "pixel_values"):
if key in tokens:
tokens[key] = tokens[key].to(device=self.current_device, dtype=self.model_dtype())
# sets up inputs
inputs = tokens
# used to determine if text streaming is supported
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
# todo: deal with batches correctly, don't assume batch size 1
token_count = 0
# progress
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
token_count += 1
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
inputs.pop("attention_mask")
output_ids = transformers_model.generate(
**inputs,
streamer=text_streamer if num_beams <= 1 else None,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**generate_kwargs
)
if not transformers_model.config.is_encoder_decoder:
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
output_ids = output_ids[:, start_position:]
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
prev_src_lang = tokenizer.src_lang
tokenizer.src_lang = tokens_original["src_lang"]
else:
prev_src_lang = None
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
try:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
finally:
if prev_src_lang is not None:
tokenizer.src_lang = prev_src_lang
# gpu-loaded stuff like images can now be unloaded
if hasattr(tokens, "to"):
del tokens
else:
for to_delete in tokens.values():
del to_delete
del tokens
# todo: better support batches
return outputs[0]
@property
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
return self._tokenizer
@property
def processor(self) -> AutoProcessor | ProcessorMixin | LLaVAProcessor | None:
return self._processor
@property
def config_dict(self) -> dict:
"""
The original configuration dictionary located in the Transformers model.
Many models derive from base models and should inherit their settings like a chat template. This
config_dict will have the base model's name in _name_or_path, enabling a lookup for the valid
chat template when it is not specified by the derived model (it almost never is).
:return: the dict value of the config.json in the HuggingFace model
"""
if self._config_dict is not None:
return self._config_dict
return self.model.config.to_dict()
def lowvram_patch_counter(self):
return 0
@property
def current_device(self) -> torch.device:
return self.model.device
def is_clone(self, other: Any) -> bool:
return hasattr(other, "model") and self.model is other.model
def clone_has_same_weights(self, clone: Any) -> bool:
if not isinstance(clone, TransformersManagedModel):
return False
clone: TransformersManagedModel
if not self.is_clone(clone):
return False
try:
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
except ValueError as no_adapters:
return True
def model_size(self) -> int:
return self._size
def model_patches_to(self, arg: torch.device | torch.dtype):
if isinstance(arg, torch.device):
self.model.to(device=arg)
else:
self.model.to(arg)
def model_dtype(self) -> torch.dtype:
return self.model.dtype
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
return self.model.to(device=device_to)
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
return self.model.to(device=device_to)
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
model = copy.copy(self)
model._processor = processor
if hasattr(processor, "tokenizer") and overwrite_tokenizer:
model._tokenizer = processor.tokenizer
self._on_set_processor(model._processor)
return model
def _on_set_processor(self, processor: Any):
if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"):
processor.image_processor.do_rescale = False
def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, chat_template: str | None = None) -> ProcessorResult:
tokenizer = self.tokenizer
assert tokenizer is not None
assert hasattr(tokenizer, "decode")
# try to retrieve a matching chat template
chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None
if chat_template is None and self.config_dict is not None and "_name_or_path" in self.config_dict:
candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path]
if len(candidate_chat_templates) > 0:
filename, chat_template = candidate_chat_templates[0]
logging.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}")
if isinstance(images, list):
images = torch.stack(images, dim=0)
if images is not None:
# PIL it for the sake of simplicity
image_sizes = [(image.shape[-2], image.shape[-3]) for image in images]
else:
image_sizes = []
images = []
try:
if hasattr(tokenizer, "apply_chat_template"):
messages: LanguagePrompt
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict):
messages = prompt
elif "content[" in chat_template:
messages = [
{"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
] + [
{"type": "image"} for _ in range(len(images))
]
}
]
else:
messages = [
{"role": "user", "content": prompt},
]
prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False)
except Exception as exc:
logging.debug("Could not apply chat template", exc_info=exc)
if self.processor is None and isinstance(prompt, str):
batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device)
return {**batch_encoding}
else:
if hasattr(self.processor, "to"):
self.processor.to(device=self.load_device)
batch_feature: BatchFeature = self.processor(text=[prompt], images=images.unbind(), return_tensors="pt", padding=True)
if hasattr(self.processor, "to"):
self.processor.to(device=self.offload_device)
assert "input_ids" in batch_feature
batch_feature.to(device=self.load_device, dtype=self.model_dtype())
# noinspection PyTypeChecker
return {
"image_sizes": image_sizes,
"images": batch_feature["pixel_values"],
"inputs": batch_feature["input_ids"],
**batch_feature
}
@property
def repo_id(self) -> str:
return self._repo_id
def __str__(self):
if self.repo_id is not None:
repo_id_as_path = pathlib.PurePath(self.repo_id)
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
else:
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.on_finalized_text_handler = on_finalized_text
def on_finalized_text(self, text: str, stream_end: bool = False):
self.on_finalized_text_handler(text, stream_end)
class _ProgressLogitsProcessor(LogitsProcessor):
def __init__(self, model: TransformersManagedModel):
self.eos_token_id = model.tokenizer.eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
self.eos_probability = probabilities[:, self.eos_token_id].item()
return scores