mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Improve language and compositing nodes
This commit is contained in:
parent
7e1201e777
commit
a4fb34a0b8
@ -206,6 +206,15 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed."
|
help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--openai-api-key",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
help="Configures the OpenAI API Key for the OpenAI nodes",
|
||||||
|
env_var="OPENAI_API_KEY",
|
||||||
|
default=None
|
||||||
|
)
|
||||||
|
|
||||||
# now give plugins a chance to add configuration
|
# now give plugins a chance to add configuration
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -111,6 +111,7 @@ class Configuration(dict):
|
|||||||
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||||
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
||||||
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
|
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
|
||||||
|
openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -198,6 +199,7 @@ class Configuration(dict):
|
|||||||
self[key] = value
|
self[key] = value
|
||||||
|
|
||||||
self.executor_factory: str = "ThreadPoolExecutor"
|
self.executor_factory: str = "ThreadPoolExecutor"
|
||||||
|
self.openai_api_key: Optional[str] = None
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if item not in self:
|
if item not in self:
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from jaxtyping import Float
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
ImageBatch = Float[Tensor, "batch height width channels"]
|
ImageBatch = Float[Tensor, "batch height width channels"]
|
||||||
|
MaskBatch = Float[Tensor, "batch height width"]
|
||||||
RGBImageBatch = Float[Tensor, "batch height width 3"]
|
RGBImageBatch = Float[Tensor, "batch height width 3"]
|
||||||
RGBAImageBatch = Float[Tensor, "batch height width 4"]
|
RGBAImageBatch = Float[Tensor, "batch height width 4"]
|
||||||
RGBImage = Float[Tensor, "height width 3"]
|
RGBImage = Float[Tensor, "height width 3"]
|
||||||
|
|||||||
@ -24,13 +24,14 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
|||||||
args: list = (),
|
args: list = (),
|
||||||
kwargs: dict = {},
|
kwargs: dict = {},
|
||||||
timeout: float = None) -> ProcessFuture:
|
timeout: float = None) -> ProcessFuture:
|
||||||
try:
|
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
|
||||||
args: ExecutePromptArgs
|
# try:
|
||||||
prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
|
# args: ExecutePromptArgs
|
||||||
|
# prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
|
||||||
except ValueError:
|
#
|
||||||
pass
|
# except ValueError:
|
||||||
super().schedule(function, args, kwargs, timeout)
|
# pass
|
||||||
|
return super().schedule(function, args, kwargs, timeout)
|
||||||
|
|
||||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||||
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, Callable, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import BatchEncoding
|
from PIL.Image import Image
|
||||||
|
from transformers import BatchEncoding, BatchFeature, TensorType
|
||||||
|
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, TruncationStrategy
|
||||||
|
from transformers.utils import PaddingStrategy
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
from comfy.component_model.tensor_types import RGBImageBatch
|
||||||
|
|
||||||
|
|
||||||
class ProcessorResult(TypedDict):
|
class ProcessorResult(TypedDict):
|
||||||
"""
|
"""
|
||||||
@ -18,7 +26,61 @@ class ProcessorResult(TypedDict):
|
|||||||
|
|
||||||
attention_mask: NotRequired[torch.Tensor]
|
attention_mask: NotRequired[torch.Tensor]
|
||||||
pixel_values: NotRequired[torch.Tensor]
|
pixel_values: NotRequired[torch.Tensor]
|
||||||
|
images: NotRequired[RGBImageBatch]
|
||||||
images: NotRequired[torch.Tensor]
|
inputs: NotRequired[BatchEncoding | list[str]]
|
||||||
inputs: NotRequired[BatchEncoding]
|
|
||||||
image_sizes: NotRequired[torch.Tensor]
|
image_sizes: NotRequired[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationKwargs(TypedDict):
|
||||||
|
top_k: NotRequired[int]
|
||||||
|
top_p: NotRequired[float]
|
||||||
|
temperature: NotRequired[float]
|
||||||
|
penalty_alpha: NotRequired[float]
|
||||||
|
num_beams: NotRequired[int]
|
||||||
|
early_stopping: NotRequired[bool]
|
||||||
|
|
||||||
|
|
||||||
|
GENERATION_KWARGS_TYPE = GenerationKwargs
|
||||||
|
GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
|
||||||
|
TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
|
||||||
|
TOKENS_TYPE_NAME = "TOKENS"
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerStreamedProgress(TypedDict):
|
||||||
|
next_token: str
|
||||||
|
|
||||||
|
|
||||||
|
LLaVAProcessor = Callable[
|
||||||
|
[
|
||||||
|
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
|
||||||
|
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
|
||||||
|
Union[bool, str, PaddingStrategy], # padding parameter
|
||||||
|
Union[bool, str, TruncationStrategy], # truncation parameter
|
||||||
|
Optional[int], # max_length parameter
|
||||||
|
Optional[Union[str, TensorType]] # return_tensors parameter
|
||||||
|
],
|
||||||
|
BatchFeature
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class LanguageModel(Protocol):
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "LanguageModel":
|
||||||
|
...
|
||||||
|
|
||||||
|
def generate(self, tokens: TOKENS_TYPE = None,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
repetition_penalty: float = 0.0,
|
||||||
|
seed: int = 0,
|
||||||
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str | None = None) -> ProcessorResult:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_id(self) -> str:
|
||||||
|
return ""
|
||||||
|
|||||||
@ -1,38 +1,32 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import pathlib
|
import pathlib
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Any, Callable, Union, List
|
from functools import reduce
|
||||||
|
from typing import Optional, Any, Callable, List
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
||||||
TensorType, BatchFeature
|
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
|
||||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
|
PretrainedConfig, TextStreamer, LogitsProcessor
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
||||||
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
|
||||||
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
from .language_types import ProcessorResult
|
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
|
||||||
from ..model_management import unet_offload_device, get_torch_device
|
LLaVAProcessor, LanguageModel
|
||||||
|
from .. import model_management
|
||||||
|
from ..model_downloader import get_or_download_huggingface_repo
|
||||||
|
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
|
||||||
from ..model_management_types import ModelManageable
|
from ..model_management_types import ModelManageable
|
||||||
|
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
||||||
LLaVAProcessor = Callable[
|
|
||||||
[
|
|
||||||
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
|
|
||||||
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
|
|
||||||
Union[bool, str, PaddingStrategy], # padding parameter
|
|
||||||
Union[bool, str, TruncationStrategy], # truncation parameter
|
|
||||||
Optional[int], # max_length parameter
|
|
||||||
Optional[Union[str, TensorType]] # return_tensors parameter
|
|
||||||
],
|
|
||||||
BatchFeature
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TransformersManagedModel(ModelManageable):
|
class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
@ -41,7 +35,7 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
config_dict: Optional[dict] = None,
|
config_dict: Optional[dict] = None,
|
||||||
processor: Optional[ProcessorMixin | AutoProcessor] = None
|
processor: Optional[ProcessorMixin | AutoProcessor] = None
|
||||||
):
|
):
|
||||||
self.repo_id = repo_id
|
self._repo_id = repo_id
|
||||||
self.model = model
|
self.model = model
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._processor = processor
|
self._processor = processor
|
||||||
@ -54,6 +48,200 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
if model.device != self.offload_device:
|
if model.device != self.offload_device:
|
||||||
model.to(device=self.offload_device)
|
model.to(device=self.offload_device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "TransformersManagedModel":
|
||||||
|
hub_kwargs = {}
|
||||||
|
if subfolder is not None and subfolder != "":
|
||||||
|
hub_kwargs["subfolder"] = subfolder
|
||||||
|
repo_id = ckpt_name
|
||||||
|
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||||
|
with comfy_tqdm():
|
||||||
|
from_pretrained_kwargs = {
|
||||||
|
"pretrained_model_name_or_path": ckpt_name,
|
||||||
|
"trust_remote_code": True,
|
||||||
|
**hub_kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
# compute bitsandbytes configuration
|
||||||
|
try:
|
||||||
|
import bitsandbytes
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
||||||
|
model_type = config_dict["model_type"]
|
||||||
|
# language models prefer to use bfloat16 over float16
|
||||||
|
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||||
|
"low_cpu_mem_usage": True,
|
||||||
|
"device_map": str(unet_offload_device()), }, {})
|
||||||
|
|
||||||
|
# if we have flash-attn installed, try to use it
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
attn_override_kwargs = {
|
||||||
|
"attn_implementation": "flash_attention_2",
|
||||||
|
**kwargs_to_try[0]
|
||||||
|
}
|
||||||
|
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
||||||
|
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
for i, props in enumerate(kwargs_to_try):
|
||||||
|
try:
|
||||||
|
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
||||||
|
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||||
|
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||||
|
else:
|
||||||
|
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
||||||
|
if model is not None:
|
||||||
|
break
|
||||||
|
except Exception as exc_info:
|
||||||
|
if i == len(kwargs_to_try) - 1:
|
||||||
|
raise exc_info
|
||||||
|
else:
|
||||||
|
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
||||||
|
finally:
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
for i, props in enumerate(kwargs_to_try):
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
||||||
|
except:
|
||||||
|
processor = None
|
||||||
|
if isinstance(processor, PreTrainedTokenizerBase):
|
||||||
|
tokenizer = processor
|
||||||
|
processor = None
|
||||||
|
else:
|
||||||
|
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
||||||
|
if tokenizer is not None or processor is not None:
|
||||||
|
break
|
||||||
|
except Exception as exc_info:
|
||||||
|
if i == len(kwargs_to_try) - 1:
|
||||||
|
raise exc_info
|
||||||
|
finally:
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
|
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
|
model.enable_xformers_memory_efficient_attention()
|
||||||
|
logging.debug("enabled xformers memory efficient attention")
|
||||||
|
|
||||||
|
model_managed = TransformersManagedModel(
|
||||||
|
repo_id=repo_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
config_dict=config_dict,
|
||||||
|
processor=processor
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_managed
|
||||||
|
|
||||||
|
def generate(self, tokens: TOKENS_TYPE = None,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
repetition_penalty: float = 0.0,
|
||||||
|
seed: int = 0,
|
||||||
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs) -> str:
|
||||||
|
tokens = copy.copy(tokens)
|
||||||
|
tokens_original = copy.copy(tokens)
|
||||||
|
sampler = sampler or {}
|
||||||
|
generate_kwargs = copy.copy(sampler)
|
||||||
|
load_models_gpu([self])
|
||||||
|
transformers_model: PreTrainedModel = self.model
|
||||||
|
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = self.tokenizer
|
||||||
|
# remove unused inputs
|
||||||
|
# maximizes compatibility with different models
|
||||||
|
generate_signature = inspect.signature(transformers_model.generate).parameters
|
||||||
|
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
|
||||||
|
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
|
||||||
|
gen_sig_keys = generate_signature.keys()
|
||||||
|
if "tgt_lang" in tokens:
|
||||||
|
to_delete.add("tgt_lang")
|
||||||
|
to_delete.add("src_lang")
|
||||||
|
to_delete.discard("input_ids")
|
||||||
|
if "forced_bos_token_id" in tokens:
|
||||||
|
to_delete.discard("forced_bos_token_id")
|
||||||
|
elif hasattr(tokenizer, "convert_tokens_to_ids"):
|
||||||
|
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
|
||||||
|
else:
|
||||||
|
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
|
||||||
|
if "input_ids" in tokens and "inputs" in tokens:
|
||||||
|
if "input_ids" in gen_sig_keys:
|
||||||
|
to_delete.add("inputs")
|
||||||
|
elif "inputs" in gen_sig_keys:
|
||||||
|
to_delete.add("input_ids")
|
||||||
|
for unused_kwarg in to_delete:
|
||||||
|
tokens.pop(unused_kwarg)
|
||||||
|
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
|
||||||
|
|
||||||
|
# images should be moved to model
|
||||||
|
for key in ("images", "pixel_values"):
|
||||||
|
if key in tokens:
|
||||||
|
tokens[key] = tokens[key].to(device=self.current_device, dtype=self.model_dtype())
|
||||||
|
|
||||||
|
# sets up inputs
|
||||||
|
inputs = tokens
|
||||||
|
|
||||||
|
# used to determine if text streaming is supported
|
||||||
|
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
|
||||||
|
|
||||||
|
progress_bar: ProgressBar
|
||||||
|
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||||
|
# todo: deal with batches correctly, don't assume batch size 1
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
# progress
|
||||||
|
def on_finalized_text(next_token: str, stop: bool):
|
||||||
|
nonlocal token_count
|
||||||
|
nonlocal progress_bar
|
||||||
|
|
||||||
|
token_count += 1
|
||||||
|
preview = TransformerStreamedProgress(next_token=next_token)
|
||||||
|
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
|
||||||
|
|
||||||
|
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||||
|
|
||||||
|
with seed_for_block(seed):
|
||||||
|
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
||||||
|
inputs.pop("attention_mask")
|
||||||
|
output_ids = transformers_model.generate(
|
||||||
|
**inputs,
|
||||||
|
streamer=text_streamer if num_beams <= 1 else None,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
||||||
|
**generate_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not transformers_model.config.is_encoder_decoder:
|
||||||
|
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
||||||
|
output_ids = output_ids[:, start_position:]
|
||||||
|
|
||||||
|
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
|
||||||
|
prev_src_lang = tokenizer.src_lang
|
||||||
|
tokenizer.src_lang = tokens_original["src_lang"]
|
||||||
|
else:
|
||||||
|
prev_src_lang = None
|
||||||
|
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
||||||
|
try:
|
||||||
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
finally:
|
||||||
|
if prev_src_lang is not None:
|
||||||
|
tokenizer.src_lang = prev_src_lang
|
||||||
|
# gpu-loaded stuff like images can now be unloaded
|
||||||
|
if hasattr(tokens, "to"):
|
||||||
|
del tokens
|
||||||
|
else:
|
||||||
|
for to_delete in tokens.values():
|
||||||
|
del to_delete
|
||||||
|
del tokens
|
||||||
|
|
||||||
|
# todo: better support batches
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
|
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
|
||||||
return self._tokenizer
|
return self._tokenizer
|
||||||
@ -178,9 +366,32 @@ class TransformersManagedModel(ModelManageable):
|
|||||||
**batch_feature
|
**batch_feature
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_id(self) -> str:
|
||||||
|
return self._repo_id
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.repo_id is not None:
|
if self.repo_id is not None:
|
||||||
repo_id_as_path = pathlib.PurePath(self.repo_id)
|
repo_id_as_path = pathlib.PurePath(self.repo_id)
|
||||||
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
|
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
|
||||||
else:
|
else:
|
||||||
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|
||||||
|
|
||||||
|
|
||||||
|
class _ProgressTextStreamer(TextStreamer):
|
||||||
|
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||||
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||||
|
self.on_finalized_text_handler = on_finalized_text
|
||||||
|
|
||||||
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||||
|
self.on_finalized_text_handler(text, stream_end)
|
||||||
|
|
||||||
|
|
||||||
|
class _ProgressLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, model: TransformersManagedModel):
|
||||||
|
self.eos_token_id = model.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
probabilities = scores.softmax(dim=-1)
|
||||||
|
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
||||||
|
return scores
|
||||||
|
|||||||
@ -385,6 +385,10 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
|
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
|
||||||
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"),
|
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"),
|
||||||
HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"),
|
HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"),
|
||||||
|
HuggingFile("TheMistoAI/MistoLine_Flux.dev", "mistoline_flux.dev_v1.safetensors"),
|
||||||
|
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-canny-controlnet-v3.safetensors"),
|
||||||
|
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-depth-controlnet-v3.safetensors"),
|
||||||
|
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-hed-controlnet-v3.safetensors"),
|
||||||
], folder_name="controlnet")
|
], folder_name="controlnet")
|
||||||
|
|
||||||
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
@ -418,6 +422,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
|||||||
'llava-hf/llava-v1.6-mistral-7b-hf',
|
'llava-hf/llava-v1.6-mistral-7b-hf',
|
||||||
'facebook/nllb-200-distilled-1.3B',
|
'facebook/nllb-200-distilled-1.3B',
|
||||||
'THUDM/chatglm3-6b',
|
'THUDM/chatglm3-6b',
|
||||||
|
'roborovski/superprompt-v1',
|
||||||
}
|
}
|
||||||
|
|
||||||
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from .. import model_management
|
|||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
|
||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
from ..component_model.tensor_types import RGBImage
|
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch
|
||||||
from ..execution_context import current_execution_context
|
from ..execution_context import current_execution_context
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||||
@ -808,7 +808,7 @@ class ControlNetApply:
|
|||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
CATEGORY = "conditioning/controlnet"
|
||||||
|
|
||||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength):
|
||||||
if strength == 0:
|
if strength == 0:
|
||||||
return (conditioning, )
|
return (conditioning, )
|
||||||
|
|
||||||
@ -1573,7 +1573,7 @@ class LoadImage:
|
|||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
|
|
||||||
def load_image(self, image: str):
|
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
|
|
||||||
img = node_helpers.pillow(Image.open, image_path)
|
img = node_helpers.pillow(Image.open, image_path)
|
||||||
@ -1703,7 +1703,7 @@ class ImageScale:
|
|||||||
|
|
||||||
CATEGORY = "image/upscaling"
|
CATEGORY = "image/upscaling"
|
||||||
|
|
||||||
def upscale(self, image, upscale_method, width, height, crop):
|
def upscale(self, image: RGBImageBatch, upscale_method, width, height, crop) -> tuple[RGBImageBatch]:
|
||||||
if width == 0 and height == 0:
|
if width == 0 and height == 0:
|
||||||
s = image
|
s = image
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import torch
|
|||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch
|
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch, MaskBatch
|
||||||
from comfy.nodes.package_typing import CustomNode
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
|
|
||||||
@ -34,10 +34,7 @@ class PorterDuffMode(Enum):
|
|||||||
XOR = 17
|
XOR = 17
|
||||||
|
|
||||||
|
|
||||||
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
def _porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
||||||
# convert mask to alpha
|
|
||||||
src_alpha = 1 - src_alpha
|
|
||||||
dst_alpha = 1 - dst_alpha
|
|
||||||
# premultiply alpha
|
# premultiply alpha
|
||||||
src_image = src_image * src_alpha
|
src_image = src_image * src_alpha
|
||||||
dst_image = dst_image * dst_alpha
|
dst_image = dst_image * dst_alpha
|
||||||
@ -109,24 +106,31 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
|||||||
return out_image, out_alpha
|
return out_image, out_alpha
|
||||||
|
|
||||||
|
|
||||||
class PorterDuffImageComposite:
|
class PorterDuffImageCompositeV2:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"source": ("IMAGE",),
|
"source": ("IMAGE",),
|
||||||
"source_alpha": ("MASK",),
|
|
||||||
"destination": ("IMAGE",),
|
"destination": ("IMAGE",),
|
||||||
"destination_alpha": ("MASK",),
|
|
||||||
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
||||||
},
|
},
|
||||||
|
"optional": {
|
||||||
|
"source_alpha": ("MASK",),
|
||||||
|
"destination_alpha": ("MASK",),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "composite"
|
FUNCTION = "composite"
|
||||||
CATEGORY = "mask/compositing"
|
CATEGORY = "mask/compositing"
|
||||||
|
|
||||||
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
def composite(self, source: RGBImageBatch, destination: RGBImageBatch, mode, source_alpha: MaskBatch = None, destination_alpha: MaskBatch = None) -> tuple[RGBImageBatch, MaskBatch]:
|
||||||
|
if source_alpha is None:
|
||||||
|
source_alpha = torch.zeros(source.shape[:3])
|
||||||
|
if destination_alpha is None:
|
||||||
|
destination_alpha = torch.zeros(destination.shape[:3])
|
||||||
|
|
||||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
out_alphas = []
|
out_alphas = []
|
||||||
@ -153,7 +157,7 @@ class PorterDuffImageComposite:
|
|||||||
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
||||||
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||||
|
|
||||||
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
out_image, out_alpha = _porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
||||||
|
|
||||||
out_images.append(out_image)
|
out_images.append(out_image)
|
||||||
out_alphas.append(out_alpha.squeeze(2))
|
out_alphas.append(out_alpha.squeeze(2))
|
||||||
@ -162,6 +166,28 @@ class PorterDuffImageComposite:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class PorterDuffImageCompositeV1(PorterDuffImageCompositeV2):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"source": ("IMAGE",),
|
||||||
|
"source_alpha": ("MASK",),
|
||||||
|
"destination": ("IMAGE",),
|
||||||
|
"destination_alpha": ("MASK",),
|
||||||
|
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNCTION = "composite_v1"
|
||||||
|
|
||||||
|
def composite_v1(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> tuple[RGBImageBatch, MaskBatch]:
|
||||||
|
# convert mask to alpha
|
||||||
|
source_alpha = 1 - source_alpha
|
||||||
|
destination_alpha = 1 - destination_alpha
|
||||||
|
return super().composite(source, destination, mode, source_alpha, destination_alpha)
|
||||||
|
|
||||||
|
|
||||||
class SplitImageWithAlpha:
|
class SplitImageWithAlpha:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -312,7 +338,8 @@ class Posterize(CustomNode):
|
|||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
"PorterDuffImageComposite": PorterDuffImageCompositeV1,
|
||||||
|
"PorterDuffImageCompositeV2": PorterDuffImageCompositeV2,
|
||||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
"SplitImageWithAlpha": SplitImageWithAlpha,
|
||||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
"JoinImageWithAlpha": JoinImageWithAlpha,
|
||||||
"EnhanceContrast": EnhanceContrast,
|
"EnhanceContrast": EnhanceContrast,
|
||||||
@ -321,7 +348,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
"PorterDuffImageComposite": "Porter-Duff Image Composite (V1)",
|
||||||
|
"PorterDuffImageCompositeV2": "Image Composite",
|
||||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
"SplitImageWithAlpha": "Split Image with Alpha",
|
||||||
"JoinImageWithAlpha": "Join Image with Alpha",
|
"JoinImageWithAlpha": "Join Image with Alpha",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,80 +1,32 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
import operator
|
import operator
|
||||||
import os.path
|
import os.path
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Any, Dict, Optional, List, Callable, Union
|
from typing import Optional, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
from transformers import AutoProcessor
|
||||||
PreTrainedTokenizerBase, PretrainedConfig, AutoProcessor, BatchFeature, AutoModel, AutoModelForCausalLM, \
|
|
||||||
AutoModelForSeq2SeqLM
|
|
||||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, \
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, AutoModelForVision2Seq
|
|
||||||
from transformers.models.m2m_100.tokenization_m2m_100 import \
|
from transformers.models.m2m_100.tokenization_m2m_100 import \
|
||||||
FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES
|
FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES
|
||||||
from transformers.models.nllb.tokenization_nllb import \
|
from transformers.models.nllb.tokenization_nllb import \
|
||||||
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
|
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from comfy import model_management
|
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
from comfy.component_model.folder_path_types import SaveImagePathResponse
|
from comfy.component_model.folder_path_types import SaveImagePathResponse
|
||||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
from comfy.language.language_types import ProcessorResult
|
from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \
|
||||||
|
TOKENS_TYPE_NAME, LanguageModel
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
|
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
|
||||||
|
|
||||||
_AUTO_CHAT_TEMPLATE = "default"
|
_AUTO_CHAT_TEMPLATE = "default"
|
||||||
|
|
||||||
# add llava support
|
|
||||||
try:
|
|
||||||
from llava import model as _llava_model_side_effects
|
|
||||||
|
|
||||||
logging.debug("Additional LLaVA models are now supported")
|
|
||||||
except ImportError as exc:
|
|
||||||
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
|
||||||
|
|
||||||
# aka kwargs type
|
|
||||||
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
|
||||||
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
|
|
||||||
|
|
||||||
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
|
|
||||||
TOKENS_TYPE_NAME = "TOKENS"
|
|
||||||
|
|
||||||
|
|
||||||
class _ProgressTextStreamer(TextStreamer):
|
|
||||||
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
|
||||||
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
|
||||||
self.on_finalized_text_handler = on_finalized_text
|
|
||||||
|
|
||||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
|
||||||
self.on_finalized_text_handler(text, stream_end)
|
|
||||||
|
|
||||||
|
|
||||||
class _ProgressLogitsProcessor(LogitsProcessor):
|
|
||||||
def __init__(self, model: TransformersManagedModel):
|
|
||||||
self.eos_token_id = model.tokenizer.eos_token_id
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
||||||
probabilities = scores.softmax(dim=-1)
|
|
||||||
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
|
|
||||||
class TransformerStreamedProgress(TypedDict):
|
|
||||||
next_token: str
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerSamplerBase(CustomNode):
|
class TransformerSamplerBase(CustomNode):
|
||||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
||||||
RETURN_NAMES = "GENERATION ARGS",
|
RETURN_NAMES = "GENERATION ARGS",
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "language/samplers"
|
CATEGORY = "language/samplers"
|
||||||
@ -142,7 +94,7 @@ class TransformersGenerationConfig(CustomNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
||||||
RETURN_NAMES = "GENERATION ARGS",
|
RETURN_NAMES = "GENERATION ARGS",
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "language"
|
CATEGORY = "language"
|
||||||
@ -182,15 +134,15 @@ class TransformerBeamSearchSampler(TransformerSamplerBase):
|
|||||||
class TransformerMergeSamplers(CustomNode):
|
class TransformerMergeSamplers(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
range_ = {"value0": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
||||||
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
range_.update({f"value{i}": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"required": range_
|
"required": range_
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "language"
|
CATEGORY = "language"
|
||||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, **kwargs):
|
def execute(self, **kwargs):
|
||||||
@ -238,98 +190,11 @@ class TransformersLoader(CustomNode):
|
|||||||
|
|
||||||
CATEGORY = "language"
|
CATEGORY = "language"
|
||||||
RETURN_TYPES = "MODEL",
|
RETURN_TYPES = "MODEL",
|
||||||
|
RETURN_NAMES = "language model",
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs) -> tuple[TransformersManagedModel]:
|
||||||
hub_kwargs = {}
|
return TransformersManagedModel.from_pretrained(ckpt_name, subfolder),
|
||||||
if subfolder is not None and subfolder != "":
|
|
||||||
hub_kwargs["subfolder"] = subfolder
|
|
||||||
|
|
||||||
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
|
||||||
with comfy_tqdm():
|
|
||||||
from_pretrained_kwargs = {
|
|
||||||
"pretrained_model_name_or_path": ckpt_name,
|
|
||||||
"trust_remote_code": True,
|
|
||||||
**hub_kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
# if flash attention exists, use it
|
|
||||||
|
|
||||||
# compute bitsandbytes configuration
|
|
||||||
try:
|
|
||||||
import bitsandbytes
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
|
||||||
model_type = config_dict["model_type"]
|
|
||||||
# language models prefer to use bfloat16 over float16
|
|
||||||
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
|
||||||
"low_cpu_mem_usage": True,
|
|
||||||
"device_map": str(unet_offload_device()), }, {})
|
|
||||||
|
|
||||||
# if we have flash-attn installed, try to use it
|
|
||||||
try:
|
|
||||||
import flash_attn
|
|
||||||
attn_override_kwargs = {
|
|
||||||
"attn_implementation": "flash_attention_2",
|
|
||||||
**kwargs_to_try[0]
|
|
||||||
}
|
|
||||||
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
|
||||||
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
for i, props in enumerate(kwargs_to_try):
|
|
||||||
try:
|
|
||||||
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
|
||||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
|
||||||
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
|
||||||
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
|
||||||
else:
|
|
||||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
|
||||||
if model is not None:
|
|
||||||
break
|
|
||||||
except Exception as exc_info:
|
|
||||||
if i == len(kwargs_to_try) - 1:
|
|
||||||
raise exc_info
|
|
||||||
else:
|
|
||||||
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
|
||||||
finally:
|
|
||||||
torch.set_default_dtype(torch.float32)
|
|
||||||
|
|
||||||
for i, props in enumerate(kwargs_to_try):
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
|
||||||
except:
|
|
||||||
processor = None
|
|
||||||
if isinstance(processor, PreTrainedTokenizerBase):
|
|
||||||
tokenizer = processor
|
|
||||||
processor = None
|
|
||||||
else:
|
|
||||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
|
||||||
if tokenizer is not None or processor is not None:
|
|
||||||
break
|
|
||||||
except Exception as exc_info:
|
|
||||||
if i == len(kwargs_to_try) - 1:
|
|
||||||
raise exc_info
|
|
||||||
finally:
|
|
||||||
torch.set_default_dtype(torch.float32)
|
|
||||||
|
|
||||||
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
logging.debug("enabled xformers memory efficient attention")
|
|
||||||
|
|
||||||
model_managed = TransformersManagedModel(
|
|
||||||
repo_id=ckpt_name,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
config_dict=config_dict,
|
|
||||||
processor=processor
|
|
||||||
)
|
|
||||||
return model_managed,
|
|
||||||
|
|
||||||
|
|
||||||
class TransformersTokenize(CustomNode):
|
class TransformersTokenize(CustomNode):
|
||||||
@ -346,7 +211,7 @@ class TransformersTokenize(CustomNode):
|
|||||||
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
|
def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult:
|
||||||
return model.tokenize(prompt, [], None),
|
return model.tokenize(prompt, [], None),
|
||||||
|
|
||||||
|
|
||||||
@ -452,7 +317,7 @@ class OneShotInstructTokenize(CustomNode):
|
|||||||
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
|
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
|
||||||
if chat_template == _AUTO_CHAT_TEMPLATE:
|
if chat_template == _AUTO_CHAT_TEMPLATE:
|
||||||
# use an exact match
|
# use an exact match
|
||||||
model_name = os.path.basename(model.repo_id)
|
model_name = os.path.basename(model.repo_id)
|
||||||
@ -475,10 +340,9 @@ class TransformersGenerate(CustomNode):
|
|||||||
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
||||||
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
|
||||||
"use_cache": ("BOOLEAN", {"default": True}),
|
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
|
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -487,110 +351,14 @@ class TransformersGenerate(CustomNode):
|
|||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self,
|
def execute(self,
|
||||||
model: Optional[TransformersManagedModel] = None,
|
model: Optional[LanguageModel] = None,
|
||||||
tokens: _TOKENS_TYPE = None,
|
tokens: TOKENS_TYPE = None,
|
||||||
max_new_tokens: int = 512,
|
max_new_tokens: int = 512,
|
||||||
repetition_penalty: float = 0.0,
|
repetition_penalty: float = 0.0,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
tokens = copy.copy(tokens)
|
return model.generate(tokens, max_new_tokens, repetition_penalty, seed, sampler),
|
||||||
tokens_original = copy.copy(tokens)
|
|
||||||
sampler = sampler or {}
|
|
||||||
generate_kwargs = copy.copy(sampler)
|
|
||||||
load_models_gpu([model])
|
|
||||||
transformers_model: PreTrainedModel = model.model
|
|
||||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
|
||||||
# remove unused inputs
|
|
||||||
# maximizes compatibility with different models
|
|
||||||
generate_signature = inspect.signature(transformers_model.generate).parameters
|
|
||||||
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
|
|
||||||
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
|
|
||||||
gen_sig_keys = generate_signature.keys()
|
|
||||||
if "tgt_lang" in tokens:
|
|
||||||
to_delete.add("tgt_lang")
|
|
||||||
to_delete.add("src_lang")
|
|
||||||
to_delete.discard("input_ids")
|
|
||||||
if "forced_bos_token_id" in tokens:
|
|
||||||
to_delete.discard("forced_bos_token_id")
|
|
||||||
elif hasattr(tokenizer, "convert_tokens_to_ids"):
|
|
||||||
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
|
|
||||||
else:
|
|
||||||
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
|
|
||||||
if "input_ids" in tokens and "inputs" in tokens:
|
|
||||||
if "input_ids" in gen_sig_keys:
|
|
||||||
to_delete.add("inputs")
|
|
||||||
elif "inputs" in gen_sig_keys:
|
|
||||||
to_delete.add("input_ids")
|
|
||||||
for unused_kwarg in to_delete:
|
|
||||||
tokens.pop(unused_kwarg)
|
|
||||||
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
|
|
||||||
|
|
||||||
# images should be moved to model
|
|
||||||
for key in ("images", "pixel_values"):
|
|
||||||
if key in tokens:
|
|
||||||
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
|
|
||||||
|
|
||||||
# sets up inputs
|
|
||||||
inputs = tokens
|
|
||||||
|
|
||||||
# used to determine if text streaming is supported
|
|
||||||
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
|
|
||||||
|
|
||||||
progress_bar: ProgressBar
|
|
||||||
with comfy_progress(total=max_new_tokens) as progress_bar:
|
|
||||||
# todo: deal with batches correctly, don't assume batch size 1
|
|
||||||
token_count = 0
|
|
||||||
|
|
||||||
# progress
|
|
||||||
def on_finalized_text(next_token: str, stop: bool):
|
|
||||||
nonlocal token_count
|
|
||||||
nonlocal progress_bar
|
|
||||||
|
|
||||||
token_count += 1
|
|
||||||
preview = TransformerStreamedProgress(next_token=next_token)
|
|
||||||
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
|
|
||||||
|
|
||||||
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
|
||||||
|
|
||||||
with seed_for_block(seed):
|
|
||||||
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
|
||||||
inputs.pop("attention_mask")
|
|
||||||
output_ids = transformers_model.generate(
|
|
||||||
**inputs,
|
|
||||||
streamer=text_streamer if num_beams <= 1 else None,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
|
||||||
**generate_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not transformers_model.config.is_encoder_decoder:
|
|
||||||
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
|
||||||
output_ids = output_ids[:, start_position:]
|
|
||||||
|
|
||||||
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
|
|
||||||
prev_src_lang = tokenizer.src_lang
|
|
||||||
tokenizer.src_lang = tokens_original["src_lang"]
|
|
||||||
else:
|
|
||||||
prev_src_lang = None
|
|
||||||
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
|
||||||
try:
|
|
||||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
||||||
finally:
|
|
||||||
if prev_src_lang is not None:
|
|
||||||
tokenizer.src_lang = prev_src_lang
|
|
||||||
# gpu-loaded stuff like images can now be unloaded
|
|
||||||
if hasattr(tokens, "to"):
|
|
||||||
del tokens
|
|
||||||
else:
|
|
||||||
for to_delete in tokens.values():
|
|
||||||
del to_delete
|
|
||||||
del tokens
|
|
||||||
|
|
||||||
# todo: better support batches
|
|
||||||
return outputs[0],
|
|
||||||
|
|
||||||
|
|
||||||
class PreviewString(CustomNode):
|
class PreviewString(CustomNode):
|
||||||
|
|||||||
206
comfy_extras/nodes/nodes_openai.py
Normal file
206
comfy_extras/nodes/nodes_openai.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from openai import OpenAI
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy.component_model.tensor_types import RGBImageBatch
|
||||||
|
from comfy.language.language_types import LanguageModel, ProcessorResult, GENERATION_KWARGS_TYPE, TOKENS_TYPE, \
|
||||||
|
TransformerStreamedProgress
|
||||||
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
|
from comfy.utils import comfy_progress, ProgressBar, seed_for_block
|
||||||
|
|
||||||
|
|
||||||
|
class _Client:
|
||||||
|
_client: Optional[OpenAI] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def instance() -> OpenAI:
|
||||||
|
if _Client._client is None:
|
||||||
|
open_ai_api_key = args.openai_api_key
|
||||||
|
_Client._client = OpenAI(
|
||||||
|
api_key=open_ai_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _Client._client
|
||||||
|
|
||||||
|
|
||||||
|
def validate_has_key():
|
||||||
|
open_api_key = os.environ.get("OPENAI_API_KEY", args.openai_api_key)
|
||||||
|
if open_api_key is None or open_api_key == "":
|
||||||
|
return "set OPENAI_API_KEY environment variable"
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_base64(image: RGBImageBatch) -> str:
|
||||||
|
pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
pil_image.save(buffered, format="JPEG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILanguageModelWrapper(LanguageModel):
|
||||||
|
def __init__(self, model: str):
|
||||||
|
self.model = model
|
||||||
|
self.client = _Client.instance()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "OpenAILanguageModelWrapper":
|
||||||
|
return OpenAILanguageModelWrapper(ckpt_name)
|
||||||
|
|
||||||
|
def generate(self, tokens: TOKENS_TYPE = None,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
repetition_penalty: float = 0.0,
|
||||||
|
seed: int = 0,
|
||||||
|
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs) -> str:
|
||||||
|
sampler = sampler or {}
|
||||||
|
prompt = tokens.get("inputs", [])
|
||||||
|
prompt = "".join(prompt)
|
||||||
|
images = tokens.get("images", [])
|
||||||
|
messages: list[ChatCompletionMessageParam] = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
] + [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{image_to_base64(image)}"
|
||||||
|
}
|
||||||
|
} for image in images
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
progress_bar: ProgressBar
|
||||||
|
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||||
|
token_count = 0
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
def on_finalized_text(next_token: str, stop: bool):
|
||||||
|
nonlocal token_count
|
||||||
|
nonlocal progress_bar
|
||||||
|
nonlocal full_response
|
||||||
|
|
||||||
|
token_count += 1
|
||||||
|
full_response += next_token
|
||||||
|
preview = TransformerStreamedProgress(next_token=next_token)
|
||||||
|
progress_bar.update_absolute(max_new_tokens if stop else token_count, total=max_new_tokens, preview_image_or_output=preview)
|
||||||
|
|
||||||
|
with seed_for_block(seed):
|
||||||
|
stream = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_new_tokens,
|
||||||
|
temperature=sampler.get("temperature", 1.0),
|
||||||
|
top_p=sampler.get("top_p", 1.0),
|
||||||
|
# n=1,
|
||||||
|
# stop=None,
|
||||||
|
# presence_penalty=repetition_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
on_finalized_text(chunk.choices[0].delta.content, False)
|
||||||
|
|
||||||
|
on_finalized_text("", True) # Signal the end of streaming
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
def tokenize(self, prompt: str, images: RGBImageBatch, chat_template: str | None = None) -> ProcessorResult:
|
||||||
|
# OpenAI API doesn't require explicit tokenization, so we'll just return the prompt and images as is
|
||||||
|
return {
|
||||||
|
"inputs": [prompt],
|
||||||
|
"attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask
|
||||||
|
"images": images
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_id(self) -> str:
|
||||||
|
return f"openai/{self.model}"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILanguageModelLoader(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": (["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"], {"default": "gpt-3.5-turbo"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
RETURN_NAMES = ("language model",)
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "openai"
|
||||||
|
|
||||||
|
def execute(self, model: str) -> tuple[LanguageModel]:
|
||||||
|
return OpenAILanguageModelWrapper(model),
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls):
|
||||||
|
return validate_has_key()
|
||||||
|
|
||||||
|
|
||||||
|
class DallEGenerate(CustomNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {"required": {
|
||||||
|
"model": (["dall-e-2", "dall-e-3"], {"default": "dall-e-3"}),
|
||||||
|
"text": ("STRING", {"multiline": True}),
|
||||||
|
"size": (["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], {"default": "1024x1024"}),
|
||||||
|
"quality": (["standard", "hd"], {"default": "standard"}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "STRING",)
|
||||||
|
RETURN_NAMES = ("images", "revised prompt")
|
||||||
|
FUNCTION = "generate"
|
||||||
|
|
||||||
|
CATEGORY = "openai"
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
model: Literal["dall-e-2", "dall-e-3"],
|
||||||
|
text: str,
|
||||||
|
size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||||
|
quality: Literal["standard", "hd"]) -> tuple[RGBImageBatch, str]:
|
||||||
|
response = _Client.instance().images.generate(
|
||||||
|
model=model,
|
||||||
|
prompt=text,
|
||||||
|
size=size,
|
||||||
|
quality=quality,
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_url = response.data[0].url
|
||||||
|
image_response = requests.get(image_url)
|
||||||
|
|
||||||
|
img = Image.open(BytesIO(image_response.content))
|
||||||
|
|
||||||
|
image = np.array(img).astype(np.float32) / 255.0
|
||||||
|
image = torch.from_numpy(image)[None,]
|
||||||
|
return image, response.data[0].revised_prompt
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls):
|
||||||
|
return validate_has_key()
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"DallEGenerate": DallEGenerate,
|
||||||
|
"OpenAILanguageModelLoader": OpenAILanguageModelLoader
|
||||||
|
}
|
||||||
@ -65,4 +65,6 @@ ml_dtypes
|
|||||||
diffusers>=0.30.1
|
diffusers>=0.30.1
|
||||||
vtracer
|
vtracer
|
||||||
skia-python
|
skia-python
|
||||||
pebble>=5.0.7
|
pebble>=5.0.7
|
||||||
|
openai
|
||||||
|
anthropic
|
||||||
14
tests/inference/test_language.py
Normal file
14
tests/inference/test_language.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize
|
||||||
|
|
||||||
|
|
||||||
|
def test_integration_transformers_loader_and_tokenize():
|
||||||
|
loader = TransformersLoader()
|
||||||
|
tokenize = OneShotInstructTokenize()
|
||||||
|
|
||||||
|
model, = loader.execute("llava-hf/llava-v1.6-mistral-7b-hf", "")
|
||||||
|
tokens, = tokenize.execute(model, "Describe this image:", torch.rand((1, 224, 224, 3)), "llava-v1.6-mistral-7b-hf", )
|
||||||
|
|
||||||
|
assert isinstance(tokens, dict)
|
||||||
|
assert "input_ids" in tokens or "inputs" in tokens
|
||||||
@ -1,10 +1,17 @@
|
|||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy.language.language_types import LanguageModel, ProcessorResult
|
||||||
from comfy_extras.nodes.nodes_language import SaveString
|
from comfy_extras.nodes.nodes_language import SaveString
|
||||||
|
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, TransformersGenerate, \
|
||||||
|
PreviewString
|
||||||
|
from comfy_extras.nodes.nodes_openai import OpenAILanguageModelLoader, OpenAILanguageModelWrapper, DallEGenerate
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -57,3 +64,151 @@ def test_save_string_default_extension(save_string_node, mock_get_save_path):
|
|||||||
assert os.path.exists(saved_file_path)
|
assert os.path.exists(saved_file_path)
|
||||||
with open(saved_file_path, "r") as f:
|
with open(saved_file_path, "r") as f:
|
||||||
assert f.read() == test_string
|
assert f.read() == test_string
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_client():
|
||||||
|
with patch('comfy_extras.nodes.nodes_openai._Client') as mock_client:
|
||||||
|
instance = mock_client.instance.return_value
|
||||||
|
instance.chat.completions.create = Mock()
|
||||||
|
instance.images.generate = Mock()
|
||||||
|
yield instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_loader():
|
||||||
|
loader = TransformersLoader()
|
||||||
|
model, = loader.execute("microsoft/Phi-3-mini-4k-instruct", "")
|
||||||
|
assert isinstance(model, LanguageModel)
|
||||||
|
assert model.repo_id == "microsoft/Phi-3-mini-4k-instruct"
|
||||||
|
|
||||||
|
|
||||||
|
def test_one_shot_instruct_tokenize(mocker):
|
||||||
|
tokenize = OneShotInstructTokenize()
|
||||||
|
mock_model = mocker.Mock()
|
||||||
|
mock_model.tokenize.return_value = {"input_ids": torch.tensor([[1, 2, 3]])}
|
||||||
|
|
||||||
|
tokens, = tokenize.execute(mock_model, "What comes after apple?", [], "phi-3")
|
||||||
|
mock_model.tokenize.assert_called_once_with("What comes after apple?", [], mocker.ANY)
|
||||||
|
assert "input_ids" in tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_generate(mocker):
|
||||||
|
generate = TransformersGenerate()
|
||||||
|
mock_model = mocker.Mock()
|
||||||
|
mock_model.generate.return_value = "The letter B comes after A in the alphabet."
|
||||||
|
|
||||||
|
tokens: ProcessorResult = {"inputs": torch.tensor([[1, 2, 3]])}
|
||||||
|
result, = generate.execute(mock_model, tokens, 512, 0, 42)
|
||||||
|
mock_model.generate.assert_called_once()
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "letter B" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_preview_string():
|
||||||
|
preview = PreviewString()
|
||||||
|
result = preview.execute("Test output")
|
||||||
|
assert result == {"ui": {"string": ["Test output"]}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_language_model_loader():
|
||||||
|
if not "OPENAI_API_KEY" in os.environ:
|
||||||
|
pytest.skip("must set OPENAI_API_KEY")
|
||||||
|
loader = OpenAILanguageModelLoader()
|
||||||
|
model, = loader.execute("gpt-3.5-turbo")
|
||||||
|
assert isinstance(model, OpenAILanguageModelWrapper)
|
||||||
|
assert model.model == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_language_model_wrapper_generate(mock_openai_client):
|
||||||
|
wrapper = OpenAILanguageModelWrapper("gpt-3.5-turbo")
|
||||||
|
mock_stream = [
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="This "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="is "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="a "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="test "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="response."))]),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.return_value = mock_stream
|
||||||
|
|
||||||
|
tokens = {"inputs": ["What is the capital of France?"]}
|
||||||
|
result = wrapper.generate(tokens, max_new_tokens=50)
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}],
|
||||||
|
max_tokens=50,
|
||||||
|
temperature=1.0,
|
||||||
|
top_p=1.0,
|
||||||
|
seed=0,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
assert result == "This is a test response."
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_language_model_wrapper_generate_with_image(mock_openai_client):
|
||||||
|
wrapper = OpenAILanguageModelWrapper("gpt-4-vision-preview")
|
||||||
|
mock_stream = [
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="This "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="image "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="shows "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="a "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="landscape."))]),
|
||||||
|
]
|
||||||
|
mock_openai_client.chat.completions.create.return_value = mock_stream
|
||||||
|
|
||||||
|
image_tensor = torch.rand((1, 224, 224, 3))
|
||||||
|
tokens: ProcessorResult = {
|
||||||
|
"inputs": ["Describe this image:"],
|
||||||
|
"images": image_tensor
|
||||||
|
}
|
||||||
|
result = wrapper.generate(tokens, max_new_tokens=50)
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.assert_called_once()
|
||||||
|
assert result == "This image shows a landscape."
|
||||||
|
|
||||||
|
|
||||||
|
def test_dalle_generate(mock_openai_client):
|
||||||
|
dalle = DallEGenerate()
|
||||||
|
mock_openai_client.images.generate.return_value = Mock(
|
||||||
|
data=[Mock(url="http://example.com/image.jpg", revised_prompt="A beautiful sunset")]
|
||||||
|
)
|
||||||
|
test_image = Image.new('RGB', (10, 10), color='red')
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
test_image.save(img_byte_arr, format='PNG')
|
||||||
|
img_byte_arr = img_byte_arr.getvalue()
|
||||||
|
|
||||||
|
with patch('requests.get') as mock_get:
|
||||||
|
mock_get.return_value = Mock(content=img_byte_arr)
|
||||||
|
image, revised_prompt = dalle.generate("dall-e-3", "Create a sunset image", "1024x1024", "standard")
|
||||||
|
|
||||||
|
assert isinstance(image, torch.Tensor)
|
||||||
|
assert image.shape == (1, 10, 10, 3)
|
||||||
|
assert torch.allclose(image, torch.tensor([1.0, 0, 0]).view(1, 1, 1, 3).expand(1, 10, 10, 3))
|
||||||
|
assert revised_prompt == "A beautiful sunset"
|
||||||
|
mock_openai_client.images.generate.assert_called_once_with(
|
||||||
|
model="dall-e-3",
|
||||||
|
prompt="Create a sunset image",
|
||||||
|
size="1024x1024",
|
||||||
|
quality="standard",
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_integration_openai_loader_and_wrapper(mock_openai_client):
|
||||||
|
loader = OpenAILanguageModelLoader()
|
||||||
|
model, = loader.execute("gpt-4")
|
||||||
|
|
||||||
|
mock_stream = [
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="Paris "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="is "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="the "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="capital "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="of France."))]),
|
||||||
|
]
|
||||||
|
mock_openai_client.chat.completions.create.return_value = mock_stream
|
||||||
|
|
||||||
|
tokens = {"inputs": ["What is the capital of France?"]}
|
||||||
|
result = model.generate(tokens, max_new_tokens=50)
|
||||||
|
|
||||||
|
assert result == "Paris is the capital of France."
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user