From a4fb34a0b80a8fd26c21b457642018967461d6c8 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 5 Sep 2024 21:56:04 -0700 Subject: [PATCH] Improve language and compositing nodes --- comfy/cli_args.py | 9 + comfy/cli_args_types.py | 2 + comfy/component_model/tensor_types.py | 1 + comfy/distributed/process_pool_executor.py | 15 +- comfy/language/language_types.py | 70 ++++- .../language/transformers_model_management.py | 255 ++++++++++++++-- comfy/model_downloader.py | 5 + comfy/nodes/base_nodes.py | 8 +- comfy_extras/nodes/nodes_compositing.py | 52 +++- comfy_extras/nodes/nodes_language.py | 272 ++---------------- comfy_extras/nodes/nodes_openai.py | 206 +++++++++++++ requirements.txt | 4 +- tests/inference/test_language.py | 14 + tests/unit/test_language_nodes.py | 157 +++++++++- 14 files changed, 767 insertions(+), 303 deletions(-) create mode 100644 comfy_extras/nodes/nodes_openai.py create mode 100644 tests/inference/test_language.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a5f0cae45..b61f5ebe4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -206,6 +206,15 @@ def _create_parser() -> EnhancedConfigArgParser: help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed." ) + parser.add_argument( + "--openai-api-key", + required=False, + type=str, + help="Configures the OpenAI API Key for the OpenAI nodes", + env_var="OPENAI_API_KEY", + default=None + ) + # now give plugins a chance to add configuration for entry_point in entry_points().select(group='comfyui.custom_config'): try: diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 82ce9a212..ec11c02de 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -111,6 +111,7 @@ class Configuration(dict): force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure. executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512. + openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes """ def __init__(self, **kwargs): @@ -198,6 +199,7 @@ class Configuration(dict): self[key] = value self.executor_factory: str = "ThreadPoolExecutor" + self.openai_api_key: Optional[str] = None def __getattr__(self, item): if item not in self: diff --git a/comfy/component_model/tensor_types.py b/comfy/component_model/tensor_types.py index def0e21f1..ed88a6750 100644 --- a/comfy/component_model/tensor_types.py +++ b/comfy/component_model/tensor_types.py @@ -2,6 +2,7 @@ from jaxtyping import Float from torch import Tensor ImageBatch = Float[Tensor, "batch height width channels"] +MaskBatch = Float[Tensor, "batch height width"] RGBImageBatch = Float[Tensor, "batch height width 3"] RGBAImageBatch = Float[Tensor, "batch height width 4"] RGBImage = Float[Tensor, "height width 3"] diff --git a/comfy/distributed/process_pool_executor.py b/comfy/distributed/process_pool_executor.py index 8f0e954e8..9c1b7776a 100644 --- a/comfy/distributed/process_pool_executor.py +++ b/comfy/distributed/process_pool_executor.py @@ -24,13 +24,14 @@ class ProcessPoolExecutor(ProcessPool, Executor): args: list = (), kwargs: dict = {}, timeout: float = None) -> ProcessFuture: - try: - args: ExecutePromptArgs - prompt, prompt_id, client_id, span_context, progress_handler, configuration = args - - except ValueError: - pass - super().schedule(function, args, kwargs, timeout) + # todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different + # try: + # args: ExecutePromptArgs + # prompt, prompt_id, client_id, span_context, progress_handler, configuration = args + # + # except ValueError: + # pass + return super().schedule(function, args, kwargs, timeout) def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None) diff --git a/comfy/language/language_types.py b/comfy/language/language_types.py index c77dd6f1d..448a67d1c 100644 --- a/comfy/language/language_types.py +++ b/comfy/language/language_types.py @@ -1,9 +1,17 @@ from __future__ import annotations +from typing import Union, Callable, List, Optional, Protocol, runtime_checkable + +import numpy as np import torch -from transformers import BatchEncoding +from PIL.Image import Image +from transformers import BatchEncoding, BatchFeature, TensorType +from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, TruncationStrategy +from transformers.utils import PaddingStrategy from typing_extensions import TypedDict, NotRequired +from comfy.component_model.tensor_types import RGBImageBatch + class ProcessorResult(TypedDict): """ @@ -18,7 +26,61 @@ class ProcessorResult(TypedDict): attention_mask: NotRequired[torch.Tensor] pixel_values: NotRequired[torch.Tensor] - - images: NotRequired[torch.Tensor] - inputs: NotRequired[BatchEncoding] + images: NotRequired[RGBImageBatch] + inputs: NotRequired[BatchEncoding | list[str]] image_sizes: NotRequired[torch.Tensor] + + +class GenerationKwargs(TypedDict): + top_k: NotRequired[int] + top_p: NotRequired[float] + temperature: NotRequired[float] + penalty_alpha: NotRequired[float] + num_beams: NotRequired[int] + early_stopping: NotRequired[bool] + + +GENERATION_KWARGS_TYPE = GenerationKwargs +GENERATION_KWARGS_TYPE_NAME = "SAMPLER" +TOKENS_TYPE = Union[ProcessorResult, BatchFeature] +TOKENS_TYPE_NAME = "TOKENS" + + +class TransformerStreamedProgress(TypedDict): + next_token: str + + +LLaVAProcessor = Callable[ + [ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter + Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter + Union[bool, str, PaddingStrategy], # padding parameter + Union[bool, str, TruncationStrategy], # truncation parameter + Optional[int], # max_length parameter + Optional[Union[str, TensorType]] # return_tensors parameter + ], + BatchFeature +] + + +@runtime_checkable +class LanguageModel(Protocol): + @staticmethod + def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "LanguageModel": + ... + + def generate(self, tokens: TOKENS_TYPE = None, + max_new_tokens: int = 512, + repetition_penalty: float = 0.0, + seed: int = 0, + sampler: Optional[GENERATION_KWARGS_TYPE] = None, + *args, + **kwargs) -> str: + ... + + def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str | None = None) -> ProcessorResult: + ... + + @property + def repo_id(self) -> str: + return "" diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index 2b408a395..9ee4a53c5 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -1,38 +1,32 @@ from __future__ import annotations import copy +import inspect import logging +import operator import pathlib import warnings -from typing import Optional, Any, Callable, Union, List +from functools import reduce +from typing import Optional, Any, Callable, List -import numpy as np import torch -from PIL.Image import Image from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \ - TensorType, BatchFeature -from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy -from transformers.utils import PaddingStrategy + BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \ + PretrainedConfig, TextStreamer, LogitsProcessor +from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \ + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from .chat_templates import KNOWN_CHAT_TEMPLATES -from .language_types import ProcessorResult -from ..model_management import unet_offload_device, get_torch_device +from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \ + LLaVAProcessor, LanguageModel +from .. import model_management +from ..model_downloader import get_or_download_huggingface_repo +from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu from ..model_management_types import ModelManageable - -LLaVAProcessor = Callable[ - [ - Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter - Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter - Union[bool, str, PaddingStrategy], # padding parameter - Union[bool, str, TruncationStrategy], # truncation parameter - Optional[int], # max_length parameter - Optional[Union[str, TensorType]] # return_tensors parameter - ], - BatchFeature -] +from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block -class TransformersManagedModel(ModelManageable): +class TransformersManagedModel(ModelManageable, LanguageModel): def __init__( self, repo_id: str, @@ -41,7 +35,7 @@ class TransformersManagedModel(ModelManageable): config_dict: Optional[dict] = None, processor: Optional[ProcessorMixin | AutoProcessor] = None ): - self.repo_id = repo_id + self._repo_id = repo_id self.model = model self._tokenizer = tokenizer self._processor = processor @@ -54,6 +48,200 @@ class TransformersManagedModel(ModelManageable): if model.device != self.offload_device: model.to(device=self.offload_device) + @staticmethod + def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "TransformersManagedModel": + hub_kwargs = {} + if subfolder is not None and subfolder != "": + hub_kwargs["subfolder"] = subfolder + repo_id = ckpt_name + ckpt_name = get_or_download_huggingface_repo(ckpt_name) + with comfy_tqdm(): + from_pretrained_kwargs = { + "pretrained_model_name_or_path": ckpt_name, + "trust_remote_code": True, + **hub_kwargs + } + + # compute bitsandbytes configuration + try: + import bitsandbytes + except ImportError: + pass + + config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs) + model_type = config_dict["model_type"] + # language models prefer to use bfloat16 over float16 + kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), + "low_cpu_mem_usage": True, + "device_map": str(unet_offload_device()), }, {}) + + # if we have flash-attn installed, try to use it + try: + import flash_attn + attn_override_kwargs = { + "attn_implementation": "flash_attention_2", + **kwargs_to_try[0] + } + kwargs_to_try = (attn_override_kwargs, *kwargs_to_try) + logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") + except ImportError: + pass + for i, props in enumerate(kwargs_to_try): + try: + if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES: + model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props) + elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: + model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props) + elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props) + else: + model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props) + if model is not None: + break + except Exception as exc_info: + if i == len(kwargs_to_try) - 1: + raise exc_info + else: + logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info) + finally: + torch.set_default_dtype(torch.float32) + + for i, props in enumerate(kwargs_to_try): + try: + try: + processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props) + except: + processor = None + if isinstance(processor, PreTrainedTokenizerBase): + tokenizer = processor + processor = None + else: + tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props) + if tokenizer is not None or processor is not None: + break + except Exception as exc_info: + if i == len(kwargs_to_try) - 1: + raise exc_info + finally: + torch.set_default_dtype(torch.float32) + + if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"): + model.enable_xformers_memory_efficient_attention() + logging.debug("enabled xformers memory efficient attention") + + model_managed = TransformersManagedModel( + repo_id=repo_id, + model=model, + tokenizer=tokenizer, + config_dict=config_dict, + processor=processor + ) + + return model_managed + + def generate(self, tokens: TOKENS_TYPE = None, + max_new_tokens: int = 512, + repetition_penalty: float = 0.0, + seed: int = 0, + sampler: Optional[GENERATION_KWARGS_TYPE] = None, + *args, + **kwargs) -> str: + tokens = copy.copy(tokens) + tokens_original = copy.copy(tokens) + sampler = sampler or {} + generate_kwargs = copy.copy(sampler) + load_models_gpu([self]) + transformers_model: PreTrainedModel = self.model + tokenizer: PreTrainedTokenizerBase | AutoTokenizer = self.tokenizer + # remove unused inputs + # maximizes compatibility with different models + generate_signature = inspect.signature(transformers_model.generate).parameters + prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters + to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature]))) + gen_sig_keys = generate_signature.keys() + if "tgt_lang" in tokens: + to_delete.add("tgt_lang") + to_delete.add("src_lang") + to_delete.discard("input_ids") + if "forced_bos_token_id" in tokens: + to_delete.discard("forced_bos_token_id") + elif hasattr(tokenizer, "convert_tokens_to_ids"): + generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"]) + else: + logging.warning(f"tokenizer {tokenizer} unexpected for translation task") + if "input_ids" in tokens and "inputs" in tokens: + if "input_ids" in gen_sig_keys: + to_delete.add("inputs") + elif "inputs" in gen_sig_keys: + to_delete.add("input_ids") + for unused_kwarg in to_delete: + tokens.pop(unused_kwarg) + logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing") + + # images should be moved to model + for key in ("images", "pixel_values"): + if key in tokens: + tokens[key] = tokens[key].to(device=self.current_device, dtype=self.model_dtype()) + + # sets up inputs + inputs = tokens + + # used to determine if text streaming is supported + num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams) + + progress_bar: ProgressBar + with comfy_progress(total=max_new_tokens) as progress_bar: + # todo: deal with batches correctly, don't assume batch size 1 + token_count = 0 + + # progress + def on_finalized_text(next_token: str, stop: bool): + nonlocal token_count + nonlocal progress_bar + + token_count += 1 + preview = TransformerStreamedProgress(next_token=next_token) + progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview) + + text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) + + with seed_for_block(seed): + if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs: + inputs.pop("attention_mask") + output_ids = transformers_model.generate( + **inputs, + streamer=text_streamer if num_beams <= 1 else None, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty if repetition_penalty != 0 else None, + **generate_kwargs + ) + + if not transformers_model.config.is_encoder_decoder: + start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1] + output_ids = output_ids[:, start_position:] + + if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original: + prev_src_lang = tokenizer.src_lang + tokenizer.src_lang = tokens_original["src_lang"] + else: + prev_src_lang = None + # todo: is this redundant consider I'm decoding in the on_finalized_text block? + try: + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + finally: + if prev_src_lang is not None: + tokenizer.src_lang = prev_src_lang + # gpu-loaded stuff like images can now be unloaded + if hasattr(tokens, "to"): + del tokens + else: + for to_delete in tokens.values(): + del to_delete + del tokens + + # todo: better support batches + return outputs[0] + @property def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer: return self._tokenizer @@ -178,9 +366,32 @@ class TransformersManagedModel(ModelManageable): **batch_feature } + @property + def repo_id(self) -> str: + return self._repo_id + def __str__(self): if self.repo_id is not None: repo_id_as_path = pathlib.PurePath(self.repo_id) return f"" else: return f"" + + +class _ProgressTextStreamer(TextStreamer): + def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.on_finalized_text_handler = on_finalized_text + + def on_finalized_text(self, text: str, stream_end: bool = False): + self.on_finalized_text_handler(text, stream_end) + + +class _ProgressLogitsProcessor(LogitsProcessor): + def __init__(self, model: TransformersManagedModel): + self.eos_token_id = model.tokenizer.eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + probabilities = scores.softmax(dim=-1) + self.eos_probability = probabilities[:, self.eos_token_id].item() + return scores diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 5cef8f07f..5f78da20e 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -385,6 +385,10 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"), HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"), HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"), + HuggingFile("TheMistoAI/MistoLine_Flux.dev", "mistoline_flux.dev_v1.safetensors"), + HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-canny-controlnet-v3.safetensors"), + HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-depth-controlnet-v3.safetensors"), + HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-hed-controlnet-v3.safetensors"), ], folder_name="controlnet") KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ @@ -418,6 +422,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { 'llava-hf/llava-v1.6-mistral-7b-hf', 'facebook/nllb-200-distilled-1.3B', 'THUDM/chatglm3-6b', + 'roborovski/superprompt-v1', } KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index b4368dd97..1813dab35 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -24,7 +24,7 @@ from .. import model_management from ..cli_args import args from ..cmd import folder_paths, latent_preview -from ..component_model.tensor_types import RGBImage +from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch from ..execution_context import current_execution_context from ..images import open_image from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES @@ -808,7 +808,7 @@ class ControlNetApply: CATEGORY = "conditioning/controlnet" - def apply_controlnet(self, conditioning, control_net, image, strength): + def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength): if strength == 0: return (conditioning, ) @@ -1573,7 +1573,7 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" - def load_image(self, image: str): + def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]: image_path = folder_paths.get_annotated_filepath(image) img = node_helpers.pillow(Image.open, image_path) @@ -1703,7 +1703,7 @@ class ImageScale: CATEGORY = "image/upscaling" - def upscale(self, image, upscale_method, width, height, crop): + def upscale(self, image: RGBImageBatch, upscale_method, width, height, crop) -> tuple[RGBImageBatch]: if width == 0 and height == 0: s = image else: diff --git a/comfy_extras/nodes/nodes_compositing.py b/comfy_extras/nodes/nodes_compositing.py index 445c3221c..2bfc3fa86 100644 --- a/comfy_extras/nodes/nodes_compositing.py +++ b/comfy_extras/nodes/nodes_compositing.py @@ -5,7 +5,7 @@ import torch from skimage import exposure import comfy.utils -from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch +from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch, MaskBatch from comfy.nodes.package_typing import CustomNode @@ -34,10 +34,7 @@ class PorterDuffMode(Enum): XOR = 17 -def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): - # convert mask to alpha - src_alpha = 1 - src_alpha - dst_alpha = 1 - dst_alpha +def _porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): # premultiply alpha src_image = src_image * src_alpha dst_image = dst_image * dst_alpha @@ -109,24 +106,31 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ return out_image, out_alpha -class PorterDuffImageComposite: +class PorterDuffImageCompositeV2: @classmethod def INPUT_TYPES(s): return { "required": { "source": ("IMAGE",), - "source_alpha": ("MASK",), "destination": ("IMAGE",), - "destination_alpha": ("MASK",), "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), }, + "optional": { + "source_alpha": ("MASK",), + "destination_alpha": ("MASK",), + } } RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "composite" CATEGORY = "mask/compositing" - def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): + def composite(self, source: RGBImageBatch, destination: RGBImageBatch, mode, source_alpha: MaskBatch = None, destination_alpha: MaskBatch = None) -> tuple[RGBImageBatch, MaskBatch]: + if source_alpha is None: + source_alpha = torch.zeros(source.shape[:3]) + if destination_alpha is None: + destination_alpha = torch.zeros(destination.shape[:3]) + batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) out_images = [] out_alphas = [] @@ -153,7 +157,7 @@ class PorterDuffImageComposite: upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) - out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode]) + out_image, out_alpha = _porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode]) out_images.append(out_image) out_alphas.append(out_alpha.squeeze(2)) @@ -162,6 +166,28 @@ class PorterDuffImageComposite: return result +class PorterDuffImageCompositeV1(PorterDuffImageCompositeV2): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source": ("IMAGE",), + "source_alpha": ("MASK",), + "destination": ("IMAGE",), + "destination_alpha": ("MASK",), + "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), + }, + } + + FUNCTION = "composite_v1" + + def composite_v1(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> tuple[RGBImageBatch, MaskBatch]: + # convert mask to alpha + source_alpha = 1 - source_alpha + destination_alpha = 1 - destination_alpha + return super().composite(source, destination, mode, source_alpha, destination_alpha) + + class SplitImageWithAlpha: @classmethod def INPUT_TYPES(s): @@ -312,7 +338,8 @@ class Posterize(CustomNode): NODE_CLASS_MAPPINGS = { - "PorterDuffImageComposite": PorterDuffImageComposite, + "PorterDuffImageComposite": PorterDuffImageCompositeV1, + "PorterDuffImageCompositeV2": PorterDuffImageCompositeV2, "SplitImageWithAlpha": SplitImageWithAlpha, "JoinImageWithAlpha": JoinImageWithAlpha, "EnhanceContrast": EnhanceContrast, @@ -321,7 +348,8 @@ NODE_CLASS_MAPPINGS = { } NODE_DISPLAY_NAME_MAPPINGS = { - "PorterDuffImageComposite": "Porter-Duff Image Composite", + "PorterDuffImageComposite": "Porter-Duff Image Composite (V1)", + "PorterDuffImageCompositeV2": "Image Composite", "SplitImageWithAlpha": "Split Image with Alpha", "JoinImageWithAlpha": "Join Image with Alpha", } diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index da8d2065f..12457256a 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -1,80 +1,32 @@ from __future__ import annotations -import copy -import inspect -import logging import operator import os.path from functools import reduce -from typing import Any, Dict, Optional, List, Callable, Union +from typing import Optional, List import torch -from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ - PreTrainedTokenizerBase, PretrainedConfig, AutoProcessor, BatchFeature, AutoModel, AutoModelForCausalLM, \ - AutoModelForSeq2SeqLM -from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, \ - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, AutoModelForVision2Seq +from transformers import AutoProcessor from transformers.models.m2m_100.tokenization_m2m_100 import \ FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES from transformers.models.nllb.tokenization_nllb import \ FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES -from typing_extensions import TypedDict -from comfy import model_management from comfy.cmd import folder_paths from comfy.component_model.folder_path_types import SaveImagePathResponse from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES -from comfy.language.language_types import ProcessorResult +from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \ + TOKENS_TYPE_NAME, LanguageModel from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo -from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu +from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult -from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar _AUTO_CHAT_TEMPLATE = "default" -# add llava support -try: - from llava import model as _llava_model_side_effects - - logging.debug("Additional LLaVA models are now supported") -except ImportError as exc: - logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support") - -# aka kwargs type -_GENERATION_KWARGS_TYPE = Dict[str, Any] -_GENERATION_KWARGS_TYPE_NAME = "SAMPLER" - -_TOKENS_TYPE = Union[ProcessorResult, BatchFeature] -TOKENS_TYPE_NAME = "TOKENS" - - -class _ProgressTextStreamer(TextStreamer): - def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): - super().__init__(tokenizer, skip_prompt, **decode_kwargs) - self.on_finalized_text_handler = on_finalized_text - - def on_finalized_text(self, text: str, stream_end: bool = False): - self.on_finalized_text_handler(text, stream_end) - - -class _ProgressLogitsProcessor(LogitsProcessor): - def __init__(self, model: TransformersManagedModel): - self.eos_token_id = model.tokenizer.eos_token_id - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - probabilities = scores.softmax(dim=-1) - self.eos_probability = probabilities[:, self.eos_token_id].item() - return scores - - -# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ? -class TransformerStreamedProgress(TypedDict): - next_token: str - class TransformerSamplerBase(CustomNode): - RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME, + RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME, RETURN_NAMES = "GENERATION ARGS", FUNCTION = "execute" CATEGORY = "language/samplers" @@ -142,7 +94,7 @@ class TransformersGenerationConfig(CustomNode): } } - RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME, + RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME, RETURN_NAMES = "GENERATION ARGS", FUNCTION = "execute" CATEGORY = "language" @@ -182,15 +134,15 @@ class TransformerBeamSearchSampler(TransformerSamplerBase): class TransformerMergeSamplers(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: - range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})} - range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)}) + range_ = {"value0": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})} + range_.update({f"value{i}": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)}) return { "required": range_ } CATEGORY = "language" - RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME, + RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME, FUNCTION = "execute" def execute(self, **kwargs): @@ -238,98 +190,11 @@ class TransformersLoader(CustomNode): CATEGORY = "language" RETURN_TYPES = "MODEL", + RETURN_NAMES = "language model", FUNCTION = "execute" - def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs): - hub_kwargs = {} - if subfolder is not None and subfolder != "": - hub_kwargs["subfolder"] = subfolder - - ckpt_name = get_or_download_huggingface_repo(ckpt_name) - with comfy_tqdm(): - from_pretrained_kwargs = { - "pretrained_model_name_or_path": ckpt_name, - "trust_remote_code": True, - **hub_kwargs - } - - # if flash attention exists, use it - - # compute bitsandbytes configuration - try: - import bitsandbytes - except ImportError: - pass - - config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs) - model_type = config_dict["model_type"] - # language models prefer to use bfloat16 over float16 - kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), - "low_cpu_mem_usage": True, - "device_map": str(unet_offload_device()), }, {}) - - # if we have flash-attn installed, try to use it - try: - import flash_attn - attn_override_kwargs = { - "attn_implementation": "flash_attention_2", - **kwargs_to_try[0] - } - kwargs_to_try = (attn_override_kwargs, *kwargs_to_try) - logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") - except ImportError: - pass - for i, props in enumerate(kwargs_to_try): - try: - if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES: - model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props) - elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props) - elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props) - else: - model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props) - if model is not None: - break - except Exception as exc_info: - if i == len(kwargs_to_try) - 1: - raise exc_info - else: - logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info) - finally: - torch.set_default_dtype(torch.float32) - - for i, props in enumerate(kwargs_to_try): - try: - try: - processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props) - except: - processor = None - if isinstance(processor, PreTrainedTokenizerBase): - tokenizer = processor - processor = None - else: - tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props) - if tokenizer is not None or processor is not None: - break - except Exception as exc_info: - if i == len(kwargs_to_try) - 1: - raise exc_info - finally: - torch.set_default_dtype(torch.float32) - - if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"): - model.enable_xformers_memory_efficient_attention() - logging.debug("enabled xformers memory efficient attention") - - model_managed = TransformersManagedModel( - repo_id=ckpt_name, - model=model, - tokenizer=tokenizer, - config_dict=config_dict, - processor=processor - ) - return model_managed, + def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs) -> tuple[TransformersManagedModel]: + return TransformersManagedModel.from_pretrained(ckpt_name, subfolder), class TransformersTokenize(CustomNode): @@ -346,7 +211,7 @@ class TransformersTokenize(CustomNode): RETURN_TYPES = (TOKENS_TYPE_NAME,) FUNCTION = "execute" - def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult: + def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult: return model.tokenize(prompt, [], None), @@ -452,7 +317,7 @@ class OneShotInstructTokenize(CustomNode): RETURN_TYPES = (TOKENS_TYPE_NAME,) FUNCTION = "execute" - def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult: + def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult: if chat_template == _AUTO_CHAT_TEMPLATE: # use an exact match model_name = os.path.basename(model.repo_id) @@ -475,10 +340,9 @@ class TransformersGenerate(CustomNode): "max_new_tokens": ("INT", {"default": 512, "min": 1}), "repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}), "seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}), - "use_cache": ("BOOLEAN", {"default": True}), }, "optional": { - "sampler": (_GENERATION_KWARGS_TYPE_NAME, {}), + "sampler": (GENERATION_KWARGS_TYPE_NAME, {}), } } @@ -487,110 +351,14 @@ class TransformersGenerate(CustomNode): FUNCTION = "execute" def execute(self, - model: Optional[TransformersManagedModel] = None, - tokens: _TOKENS_TYPE = None, + model: Optional[LanguageModel] = None, + tokens: TOKENS_TYPE = None, max_new_tokens: int = 512, repetition_penalty: float = 0.0, seed: int = 0, - sampler: Optional[_GENERATION_KWARGS_TYPE] = None, - *args, - **kwargs + sampler: Optional[GENERATION_KWARGS_TYPE] = None, ): - tokens = copy.copy(tokens) - tokens_original = copy.copy(tokens) - sampler = sampler or {} - generate_kwargs = copy.copy(sampler) - load_models_gpu([model]) - transformers_model: PreTrainedModel = model.model - tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer - # remove unused inputs - # maximizes compatibility with different models - generate_signature = inspect.signature(transformers_model.generate).parameters - prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters - to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature]))) - gen_sig_keys = generate_signature.keys() - if "tgt_lang" in tokens: - to_delete.add("tgt_lang") - to_delete.add("src_lang") - to_delete.discard("input_ids") - if "forced_bos_token_id" in tokens: - to_delete.discard("forced_bos_token_id") - elif hasattr(tokenizer, "convert_tokens_to_ids"): - generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"]) - else: - logging.warning(f"tokenizer {tokenizer} unexpected for translation task") - if "input_ids" in tokens and "inputs" in tokens: - if "input_ids" in gen_sig_keys: - to_delete.add("inputs") - elif "inputs" in gen_sig_keys: - to_delete.add("input_ids") - for unused_kwarg in to_delete: - tokens.pop(unused_kwarg) - logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing") - - # images should be moved to model - for key in ("images", "pixel_values"): - if key in tokens: - tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype()) - - # sets up inputs - inputs = tokens - - # used to determine if text streaming is supported - num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams) - - progress_bar: ProgressBar - with comfy_progress(total=max_new_tokens) as progress_bar: - # todo: deal with batches correctly, don't assume batch size 1 - token_count = 0 - - # progress - def on_finalized_text(next_token: str, stop: bool): - nonlocal token_count - nonlocal progress_bar - - token_count += 1 - preview = TransformerStreamedProgress(next_token=next_token) - progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview) - - text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) - - with seed_for_block(seed): - if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs: - inputs.pop("attention_mask") - output_ids = transformers_model.generate( - **inputs, - streamer=text_streamer if num_beams <= 1 else None, - max_new_tokens=max_new_tokens, - repetition_penalty=repetition_penalty if repetition_penalty != 0 else None, - **generate_kwargs - ) - - if not transformers_model.config.is_encoder_decoder: - start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1] - output_ids = output_ids[:, start_position:] - - if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original: - prev_src_lang = tokenizer.src_lang - tokenizer.src_lang = tokens_original["src_lang"] - else: - prev_src_lang = None - # todo: is this redundant consider I'm decoding in the on_finalized_text block? - try: - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - finally: - if prev_src_lang is not None: - tokenizer.src_lang = prev_src_lang - # gpu-loaded stuff like images can now be unloaded - if hasattr(tokens, "to"): - del tokens - else: - for to_delete in tokens.values(): - del to_delete - del tokens - - # todo: better support batches - return outputs[0], + return model.generate(tokens, max_new_tokens, repetition_penalty, seed, sampler), class PreviewString(CustomNode): diff --git a/comfy_extras/nodes/nodes_openai.py b/comfy_extras/nodes/nodes_openai.py new file mode 100644 index 000000000..325808286 --- /dev/null +++ b/comfy_extras/nodes/nodes_openai.py @@ -0,0 +1,206 @@ +import base64 +import io +import os +from io import BytesIO +from typing import Literal, Optional + +import numpy as np +import requests +import torch +from PIL import Image +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from comfy.cli_args import args +from comfy.component_model.tensor_types import RGBImageBatch +from comfy.language.language_types import LanguageModel, ProcessorResult, GENERATION_KWARGS_TYPE, TOKENS_TYPE, \ + TransformerStreamedProgress +from comfy.nodes.package_typing import CustomNode, InputTypes +from comfy.utils import comfy_progress, ProgressBar, seed_for_block + + +class _Client: + _client: Optional[OpenAI] = None + + @staticmethod + def instance() -> OpenAI: + if _Client._client is None: + open_ai_api_key = args.openai_api_key + _Client._client = OpenAI( + api_key=open_ai_api_key, + ) + + return _Client._client + + +def validate_has_key(): + open_api_key = os.environ.get("OPENAI_API_KEY", args.openai_api_key) + if open_api_key is None or open_api_key == "": + return "set OPENAI_API_KEY environment variable" + return True + + +def image_to_base64(image: RGBImageBatch) -> str: + pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) + buffered = io.BytesIO() + pil_image.save(buffered, format="JPEG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +class OpenAILanguageModelWrapper(LanguageModel): + def __init__(self, model: str): + self.model = model + self.client = _Client.instance() + + @staticmethod + def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "OpenAILanguageModelWrapper": + return OpenAILanguageModelWrapper(ckpt_name) + + def generate(self, tokens: TOKENS_TYPE = None, + max_new_tokens: int = 512, + repetition_penalty: float = 0.0, + seed: int = 0, + sampler: Optional[GENERATION_KWARGS_TYPE] = None, + *args, + **kwargs) -> str: + sampler = sampler or {} + prompt = tokens.get("inputs", []) + prompt = "".join(prompt) + images = tokens.get("images", []) + messages: list[ChatCompletionMessageParam] = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ] + [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_to_base64(image)}" + } + } for image in images + ] + } + ] + + progress_bar: ProgressBar + with comfy_progress(total=max_new_tokens) as progress_bar: + token_count = 0 + full_response = "" + + def on_finalized_text(next_token: str, stop: bool): + nonlocal token_count + nonlocal progress_bar + nonlocal full_response + + token_count += 1 + full_response += next_token + preview = TransformerStreamedProgress(next_token=next_token) + progress_bar.update_absolute(max_new_tokens if stop else token_count, total=max_new_tokens, preview_image_or_output=preview) + + with seed_for_block(seed): + stream = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=max_new_tokens, + temperature=sampler.get("temperature", 1.0), + top_p=sampler.get("top_p", 1.0), + # n=1, + # stop=None, + # presence_penalty=repetition_penalty, + seed=seed, + stream=True + ) + + for chunk in stream: + if chunk.choices[0].delta.content is not None: + on_finalized_text(chunk.choices[0].delta.content, False) + + on_finalized_text("", True) # Signal the end of streaming + + return full_response + + def tokenize(self, prompt: str, images: RGBImageBatch, chat_template: str | None = None) -> ProcessorResult: + # OpenAI API doesn't require explicit tokenization, so we'll just return the prompt and images as is + return { + "inputs": [prompt], + "attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask + "images": images + } + + @property + def repo_id(self) -> str: + return f"openai/{self.model}" + + +class OpenAILanguageModelLoader(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": (["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"], {"default": "gpt-3.5-turbo"}) + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("language model",) + + FUNCTION = "execute" + CATEGORY = "openai" + + def execute(self, model: str) -> tuple[LanguageModel]: + return OpenAILanguageModelWrapper(model), + + @classmethod + def VALIDATE_INPUTS(cls): + return validate_has_key() + + +class DallEGenerate(CustomNode): + + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "model": (["dall-e-2", "dall-e-3"], {"default": "dall-e-3"}), + "text": ("STRING", {"multiline": True}), + "size": (["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], {"default": "1024x1024"}), + "quality": (["standard", "hd"], {"default": "standard"}), + }} + + RETURN_TYPES = ("IMAGE", "STRING",) + RETURN_NAMES = ("images", "revised prompt") + FUNCTION = "generate" + + CATEGORY = "openai" + + def generate(self, + model: Literal["dall-e-2", "dall-e-3"], + text: str, + size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], + quality: Literal["standard", "hd"]) -> tuple[RGBImageBatch, str]: + response = _Client.instance().images.generate( + model=model, + prompt=text, + size=size, + quality=quality, + n=1, + ) + + image_url = response.data[0].url + image_response = requests.get(image_url) + + img = Image.open(BytesIO(image_response.content)) + + image = np.array(img).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + return image, response.data[0].revised_prompt + + @classmethod + def VALIDATE_INPUTS(cls): + return validate_has_key() + + +NODE_CLASS_MAPPINGS = { + "DallEGenerate": DallEGenerate, + "OpenAILanguageModelLoader": OpenAILanguageModelLoader +} diff --git a/requirements.txt b/requirements.txt index 5c0eb7f7e..d31dee04f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -65,4 +65,6 @@ ml_dtypes diffusers>=0.30.1 vtracer skia-python -pebble>=5.0.7 \ No newline at end of file +pebble>=5.0.7 +openai +anthropic \ No newline at end of file diff --git a/tests/inference/test_language.py b/tests/inference/test_language.py new file mode 100644 index 000000000..dc125c03b --- /dev/null +++ b/tests/inference/test_language.py @@ -0,0 +1,14 @@ +import torch + +from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize + + +def test_integration_transformers_loader_and_tokenize(): + loader = TransformersLoader() + tokenize = OneShotInstructTokenize() + + model, = loader.execute("llava-hf/llava-v1.6-mistral-7b-hf", "") + tokens, = tokenize.execute(model, "Describe this image:", torch.rand((1, 224, 224, 3)), "llava-v1.6-mistral-7b-hf", ) + + assert isinstance(tokens, dict) + assert "input_ids" in tokens or "inputs" in tokens diff --git a/tests/unit/test_language_nodes.py b/tests/unit/test_language_nodes.py index ddb421b80..340ecb43e 100644 --- a/tests/unit/test_language_nodes.py +++ b/tests/unit/test_language_nodes.py @@ -1,10 +1,17 @@ +import io import os import tempfile -from unittest.mock import patch +from unittest.mock import patch, Mock import pytest +import torch +from PIL import Image +from comfy.language.language_types import LanguageModel, ProcessorResult from comfy_extras.nodes.nodes_language import SaveString +from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, TransformersGenerate, \ + PreviewString +from comfy_extras.nodes.nodes_openai import OpenAILanguageModelLoader, OpenAILanguageModelWrapper, DallEGenerate @pytest.fixture @@ -57,3 +64,151 @@ def test_save_string_default_extension(save_string_node, mock_get_save_path): assert os.path.exists(saved_file_path) with open(saved_file_path, "r") as f: assert f.read() == test_string + + +@pytest.fixture +def mock_openai_client(): + with patch('comfy_extras.nodes.nodes_openai._Client') as mock_client: + instance = mock_client.instance.return_value + instance.chat.completions.create = Mock() + instance.images.generate = Mock() + yield instance + + +def test_transformers_loader(): + loader = TransformersLoader() + model, = loader.execute("microsoft/Phi-3-mini-4k-instruct", "") + assert isinstance(model, LanguageModel) + assert model.repo_id == "microsoft/Phi-3-mini-4k-instruct" + + +def test_one_shot_instruct_tokenize(mocker): + tokenize = OneShotInstructTokenize() + mock_model = mocker.Mock() + mock_model.tokenize.return_value = {"input_ids": torch.tensor([[1, 2, 3]])} + + tokens, = tokenize.execute(mock_model, "What comes after apple?", [], "phi-3") + mock_model.tokenize.assert_called_once_with("What comes after apple?", [], mocker.ANY) + assert "input_ids" in tokens + + +def test_transformers_generate(mocker): + generate = TransformersGenerate() + mock_model = mocker.Mock() + mock_model.generate.return_value = "The letter B comes after A in the alphabet." + + tokens: ProcessorResult = {"inputs": torch.tensor([[1, 2, 3]])} + result, = generate.execute(mock_model, tokens, 512, 0, 42) + mock_model.generate.assert_called_once() + assert isinstance(result, str) + assert "letter B" in result + + +def test_preview_string(): + preview = PreviewString() + result = preview.execute("Test output") + assert result == {"ui": {"string": ["Test output"]}} + + +def test_openai_language_model_loader(): + if not "OPENAI_API_KEY" in os.environ: + pytest.skip("must set OPENAI_API_KEY") + loader = OpenAILanguageModelLoader() + model, = loader.execute("gpt-3.5-turbo") + assert isinstance(model, OpenAILanguageModelWrapper) + assert model.model == "gpt-3.5-turbo" + + +def test_openai_language_model_wrapper_generate(mock_openai_client): + wrapper = OpenAILanguageModelWrapper("gpt-3.5-turbo") + mock_stream = [ + Mock(choices=[Mock(delta=Mock(content="This "))]), + Mock(choices=[Mock(delta=Mock(content="is "))]), + Mock(choices=[Mock(delta=Mock(content="a "))]), + Mock(choices=[Mock(delta=Mock(content="test "))]), + Mock(choices=[Mock(delta=Mock(content="response."))]), + ] + + mock_openai_client.chat.completions.create.return_value = mock_stream + + tokens = {"inputs": ["What is the capital of France?"]} + result = wrapper.generate(tokens, max_new_tokens=50) + + mock_openai_client.chat.completions.create.assert_called_once_with( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}], + max_tokens=50, + temperature=1.0, + top_p=1.0, + seed=0, + stream=True + ) + assert result == "This is a test response." + + +def test_openai_language_model_wrapper_generate_with_image(mock_openai_client): + wrapper = OpenAILanguageModelWrapper("gpt-4-vision-preview") + mock_stream = [ + Mock(choices=[Mock(delta=Mock(content="This "))]), + Mock(choices=[Mock(delta=Mock(content="image "))]), + Mock(choices=[Mock(delta=Mock(content="shows "))]), + Mock(choices=[Mock(delta=Mock(content="a "))]), + Mock(choices=[Mock(delta=Mock(content="landscape."))]), + ] + mock_openai_client.chat.completions.create.return_value = mock_stream + + image_tensor = torch.rand((1, 224, 224, 3)) + tokens: ProcessorResult = { + "inputs": ["Describe this image:"], + "images": image_tensor + } + result = wrapper.generate(tokens, max_new_tokens=50) + + mock_openai_client.chat.completions.create.assert_called_once() + assert result == "This image shows a landscape." + + +def test_dalle_generate(mock_openai_client): + dalle = DallEGenerate() + mock_openai_client.images.generate.return_value = Mock( + data=[Mock(url="http://example.com/image.jpg", revised_prompt="A beautiful sunset")] + ) + test_image = Image.new('RGB', (10, 10), color='red') + img_byte_arr = io.BytesIO() + test_image.save(img_byte_arr, format='PNG') + img_byte_arr = img_byte_arr.getvalue() + + with patch('requests.get') as mock_get: + mock_get.return_value = Mock(content=img_byte_arr) + image, revised_prompt = dalle.generate("dall-e-3", "Create a sunset image", "1024x1024", "standard") + + assert isinstance(image, torch.Tensor) + assert image.shape == (1, 10, 10, 3) + assert torch.allclose(image, torch.tensor([1.0, 0, 0]).view(1, 1, 1, 3).expand(1, 10, 10, 3)) + assert revised_prompt == "A beautiful sunset" + mock_openai_client.images.generate.assert_called_once_with( + model="dall-e-3", + prompt="Create a sunset image", + size="1024x1024", + quality="standard", + n=1, + ) + + +def test_integration_openai_loader_and_wrapper(mock_openai_client): + loader = OpenAILanguageModelLoader() + model, = loader.execute("gpt-4") + + mock_stream = [ + Mock(choices=[Mock(delta=Mock(content="Paris "))]), + Mock(choices=[Mock(delta=Mock(content="is "))]), + Mock(choices=[Mock(delta=Mock(content="the "))]), + Mock(choices=[Mock(delta=Mock(content="capital "))]), + Mock(choices=[Mock(delta=Mock(content="of France."))]), + ] + mock_openai_client.chat.completions.create.return_value = mock_stream + + tokens = {"inputs": ["What is the capital of France?"]} + result = model.generate(tokens, max_new_tokens=50) + + assert result == "Paris is the capital of France."