ComfyUI/comfy/language/transformers_model_management.py
doctorpangloss 8741cb3ce8 LLM support in ComfyUI
- Currently uses `transformers`
 - Supports model management and correctly loading and unloading models
   based on what your machine can support
 - Includes a Text Diffusers 2 workflow to demonstrate text rendering in
   SD1.5
2024-05-14 17:30:23 -07:00

71 lines
2.6 KiB
Python

from __future__ import annotations
import warnings
from typing import Optional, Any
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from ..model_management import unet_offload_device, get_torch_device
from ..model_management_types import ModelManageable
class TransformersManagedModel(ModelManageable):
def __init__(self, repo_id: str, model: PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase] = None):
self.repo_id = repo_id
self.model = model
self.tokenizer = tokenizer
self._parameter_count = sum(param.nelement() for param in self.model.state_dict().values())
self._size = sum(param.nelement() * param.element_size() for param in self.model.state_dict().values())
self.load_device = get_torch_device()
self.offload_device = unet_offload_device()
if model.device != self.offload_device:
model.to(device=self.offload_device)
load_device: torch.device
offload_device: torch.device
model: PreTrainedModel
@property
def current_device(self) -> torch.device:
return self.model.device
def is_clone(self, other: Any) -> bool:
return hasattr(other, "model") and self.model is other.model
def clone_has_same_weights(self, clone: Any) -> bool:
if not isinstance(clone, TransformersManagedModel):
return False
clone: TransformersManagedModel
if not self.is_clone(clone):
return False
return frozenset(self.model.active_adapters()) == frozenset(clone.model.active_adapters())
def model_size(self) -> int:
return self._size
def model_patches_to(self, arg: torch.device | torch.dtype):
if isinstance(arg, torch.device):
self.model.to(device=arg)
else:
self.model.to(arg)
def model_dtype(self) -> torch.dtype:
return self.model.dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=device_to)
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
return self.model.to(device=offload_device)