ComfyUI/comfy_extras/nodes/nodes_language.py
doctorpangloss 8741cb3ce8 LLM support in ComfyUI
- Currently uses `transformers`
 - Supports model management and correctly loading and unloading models
   based on what your machine can support
 - Includes a Text Diffusers 2 workflow to demonstrate text rendering in
   SD1.5
2024-05-14 17:30:23 -07:00

140 lines
4.6 KiB
Python

from __future__ import annotations
from typing import Any, List, Dict
import torch
from fastchat.model import get_conversation_template
from transformers import AutoModelForCausalLM, AutoTokenizer
from comfy.language.language_types import ProcArgsRes
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.utils import comfy_tqdm, seed_for_block
_transformer_args_deterministic_decoding = {
"max_length": ("INT", {"default": 4096, "min": 1}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0}),
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
}
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes:
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding}
seed = generate_kwargs.pop("seed", 0)
return ProcArgsRes(seed, generate_kwargs)
class TransformersLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),)
}
}
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str):
with comfy_tqdm():
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
return model_managed,
class SimpleBatchDecode(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
tokenizer = model.tokenizer
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
with comfy_tqdm():
with seed_for_block(seed):
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return outputs,
class SimpleInstruct(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
conv = get_conversation_template(model.repo_id)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = model.tokenizer([prompt], return_token_type_ids=False)
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
with seed_for_block(seed):
output_ids = model.model.generate(
**inputs,
do_sample=True,
**generate_kwargs
)
if model.model.config.is_encoder_decoder:
output_ids = output_ids[0]
else:
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
outputs = model.tokenizer.decode(
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
)
return outputs,
class PreviewString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {}),
}
}
FUNCTION = "execute"
RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True
def execute(self, value: str):
return {"ui": {"string": [value]}}
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformersLoader,
SimpleBatchDecode,
SimpleInstruct,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls