Improve language and compositing nodes

This commit is contained in:
doctorpangloss 2024-09-05 21:56:04 -07:00
parent 7e1201e777
commit a4fb34a0b8
14 changed files with 767 additions and 303 deletions

View File

@ -206,6 +206,15 @@ def _create_parser() -> EnhancedConfigArgParser:
help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed."
)
parser.add_argument(
"--openai-api-key",
required=False,
type=str,
help="Configures the OpenAI API Key for the OpenAI nodes",
env_var="OPENAI_API_KEY",
default=None
)
# now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'):
try:

View File

@ -111,6 +111,7 @@ class Configuration(dict):
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes
"""
def __init__(self, **kwargs):
@ -198,6 +199,7 @@ class Configuration(dict):
self[key] = value
self.executor_factory: str = "ThreadPoolExecutor"
self.openai_api_key: Optional[str] = None
def __getattr__(self, item):
if item not in self:

View File

@ -2,6 +2,7 @@ from jaxtyping import Float
from torch import Tensor
ImageBatch = Float[Tensor, "batch height width channels"]
MaskBatch = Float[Tensor, "batch height width"]
RGBImageBatch = Float[Tensor, "batch height width 3"]
RGBAImageBatch = Float[Tensor, "batch height width 4"]
RGBImage = Float[Tensor, "height width 3"]

View File

@ -24,13 +24,14 @@ class ProcessPoolExecutor(ProcessPool, Executor):
args: list = (),
kwargs: dict = {},
timeout: float = None) -> ProcessFuture:
try:
args: ExecutePromptArgs
prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
except ValueError:
pass
super().schedule(function, args, kwargs, timeout)
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
# try:
# args: ExecutePromptArgs
# prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
#
# except ValueError:
# pass
return super().schedule(function, args, kwargs, timeout)
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)

View File

@ -1,9 +1,17 @@
from __future__ import annotations
from typing import Union, Callable, List, Optional, Protocol, runtime_checkable
import numpy as np
import torch
from transformers import BatchEncoding
from PIL.Image import Image
from transformers import BatchEncoding, BatchFeature, TensorType
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, TruncationStrategy
from transformers.utils import PaddingStrategy
from typing_extensions import TypedDict, NotRequired
from comfy.component_model.tensor_types import RGBImageBatch
class ProcessorResult(TypedDict):
"""
@ -18,7 +26,61 @@ class ProcessorResult(TypedDict):
attention_mask: NotRequired[torch.Tensor]
pixel_values: NotRequired[torch.Tensor]
images: NotRequired[torch.Tensor]
inputs: NotRequired[BatchEncoding]
images: NotRequired[RGBImageBatch]
inputs: NotRequired[BatchEncoding | list[str]]
image_sizes: NotRequired[torch.Tensor]
class GenerationKwargs(TypedDict):
top_k: NotRequired[int]
top_p: NotRequired[float]
temperature: NotRequired[float]
penalty_alpha: NotRequired[float]
num_beams: NotRequired[int]
early_stopping: NotRequired[bool]
GENERATION_KWARGS_TYPE = GenerationKwargs
GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
TOKENS_TYPE_NAME = "TOKENS"
class TransformerStreamedProgress(TypedDict):
next_token: str
LLaVAProcessor = Callable[
[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
Union[bool, str, PaddingStrategy], # padding parameter
Union[bool, str, TruncationStrategy], # truncation parameter
Optional[int], # max_length parameter
Optional[Union[str, TensorType]] # return_tensors parameter
],
BatchFeature
]
@runtime_checkable
class LanguageModel(Protocol):
@staticmethod
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "LanguageModel":
...
def generate(self, tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs) -> str:
...
def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str | None = None) -> ProcessorResult:
...
@property
def repo_id(self) -> str:
return ""

View File

@ -1,38 +1,32 @@
from __future__ import annotations
import copy
import inspect
import logging
import operator
import pathlib
import warnings
from typing import Optional, Any, Callable, Union, List
from functools import reduce
from typing import Optional, Any, Callable, List
import numpy as np
import torch
from PIL.Image import Image
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
TensorType, BatchFeature
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import PaddingStrategy
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
PretrainedConfig, TextStreamer, LogitsProcessor
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .chat_templates import KNOWN_CHAT_TEMPLATES
from .language_types import ProcessorResult
from ..model_management import unet_offload_device, get_torch_device
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
LLaVAProcessor, LanguageModel
from .. import model_management
from ..model_downloader import get_or_download_huggingface_repo
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
from ..model_management_types import ModelManageable
LLaVAProcessor = Callable[
[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
Union[bool, str, PaddingStrategy], # padding parameter
Union[bool, str, TruncationStrategy], # truncation parameter
Optional[int], # max_length parameter
Optional[Union[str, TensorType]] # return_tensors parameter
],
BatchFeature
]
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
class TransformersManagedModel(ModelManageable):
class TransformersManagedModel(ModelManageable, LanguageModel):
def __init__(
self,
repo_id: str,
@ -41,7 +35,7 @@ class TransformersManagedModel(ModelManageable):
config_dict: Optional[dict] = None,
processor: Optional[ProcessorMixin | AutoProcessor] = None
):
self.repo_id = repo_id
self._repo_id = repo_id
self.model = model
self._tokenizer = tokenizer
self._processor = processor
@ -54,6 +48,200 @@ class TransformersManagedModel(ModelManageable):
if model.device != self.offload_device:
model.to(device=self.offload_device)
@staticmethod
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "TransformersManagedModel":
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
repo_id = ckpt_name
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
with comfy_tqdm():
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
"trust_remote_code": True,
**hub_kwargs
}
# compute bitsandbytes configuration
try:
import bitsandbytes
except ImportError:
pass
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
model_type = config_dict["model_type"]
# language models prefer to use bfloat16 over float16
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()), }, {})
# if we have flash-attn installed, try to use it
try:
import flash_attn
attn_override_kwargs = {
"attn_implementation": "flash_attention_2",
**kwargs_to_try[0]
}
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
except ImportError:
pass
for i, props in enumerate(kwargs_to_try):
try:
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
else:
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
if model is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
else:
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
finally:
torch.set_default_dtype(torch.float32)
for i, props in enumerate(kwargs_to_try):
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
except:
processor = None
if isinstance(processor, PreTrainedTokenizerBase):
tokenizer = processor
processor = None
else:
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
if tokenizer is not None or processor is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
finally:
torch.set_default_dtype(torch.float32)
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
logging.debug("enabled xformers memory efficient attention")
model_managed = TransformersManagedModel(
repo_id=repo_id,
model=model,
tokenizer=tokenizer,
config_dict=config_dict,
processor=processor
)
return model_managed
def generate(self, tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs) -> str:
tokens = copy.copy(tokens)
tokens_original = copy.copy(tokens)
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
load_models_gpu([self])
transformers_model: PreTrainedModel = self.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = self.tokenizer
# remove unused inputs
# maximizes compatibility with different models
generate_signature = inspect.signature(transformers_model.generate).parameters
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
gen_sig_keys = generate_signature.keys()
if "tgt_lang" in tokens:
to_delete.add("tgt_lang")
to_delete.add("src_lang")
to_delete.discard("input_ids")
if "forced_bos_token_id" in tokens:
to_delete.discard("forced_bos_token_id")
elif hasattr(tokenizer, "convert_tokens_to_ids"):
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
else:
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
if "input_ids" in tokens and "inputs" in tokens:
if "input_ids" in gen_sig_keys:
to_delete.add("inputs")
elif "inputs" in gen_sig_keys:
to_delete.add("input_ids")
for unused_kwarg in to_delete:
tokens.pop(unused_kwarg)
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
# images should be moved to model
for key in ("images", "pixel_values"):
if key in tokens:
tokens[key] = tokens[key].to(device=self.current_device, dtype=self.model_dtype())
# sets up inputs
inputs = tokens
# used to determine if text streaming is supported
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
# todo: deal with batches correctly, don't assume batch size 1
token_count = 0
# progress
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
token_count += 1
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
inputs.pop("attention_mask")
output_ids = transformers_model.generate(
**inputs,
streamer=text_streamer if num_beams <= 1 else None,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**generate_kwargs
)
if not transformers_model.config.is_encoder_decoder:
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
output_ids = output_ids[:, start_position:]
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
prev_src_lang = tokenizer.src_lang
tokenizer.src_lang = tokens_original["src_lang"]
else:
prev_src_lang = None
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
try:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
finally:
if prev_src_lang is not None:
tokenizer.src_lang = prev_src_lang
# gpu-loaded stuff like images can now be unloaded
if hasattr(tokens, "to"):
del tokens
else:
for to_delete in tokens.values():
del to_delete
del tokens
# todo: better support batches
return outputs[0]
@property
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
return self._tokenizer
@ -178,9 +366,32 @@ class TransformersManagedModel(ModelManageable):
**batch_feature
}
@property
def repo_id(self) -> str:
return self._repo_id
def __str__(self):
if self.repo_id is not None:
repo_id_as_path = pathlib.PurePath(self.repo_id)
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
else:
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.on_finalized_text_handler = on_finalized_text
def on_finalized_text(self, text: str, stream_end: bool = False):
self.on_finalized_text_handler(text, stream_end)
class _ProgressLogitsProcessor(LogitsProcessor):
def __init__(self, model: TransformersManagedModel):
self.eos_token_id = model.tokenizer.eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
self.eos_probability = probabilities[:, self.eos_token_id].item()
return scores

View File

@ -385,6 +385,10 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"),
HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"),
HuggingFile("TheMistoAI/MistoLine_Flux.dev", "mistoline_flux.dev_v1.safetensors"),
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-canny-controlnet-v3.safetensors"),
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-depth-controlnet-v3.safetensors"),
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-hed-controlnet-v3.safetensors"),
], folder_name="controlnet")
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
@ -418,6 +422,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
'llava-hf/llava-v1.6-mistral-7b-hf',
'facebook/nllb-200-distilled-1.3B',
'THUDM/chatglm3-6b',
'roborovski/superprompt-v1',
}
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -24,7 +24,7 @@ from .. import model_management
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..component_model.tensor_types import RGBImage
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch
from ..execution_context import current_execution_context
from ..images import open_image
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
@ -808,7 +808,7 @@ class ControlNetApply:
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image, strength):
def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength):
if strength == 0:
return (conditioning, )
@ -1573,7 +1573,7 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image: str):
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)
@ -1703,7 +1703,7 @@ class ImageScale:
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, width, height, crop):
def upscale(self, image: RGBImageBatch, upscale_method, width, height, crop) -> tuple[RGBImageBatch]:
if width == 0 and height == 0:
s = image
else:

View File

@ -5,7 +5,7 @@ import torch
from skimage import exposure
import comfy.utils
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch, MaskBatch
from comfy.nodes.package_typing import CustomNode
@ -34,10 +34,7 @@ class PorterDuffMode(Enum):
XOR = 17
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
# convert mask to alpha
src_alpha = 1 - src_alpha
dst_alpha = 1 - dst_alpha
def _porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
# premultiply alpha
src_image = src_image * src_alpha
dst_image = dst_image * dst_alpha
@ -109,24 +106,31 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha
class PorterDuffImageComposite:
class PorterDuffImageCompositeV2:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source": ("IMAGE",),
"source_alpha": ("MASK",),
"destination": ("IMAGE",),
"destination_alpha": ("MASK",),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
},
"optional": {
"source_alpha": ("MASK",),
"destination_alpha": ("MASK",),
}
}
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "composite"
CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
def composite(self, source: RGBImageBatch, destination: RGBImageBatch, mode, source_alpha: MaskBatch = None, destination_alpha: MaskBatch = None) -> tuple[RGBImageBatch, MaskBatch]:
if source_alpha is None:
source_alpha = torch.zeros(source.shape[:3])
if destination_alpha is None:
destination_alpha = torch.zeros(destination.shape[:3])
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = []
out_alphas = []
@ -153,7 +157,7 @@ class PorterDuffImageComposite:
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
out_image, out_alpha = _porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2))
@ -162,6 +166,28 @@ class PorterDuffImageComposite:
return result
class PorterDuffImageCompositeV1(PorterDuffImageCompositeV2):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source": ("IMAGE",),
"source_alpha": ("MASK",),
"destination": ("IMAGE",),
"destination_alpha": ("MASK",),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
},
}
FUNCTION = "composite_v1"
def composite_v1(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> tuple[RGBImageBatch, MaskBatch]:
# convert mask to alpha
source_alpha = 1 - source_alpha
destination_alpha = 1 - destination_alpha
return super().composite(source, destination, mode, source_alpha, destination_alpha)
class SplitImageWithAlpha:
@classmethod
def INPUT_TYPES(s):
@ -312,7 +338,8 @@ class Posterize(CustomNode):
NODE_CLASS_MAPPINGS = {
"PorterDuffImageComposite": PorterDuffImageComposite,
"PorterDuffImageComposite": PorterDuffImageCompositeV1,
"PorterDuffImageCompositeV2": PorterDuffImageCompositeV2,
"SplitImageWithAlpha": SplitImageWithAlpha,
"JoinImageWithAlpha": JoinImageWithAlpha,
"EnhanceContrast": EnhanceContrast,
@ -321,7 +348,8 @@ NODE_CLASS_MAPPINGS = {
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PorterDuffImageComposite": "Porter-Duff Image Composite",
"PorterDuffImageComposite": "Porter-Duff Image Composite (V1)",
"PorterDuffImageCompositeV2": "Image Composite",
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}

View File

@ -1,80 +1,32 @@
from __future__ import annotations
import copy
import inspect
import logging
import operator
import os.path
from functools import reduce
from typing import Any, Dict, Optional, List, Callable, Union
from typing import Optional, List
import torch
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, PretrainedConfig, AutoProcessor, BatchFeature, AutoModel, AutoModelForCausalLM, \
AutoModelForSeq2SeqLM
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, \
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, AutoModelForVision2Seq
from transformers import AutoProcessor
from transformers.models.m2m_100.tokenization_m2m_100 import \
FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES
from transformers.models.nllb.tokenization_nllb import \
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
from typing_extensions import TypedDict
from comfy import model_management
from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import SaveImagePathResponse
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \
TOKENS_TYPE_NAME, LanguageModel
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
_AUTO_CHAT_TEMPLATE = "default"
# add llava support
try:
from llava import model as _llava_model_side_effects
logging.debug("Additional LLaVA models are now supported")
except ImportError as exc:
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
# aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any]
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
TOKENS_TYPE_NAME = "TOKENS"
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.on_finalized_text_handler = on_finalized_text
def on_finalized_text(self, text: str, stream_end: bool = False):
self.on_finalized_text_handler(text, stream_end)
class _ProgressLogitsProcessor(LogitsProcessor):
def __init__(self, model: TransformersManagedModel):
self.eos_token_id = model.tokenizer.eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
self.eos_probability = probabilities[:, self.eos_token_id].item()
return scores
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
class TransformerStreamedProgress(TypedDict):
next_token: str
class TransformerSamplerBase(CustomNode):
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language/samplers"
@ -142,7 +94,7 @@ class TransformersGenerationConfig(CustomNode):
}
}
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language"
@ -182,15 +134,15 @@ class TransformerBeamSearchSampler(TransformerSamplerBase):
class TransformerMergeSamplers(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
range_ = {"value0": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
range_.update({f"value{i}": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "language"
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
FUNCTION = "execute"
def execute(self, **kwargs):
@ -238,98 +190,11 @@ class TransformersLoader(CustomNode):
CATEGORY = "language"
RETURN_TYPES = "MODEL",
RETURN_NAMES = "language model",
FUNCTION = "execute"
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
with comfy_tqdm():
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
"trust_remote_code": True,
**hub_kwargs
}
# if flash attention exists, use it
# compute bitsandbytes configuration
try:
import bitsandbytes
except ImportError:
pass
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
model_type = config_dict["model_type"]
# language models prefer to use bfloat16 over float16
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()), }, {})
# if we have flash-attn installed, try to use it
try:
import flash_attn
attn_override_kwargs = {
"attn_implementation": "flash_attention_2",
**kwargs_to_try[0]
}
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
except ImportError:
pass
for i, props in enumerate(kwargs_to_try):
try:
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
else:
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
if model is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
else:
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
finally:
torch.set_default_dtype(torch.float32)
for i, props in enumerate(kwargs_to_try):
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
except:
processor = None
if isinstance(processor, PreTrainedTokenizerBase):
tokenizer = processor
processor = None
else:
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
if tokenizer is not None or processor is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
raise exc_info
finally:
torch.set_default_dtype(torch.float32)
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
logging.debug("enabled xformers memory efficient attention")
model_managed = TransformersManagedModel(
repo_id=ckpt_name,
model=model,
tokenizer=tokenizer,
config_dict=config_dict,
processor=processor
)
return model_managed,
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs) -> tuple[TransformersManagedModel]:
return TransformersManagedModel.from_pretrained(ckpt_name, subfolder),
class TransformersTokenize(CustomNode):
@ -346,7 +211,7 @@ class TransformersTokenize(CustomNode):
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult:
return model.tokenize(prompt, [], None),
@ -452,7 +317,7 @@ class OneShotInstructTokenize(CustomNode):
RETURN_TYPES = (TOKENS_TYPE_NAME,)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
if chat_template == _AUTO_CHAT_TEMPLATE:
# use an exact match
model_name = os.path.basename(model.repo_id)
@ -475,10 +340,9 @@ class TransformersGenerate(CustomNode):
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
"use_cache": ("BOOLEAN", {"default": True}),
},
"optional": {
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),
}
}
@ -487,110 +351,14 @@ class TransformersGenerate(CustomNode):
FUNCTION = "execute"
def execute(self,
model: Optional[TransformersManagedModel] = None,
tokens: _TOKENS_TYPE = None,
model: Optional[LanguageModel] = None,
tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
):
tokens = copy.copy(tokens)
tokens_original = copy.copy(tokens)
sampler = sampler or {}
generate_kwargs = copy.copy(sampler)
load_models_gpu([model])
transformers_model: PreTrainedModel = model.model
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
# remove unused inputs
# maximizes compatibility with different models
generate_signature = inspect.signature(transformers_model.generate).parameters
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
gen_sig_keys = generate_signature.keys()
if "tgt_lang" in tokens:
to_delete.add("tgt_lang")
to_delete.add("src_lang")
to_delete.discard("input_ids")
if "forced_bos_token_id" in tokens:
to_delete.discard("forced_bos_token_id")
elif hasattr(tokenizer, "convert_tokens_to_ids"):
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
else:
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
if "input_ids" in tokens and "inputs" in tokens:
if "input_ids" in gen_sig_keys:
to_delete.add("inputs")
elif "inputs" in gen_sig_keys:
to_delete.add("input_ids")
for unused_kwarg in to_delete:
tokens.pop(unused_kwarg)
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
# images should be moved to model
for key in ("images", "pixel_values"):
if key in tokens:
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
# sets up inputs
inputs = tokens
# used to determine if text streaming is supported
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
# todo: deal with batches correctly, don't assume batch size 1
token_count = 0
# progress
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
token_count += 1
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
inputs.pop("attention_mask")
output_ids = transformers_model.generate(
**inputs,
streamer=text_streamer if num_beams <= 1 else None,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**generate_kwargs
)
if not transformers_model.config.is_encoder_decoder:
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
output_ids = output_ids[:, start_position:]
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
prev_src_lang = tokenizer.src_lang
tokenizer.src_lang = tokens_original["src_lang"]
else:
prev_src_lang = None
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
try:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
finally:
if prev_src_lang is not None:
tokenizer.src_lang = prev_src_lang
# gpu-loaded stuff like images can now be unloaded
if hasattr(tokens, "to"):
del tokens
else:
for to_delete in tokens.values():
del to_delete
del tokens
# todo: better support batches
return outputs[0],
return model.generate(tokens, max_new_tokens, repetition_penalty, seed, sampler),
class PreviewString(CustomNode):

View File

@ -0,0 +1,206 @@
import base64
import io
import os
from io import BytesIO
from typing import Literal, Optional
import numpy as np
import requests
import torch
from PIL import Image
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from comfy.cli_args import args
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.language.language_types import LanguageModel, ProcessorResult, GENERATION_KWARGS_TYPE, TOKENS_TYPE, \
TransformerStreamedProgress
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.utils import comfy_progress, ProgressBar, seed_for_block
class _Client:
_client: Optional[OpenAI] = None
@staticmethod
def instance() -> OpenAI:
if _Client._client is None:
open_ai_api_key = args.openai_api_key
_Client._client = OpenAI(
api_key=open_ai_api_key,
)
return _Client._client
def validate_has_key():
open_api_key = os.environ.get("OPENAI_API_KEY", args.openai_api_key)
if open_api_key is None or open_api_key == "":
return "set OPENAI_API_KEY environment variable"
return True
def image_to_base64(image: RGBImageBatch) -> str:
pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
buffered = io.BytesIO()
pil_image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
class OpenAILanguageModelWrapper(LanguageModel):
def __init__(self, model: str):
self.model = model
self.client = _Client.instance()
@staticmethod
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "OpenAILanguageModelWrapper":
return OpenAILanguageModelWrapper(ckpt_name)
def generate(self, tokens: TOKENS_TYPE = None,
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs) -> str:
sampler = sampler or {}
prompt = tokens.get("inputs", [])
prompt = "".join(prompt)
images = tokens.get("images", [])
messages: list[ChatCompletionMessageParam] = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
] + [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_to_base64(image)}"
}
} for image in images
]
}
]
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
token_count = 0
full_response = ""
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
nonlocal full_response
token_count += 1
full_response += next_token
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(max_new_tokens if stop else token_count, total=max_new_tokens, preview_image_or_output=preview)
with seed_for_block(seed):
stream = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=max_new_tokens,
temperature=sampler.get("temperature", 1.0),
top_p=sampler.get("top_p", 1.0),
# n=1,
# stop=None,
# presence_penalty=repetition_penalty,
seed=seed,
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
on_finalized_text(chunk.choices[0].delta.content, False)
on_finalized_text("", True) # Signal the end of streaming
return full_response
def tokenize(self, prompt: str, images: RGBImageBatch, chat_template: str | None = None) -> ProcessorResult:
# OpenAI API doesn't require explicit tokenization, so we'll just return the prompt and images as is
return {
"inputs": [prompt],
"attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask
"images": images
}
@property
def repo_id(self) -> str:
return f"openai/{self.model}"
class OpenAILanguageModelLoader(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": (["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"], {"default": "gpt-3.5-turbo"})
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("language model",)
FUNCTION = "execute"
CATEGORY = "openai"
def execute(self, model: str) -> tuple[LanguageModel]:
return OpenAILanguageModelWrapper(model),
@classmethod
def VALIDATE_INPUTS(cls):
return validate_has_key()
class DallEGenerate(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {"required": {
"model": (["dall-e-2", "dall-e-3"], {"default": "dall-e-3"}),
"text": ("STRING", {"multiline": True}),
"size": (["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], {"default": "1024x1024"}),
"quality": (["standard", "hd"], {"default": "standard"}),
}}
RETURN_TYPES = ("IMAGE", "STRING",)
RETURN_NAMES = ("images", "revised prompt")
FUNCTION = "generate"
CATEGORY = "openai"
def generate(self,
model: Literal["dall-e-2", "dall-e-3"],
text: str,
size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
quality: Literal["standard", "hd"]) -> tuple[RGBImageBatch, str]:
response = _Client.instance().images.generate(
model=model,
prompt=text,
size=size,
quality=quality,
n=1,
)
image_url = response.data[0].url
image_response = requests.get(image_url)
img = Image.open(BytesIO(image_response.content))
image = np.array(img).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
return image, response.data[0].revised_prompt
@classmethod
def VALIDATE_INPUTS(cls):
return validate_has_key()
NODE_CLASS_MAPPINGS = {
"DallEGenerate": DallEGenerate,
"OpenAILanguageModelLoader": OpenAILanguageModelLoader
}

View File

@ -65,4 +65,6 @@ ml_dtypes
diffusers>=0.30.1
vtracer
skia-python
pebble>=5.0.7
pebble>=5.0.7
openai
anthropic

View File

@ -0,0 +1,14 @@
import torch
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize
def test_integration_transformers_loader_and_tokenize():
loader = TransformersLoader()
tokenize = OneShotInstructTokenize()
model, = loader.execute("llava-hf/llava-v1.6-mistral-7b-hf", "")
tokens, = tokenize.execute(model, "Describe this image:", torch.rand((1, 224, 224, 3)), "llava-v1.6-mistral-7b-hf", )
assert isinstance(tokens, dict)
assert "input_ids" in tokens or "inputs" in tokens

View File

@ -1,10 +1,17 @@
import io
import os
import tempfile
from unittest.mock import patch
from unittest.mock import patch, Mock
import pytest
import torch
from PIL import Image
from comfy.language.language_types import LanguageModel, ProcessorResult
from comfy_extras.nodes.nodes_language import SaveString
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, TransformersGenerate, \
PreviewString
from comfy_extras.nodes.nodes_openai import OpenAILanguageModelLoader, OpenAILanguageModelWrapper, DallEGenerate
@pytest.fixture
@ -57,3 +64,151 @@ def test_save_string_default_extension(save_string_node, mock_get_save_path):
assert os.path.exists(saved_file_path)
with open(saved_file_path, "r") as f:
assert f.read() == test_string
@pytest.fixture
def mock_openai_client():
with patch('comfy_extras.nodes.nodes_openai._Client') as mock_client:
instance = mock_client.instance.return_value
instance.chat.completions.create = Mock()
instance.images.generate = Mock()
yield instance
def test_transformers_loader():
loader = TransformersLoader()
model, = loader.execute("microsoft/Phi-3-mini-4k-instruct", "")
assert isinstance(model, LanguageModel)
assert model.repo_id == "microsoft/Phi-3-mini-4k-instruct"
def test_one_shot_instruct_tokenize(mocker):
tokenize = OneShotInstructTokenize()
mock_model = mocker.Mock()
mock_model.tokenize.return_value = {"input_ids": torch.tensor([[1, 2, 3]])}
tokens, = tokenize.execute(mock_model, "What comes after apple?", [], "phi-3")
mock_model.tokenize.assert_called_once_with("What comes after apple?", [], mocker.ANY)
assert "input_ids" in tokens
def test_transformers_generate(mocker):
generate = TransformersGenerate()
mock_model = mocker.Mock()
mock_model.generate.return_value = "The letter B comes after A in the alphabet."
tokens: ProcessorResult = {"inputs": torch.tensor([[1, 2, 3]])}
result, = generate.execute(mock_model, tokens, 512, 0, 42)
mock_model.generate.assert_called_once()
assert isinstance(result, str)
assert "letter B" in result
def test_preview_string():
preview = PreviewString()
result = preview.execute("Test output")
assert result == {"ui": {"string": ["Test output"]}}
def test_openai_language_model_loader():
if not "OPENAI_API_KEY" in os.environ:
pytest.skip("must set OPENAI_API_KEY")
loader = OpenAILanguageModelLoader()
model, = loader.execute("gpt-3.5-turbo")
assert isinstance(model, OpenAILanguageModelWrapper)
assert model.model == "gpt-3.5-turbo"
def test_openai_language_model_wrapper_generate(mock_openai_client):
wrapper = OpenAILanguageModelWrapper("gpt-3.5-turbo")
mock_stream = [
Mock(choices=[Mock(delta=Mock(content="This "))]),
Mock(choices=[Mock(delta=Mock(content="is "))]),
Mock(choices=[Mock(delta=Mock(content="a "))]),
Mock(choices=[Mock(delta=Mock(content="test "))]),
Mock(choices=[Mock(delta=Mock(content="response."))]),
]
mock_openai_client.chat.completions.create.return_value = mock_stream
tokens = {"inputs": ["What is the capital of France?"]}
result = wrapper.generate(tokens, max_new_tokens=50)
mock_openai_client.chat.completions.create.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}],
max_tokens=50,
temperature=1.0,
top_p=1.0,
seed=0,
stream=True
)
assert result == "This is a test response."
def test_openai_language_model_wrapper_generate_with_image(mock_openai_client):
wrapper = OpenAILanguageModelWrapper("gpt-4-vision-preview")
mock_stream = [
Mock(choices=[Mock(delta=Mock(content="This "))]),
Mock(choices=[Mock(delta=Mock(content="image "))]),
Mock(choices=[Mock(delta=Mock(content="shows "))]),
Mock(choices=[Mock(delta=Mock(content="a "))]),
Mock(choices=[Mock(delta=Mock(content="landscape."))]),
]
mock_openai_client.chat.completions.create.return_value = mock_stream
image_tensor = torch.rand((1, 224, 224, 3))
tokens: ProcessorResult = {
"inputs": ["Describe this image:"],
"images": image_tensor
}
result = wrapper.generate(tokens, max_new_tokens=50)
mock_openai_client.chat.completions.create.assert_called_once()
assert result == "This image shows a landscape."
def test_dalle_generate(mock_openai_client):
dalle = DallEGenerate()
mock_openai_client.images.generate.return_value = Mock(
data=[Mock(url="http://example.com/image.jpg", revised_prompt="A beautiful sunset")]
)
test_image = Image.new('RGB', (10, 10), color='red')
img_byte_arr = io.BytesIO()
test_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
with patch('requests.get') as mock_get:
mock_get.return_value = Mock(content=img_byte_arr)
image, revised_prompt = dalle.generate("dall-e-3", "Create a sunset image", "1024x1024", "standard")
assert isinstance(image, torch.Tensor)
assert image.shape == (1, 10, 10, 3)
assert torch.allclose(image, torch.tensor([1.0, 0, 0]).view(1, 1, 1, 3).expand(1, 10, 10, 3))
assert revised_prompt == "A beautiful sunset"
mock_openai_client.images.generate.assert_called_once_with(
model="dall-e-3",
prompt="Create a sunset image",
size="1024x1024",
quality="standard",
n=1,
)
def test_integration_openai_loader_and_wrapper(mock_openai_client):
loader = OpenAILanguageModelLoader()
model, = loader.execute("gpt-4")
mock_stream = [
Mock(choices=[Mock(delta=Mock(content="Paris "))]),
Mock(choices=[Mock(delta=Mock(content="is "))]),
Mock(choices=[Mock(delta=Mock(content="the "))]),
Mock(choices=[Mock(delta=Mock(content="capital "))]),
Mock(choices=[Mock(delta=Mock(content="of France."))]),
]
mock_openai_client.chat.completions.create.return_value = mock_stream
tokens = {"inputs": ["What is the capital of France?"]}
result = model.generate(tokens, max_new_tokens=50)
assert result == "Paris is the capital of France."