from __future__ import annotations import contextlib import copy import inspect import logging import operator import pathlib import weakref from functools import reduce from typing import Optional, Any, Callable import torch import transformers from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \ BatchFeature, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \ PretrainedConfig, TextStreamer, LogitsProcessor from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, AutoModelForImageTextToText from .chat_templates import KNOWN_CHAT_TEMPLATES from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \ LanguageModel, LanguagePrompt from .. import model_management from ..cli_args import args from ..component_model.tensor_types import RGBImageBatch from ..model_downloader import get_or_download_huggingface_repo from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu from ..model_management_types import ModelManageableStub from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block logger = logging.getLogger(__name__) # tweaks to support florence 2 _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2'] # should be added if the expectation is that this model emits special tokens _DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2', 'paligemma'} class TransformersManagedModel(ModelManageableStub, LanguageModel): def __init__( self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None, config_dict: Optional[dict] = None, processor: Optional[ProcessorMixin | AutoProcessor] = None ): self._repo_id = repo_id self._model = model self._tokenizer = tokenizer self._processor = processor self._object_patches: dict[str, Any] = {} self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values()) self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values()) self.load_device = get_torch_device() self.offload_device = unet_offload_device() self._config_dict = config_dict self._on_set_processor(self._processor) self._model_type = "" self._original_transformers_managed_model: weakref.ReferenceType["TransformersManagedModel"] = weakref.ref(self) self.wrappers = {} self.callbacks = {} self._hook_mode = None self._model_options = {"transformer_options": {}} if model.device != self.offload_device: model.to(device=self.offload_device) @property def hook_mode(self): from ..hooks import EnumHookMode if self._hook_mode is None: self._hook_mode = EnumHookMode.MaxSpeed return self._hook_mode @hook_mode.setter def hook_mode(self, value): self._hook_mode = value def prepare_hook_patches_current_keyframe(self, t, hook_group, model_options): return def model_patches_models(self): return [] def restore_hook_patches(self): return def cleanup(self): pass def pre_run(self): pass def prepare_state(self, *args, **kwargs): pass def register_all_hook_patches(self, a, b, c, d): pass def get_nested_additional_models(self): return [] def apply_hooks(self, *args, **kwargs): return {} def add_wrapper(self, wrapper_type: str, wrapper: Callable): self.add_wrapper_with_key(wrapper_type, None, wrapper) def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) w.append(wrapper) def remove_wrappers_with_key(self, wrapper_type: str, key: str): w = self.wrappers.get(wrapper_type, {}) if key in w: w.pop(key) def get_wrappers_with_key(self, wrapper_type: str, key: str): w_list = [] w_list.extend(self.wrappers.get(wrapper_type, {}).get(key, [])) return w_list def get_all_wrappers(self, wrapper_type: str): w_list = [] for w in self.wrappers.get(wrapper_type, {}).values(): w_list.extend(w) return w_list @property def model_options(self): return self._model_options @model_options.setter def model_options(self, value): self._model_options = value @property def diffusion_model(self): return self.model @diffusion_model.setter def diffusion_model(self, value): self.add_object_patch("model", value) @staticmethod def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None, config_dict: PretrainedConfig | dict | None = None, **kwargs) -> "TransformersManagedModel": hub_kwargs = {} if subfolder is not None and subfolder.strip() != "": hub_kwargs["subfolder"] = subfolder repo_id = ckpt_name with comfy_tqdm(): ckpt_name = get_or_download_huggingface_repo(repo_id) if config_dict is None: config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs) elif isinstance(config_dict, PretrainedConfig): config_dict: dict = config_dict.to_dict() else: config_dict = {} try: model_type = config_dict["model_type"] except KeyError: logger.debug(f"Configuration was missing for repo_id={repo_id}") model_type = "" from_pretrained_kwargs = { "pretrained_model_name_or_path": ckpt_name, **hub_kwargs, **kwargs, } # language models prefer to use bfloat16 over float16 default_kwargs = { "dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), "low_cpu_mem_usage": True, "device_map": str(unet_offload_device()), # transformers usually has a better upstream implementation than whatever is put into the author's repos "trust_remote_code": False, } default_kwargs_trust_remote = { **default_kwargs, "trust_remote_code": True } kwargses_to_try = (default_kwargs, default_kwargs_trust_remote, {}) # if we have flash-attn installed, try to use it try: if model_management.flash_attn_enabled(): attn_override_kwargs = { "attn_implementation": "flash_attention_2", **kwargses_to_try[0] } kwargses_to_try = (attn_override_kwargs, *kwargses_to_try) logger.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") except ImportError: pass for i, kwargs_to_try in enumerate(kwargses_to_try): try: if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES: model = AutoModelForImageTextToText.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) else: model = AutoModel.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) if model is not None: break except Exception as exc_info: if i == len(kwargses_to_try) - 1: raise exc_info else: logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {kwargs_to_try}", exc_info=exc_info) finally: torch.set_default_dtype(torch.float32) for i, kwargs_to_try in enumerate(kwargses_to_try): try: try: processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) except: processor = None if isinstance(processor, PreTrainedTokenizerBase): tokenizer = processor processor = None else: try: tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try) except Exception: try: tokenizer = AutoTokenizer.from_pretrained(ckpt_name, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try) except Exception: if repo_id != ckpt_name: tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True, legacy=False, **hub_kwargs, **kwargs_to_try) else: raise if tokenizer is not None or processor is not None: break except Exception as exc_info: if i == len(kwargses_to_try) - 1: raise exc_info finally: torch.set_default_dtype(torch.float32) if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"): model.enable_xformers_memory_efficient_attention() logger.debug("enabled xformers memory efficient attention") model_managed = TransformersManagedModel( repo_id=repo_id, model=model, tokenizer=tokenizer, config_dict=config_dict, processor=processor ) model_managed._model_type = model_type return model_managed def generate(self, tokens: TOKENS_TYPE = None, max_new_tokens: int = 512, seed: int = 0, sampler: Optional[GENERATION_KWARGS_TYPE] = None, *args, **kwargs) -> str: tokens = copy.copy(tokens) tokens_original = copy.copy(tokens) sampler = sampler or {} generate_kwargs = copy.copy(sampler) load_models_gpu([self]) transformers_model: PreTrainedModel = self.model tokenizer: PreTrainedTokenizerBase | AutoTokenizer = self.tokenizer # remove unused inputs # maximizes compatibility with different models generate_signature = inspect.signature(transformers_model.generate).parameters prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters if hasattr(transformers_model, "forward"): forward_signature = inspect.signature(transformers_model.forward).parameters else: forward_signature = {} to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature, forward_signature]))) gen_sig_keys = generate_signature.keys() if "tgt_lang" in tokens: to_delete.add("tgt_lang") to_delete.add("src_lang") to_delete.discard("input_ids") if "forced_bos_token_id" in tokens: to_delete.discard("forced_bos_token_id") elif hasattr(tokenizer, "convert_tokens_to_ids"): generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"]) else: logger.warning(f"tokenizer {tokenizer} unexpected for translation task") if "input_ids" in tokens and "inputs" in tokens: if "input_ids" in gen_sig_keys: to_delete.add("inputs") elif "inputs" in gen_sig_keys: to_delete.add("input_ids") for unused_kwarg in to_delete: tokens.pop(unused_kwarg) logger.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing") # images should be moved to model for key in ("images", "pixel_values"): if key in tokens: tokens[key] = tokens[key].to(device=self.current_device, dtype=self.model_dtype()) # sets up inputs inputs = tokens # used to determine if text streaming is supported num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams) progress_bar: ProgressBar with comfy_progress(total=max_new_tokens) as progress_bar: # todo: deal with batches correctly, don't assume batch size 1 token_count = 0 # progress def on_finalized_text(next_token: str, stop: bool): nonlocal token_count nonlocal progress_bar token_count += 1 preview = TransformerStreamedProgress(next_token=next_token) progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview) text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) try: import triton # pylint: disable=import-error has_triton = True except (ImportError, ModuleNotFoundError): has_triton = False with seed_for_block(seed), torch.inference_mode(mode=True) if has_triton else contextlib.nullcontext(): if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs: inputs.pop("attention_mask") from ..patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers def _generate(inputs, streamer, max_new_tokens, **generate_kwargs): return transformers_model.generate( **inputs, streamer=streamer, max_new_tokens=max_new_tokens, **generate_kwargs ) output_ids = WrapperExecutor.new_class_executor( _generate, self, get_all_wrappers(WrappersMP.APPLY_MODEL, self.model_options) ).execute(inputs, text_streamer if num_beams <= 1 else None, max_new_tokens, **generate_kwargs) if not transformers_model.config.is_encoder_decoder: start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1] output_ids = output_ids[:, start_position:] if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original: prev_src_lang = tokenizer.src_lang tokenizer.src_lang = tokens_original["src_lang"] else: prev_src_lang = None # todo: is this redundant consider I'm decoding in the on_finalized_text block? try: outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=self._model_type not in _DO_NOT_SKIP_SPECIAL_TOKENS, clean_up_tokenization_spaces=False) finally: if prev_src_lang is not None: tokenizer.src_lang = prev_src_lang # gpu-loaded stuff like images can now be unloaded if hasattr(tokens, "to"): del tokens else: for to_delete in tokens.values(): del to_delete del tokens # todo: better support batches return outputs[0] @property def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer: return self._tokenizer @property def processor(self) -> AutoProcessor | ProcessorMixin | None: return self._processor @property def config_dict(self) -> dict: """ The original configuration dictionary located in the Transformers model. Many models derive from base models and should inherit their settings like a chat template. This config_dict will have the base model's name in _name_or_path, enabling a lookup for the valid chat template when it is not specified by the derived model (it almost never is). :return: the dict value of the config.json in the HuggingFace model """ if self._config_dict is not None: return self._config_dict return self.model.config.to_dict() def lowvram_patch_counter(self): return 0 @property def current_device(self) -> torch.device: return self.model.device def is_clone(self, other: Any) -> bool: return hasattr(other, "model") and self.model is other.model def clone_has_same_weights(self, clone: Any) -> bool: if not isinstance(clone, TransformersManagedModel): return False clone: TransformersManagedModel if not self.is_clone(clone): return False try: return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters()) except ValueError as no_adapters: return True def model_size(self) -> int: return self._size def model_patches_to(self, arg: torch.device | torch.dtype): if getattr(self.model, "is_loaded_in_4bit", False) or getattr(self.model, "is_loaded_in_8bit", False): return if isinstance(arg, torch.device): self.model.to(device=arg) else: self.model.to(arg) def model_dtype(self) -> torch.dtype: return self.model.dtype def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: return self.model.to(device=device_to) def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: return self.model.to(device=device_to) def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel: model = copy.copy(self) model._processor = processor if hasattr(processor, "tokenizer") and overwrite_tokenizer: model._tokenizer = processor.tokenizer self._on_set_processor(model._processor) return model def _on_set_processor(self, processor: Any): if processor is not None and hasattr(processor, "image_processor") and hasattr(processor.image_processor, "do_rescale"): processor.image_processor.do_rescale = False def tokenize(self, prompt: str | LanguagePrompt, images: RGBImageBatch | None, videos: list[torch.Tensor] | None = None, chat_template: str | None = None) -> ProcessorResult: tokenizer = self.processor if self.processor is not None else self.tokenizer assert tokenizer is not None assert hasattr(tokenizer, "decode") # try to retrieve a matching chat template chat_template = chat_template or tokenizer.chat_template if hasattr(tokenizer, "chat_template") else None if chat_template is None and self.config_dict is not None and "_name_or_path" in self.config_dict: candidate_chat_templates = [(name, template) for name, template in KNOWN_CHAT_TEMPLATES.items() if name in self.config_dict["_name_or_path"] or name in self.model.name_or_path] if len(candidate_chat_templates) > 0: filename, chat_template = candidate_chat_templates[0] logger.debug(f"Selected chat template filename={filename} for {self.model.name_or_path}") if isinstance(images, list): images = torch.stack(images, dim=0) if images is not None: image_sizes = [(image.shape[-2], image.shape[-3]) for image in images] else: image_sizes = [] # todo: what is the best choice for this? # probably select a size that related to the vision tower? images = torch.zeros((0, 0, 0, 3)) try: if hasattr(tokenizer, "apply_chat_template"): messages: LanguagePrompt if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): messages = prompt elif images is not None and len(images) > 0 or videos is not None and len(videos) > 0: messages = [ {"role": "user", "content": [ { "type": "text", "text": prompt if isinstance(prompt, str) else "" } ] + [ {"type": "image"} for _ in range(len(images)) ] + [ {"type": "video"} for _ in range(len(videos)) ] } ] else: messages = [ {"role": "user", "content": prompt}, ] prompt = tokenizer.apply_chat_template(messages, chat_template=chat_template, add_generation_prompt=True, tokenize=False) except Exception as exc: logger.debug("Could not apply chat template", exc_info=exc) if isinstance(prompt, list): # Fallback: extract text from messages if chat template application failed or wasn't available extracted_text = [] for message in prompt: if isinstance(message, dict) and "content" in message: content = message["content"] if isinstance(content, str): extracted_text.append(content) elif isinstance(content, list): for item in content: if isinstance(item, dict) and item.get("type") == "text": extracted_text.append(item.get("text", "")) prompt = "\n".join(extracted_text) if self.processor is None and isinstance(prompt, str): batch_encoding = tokenizer(prompt, return_tensors="pt").to(device=self.load_device) return {**batch_encoding} else: if hasattr(self.processor, "to"): self.processor.to(device=self.load_device) # convert tuple to list from images.unbind() for paligemma workaround image_tensor_list = list(images.unbind()) if images is not None and len(images) > 0 else None # Convert videos to list of list of frames (uint8) if videos is not None and len(videos) > 0: new_videos = [] for v in videos: # Convert to uint8 0-255 if float if v.dtype == torch.float32 or v.dtype == torch.float16 or v.dtype == torch.bfloat16: v = (v * 255).to(torch.uint8) # Convert (T, H, W, C) tensor to list of (H, W, C) tensors if v.ndim == 4: new_videos.append(list(v)) else: new_videos.append([v]) # Fallback if not 4D videos = new_videos # Check if processor accepts 'videos' argument import inspect processor_params = inspect.signature(self.processor).parameters has_videos_arg = "videos" in processor_params kwargs = { "text": [prompt], "images": image_tensor_list, "return_tensors": "pt", "padding": True, } if videos is None or len(videos) == 0: pass elif has_videos_arg: kwargs["videos"] = videos if "input_data_format" in processor_params: kwargs["input_data_format"] = "channels_last" elif videos is not None and len(videos) > 0: if args.enable_video_to_image_fallback: # Fallback: flatten video frames into images if processor doesn't support 'videos' # videos is List[List[Frame]] where Frame is (H, W, C) flattened_frames = [] for video in videos: flattened_frames.extend(video) # Convert list of frames to list of tensors if needed, or just append to images list # images is currently a list of tensors if kwargs["images"] is None: kwargs["images"] = [] # Ensure frames are in the same format as images (tensors) # Frames in videos are already tensors (uint8) kwargs["images"].extend(flattened_frames) else: logger.warning(f"Model {self.model.name_or_path} does not support video inputs and video-to-image fallback is disabled. Use --enable-video-to-image-fallback to enable it.") try: batch_feature: BatchFeature = self.processor(**kwargs) except TypeError as exc_info: logger.warning(f"Exception while trying to run processor. Your transformers package is version {transformers.__version__} and may need to be updated") raise exc_info if hasattr(self.processor, "to"): self.processor.to(device=self.offload_device) assert "input_ids" in batch_feature try: batch_feature.to(device=self.load_device, dtype=self.model_dtype()) except TypeError: # works around Pixtral processor bug batch_feature.to(self.load_device) batch_feature.to(self.model_dtype()) # noinspection PyTypeChecker batch_feature_dict = { "inputs": batch_feature["input_ids"], **batch_feature } if "pixel_values" in batch_feature and "image_sizes" not in batch_feature_dict: batch_feature_dict["image_sizes"] = image_sizes if "pixel_values" in batch_feature and "images" not in batch_feature_dict: batch_feature_dict["images"] = batch_feature["pixel_values"] return batch_feature_dict @property def repo_id(self) -> str: return self._repo_id def __str__(self): if self.repo_id is not None: repo_id_as_path = pathlib.PurePath(self.repo_id) return f"" else: return f"" def clone(self) -> TransformersManagedModel: m = copy.copy(self) # deep copy a few objects m._object_patches = copy.copy(self._object_patches) return m def add_object_patch(self, name: str, obj: Any): # for the sake of compatibility, rewrite the name to the actual model field if name == "diffusion_model": name = "model" self._object_patches[name] = obj def get_model_object(self, name: str) -> torch.nn.Module: if name == "diffusion_model": name = "model" return super().get_model_object(name) @property def model(self) -> PreTrainedModel | torch.nn.Module: return self._object_patches.get("model", self._model) class _ProgressTextStreamer(TextStreamer): def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.on_finalized_text_handler = on_finalized_text def on_finalized_text(self, text: str, stream_end: bool = False): self.on_finalized_text_handler(text, stream_end) class _ProgressLogitsProcessor(LogitsProcessor): def __init__(self, model: TransformersManagedModel): self.eos_token_id = model.tokenizer.eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: probabilities = scores.softmax(dim=-1) self.eos_probability = probabilities[:, self.eos_token_id].item() return scores