ComfyUI/comfy/language/fastchat_adapters.py
doctorpangloss 8741cb3ce8 LLM support in ComfyUI
- Currently uses `transformers`
 - Supports model management and correctly loading and unloading models
   based on what your machine can support
 - Includes a Text Diffusers 2 workflow to demonstrate text rendering in
   SD1.5
2024-05-14 17:30:23 -07:00

63 lines
2.2 KiB
Python

from __future__ import annotations
from typing import Optional
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.model_adapter import BaseModelAdapter
from transformers import AutoModelForCausalLM, AutoTokenizer
class Phi3Adapter(BaseModelAdapter):
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
def match(self, model_path: str):
return "phi-3-mini-128k-instruct" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
self.model = model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
if input:
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
else:
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
return prompt
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
prompt = self.generate_prompt(messages[-1]["content"])
for i in range(len(messages) - 2, -1, -1):
if messages[i]["role"] == "user":
prompt = self.generate_prompt(messages[i]["content"]) + prompt
elif messages[i]["role"] == "assistant":
prompt = messages[i]["content"] + prompt
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.eos_token_id,
}
output_ids = self.model.generate(
input_ids,
**generation_kwargs
)
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
output = output.replace(prompt, "").strip()
return output
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("phi-3-mini")