mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improve language and compositing nodes
This commit is contained in:
parent
7e1201e777
commit
a4fb34a0b8
@ -206,6 +206,15 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--openai-api-key",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Configures the OpenAI API Key for the OpenAI nodes",
|
||||
env_var="OPENAI_API_KEY",
|
||||
default=None
|
||||
)
|
||||
|
||||
# now give plugins a chance to add configuration
|
||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||
try:
|
||||
|
||||
@ -111,6 +111,7 @@ class Configuration(dict):
|
||||
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
||||
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
|
||||
openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -198,6 +199,7 @@ class Configuration(dict):
|
||||
self[key] = value
|
||||
|
||||
self.executor_factory: str = "ThreadPoolExecutor"
|
||||
self.openai_api_key: Optional[str] = None
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item not in self:
|
||||
|
||||
@ -2,6 +2,7 @@ from jaxtyping import Float
|
||||
from torch import Tensor
|
||||
|
||||
ImageBatch = Float[Tensor, "batch height width channels"]
|
||||
MaskBatch = Float[Tensor, "batch height width"]
|
||||
RGBImageBatch = Float[Tensor, "batch height width 3"]
|
||||
RGBAImageBatch = Float[Tensor, "batch height width 4"]
|
||||
RGBImage = Float[Tensor, "height width 3"]
|
||||
|
||||
@ -24,13 +24,14 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
||||
args: list = (),
|
||||
kwargs: dict = {},
|
||||
timeout: float = None) -> ProcessFuture:
|
||||
try:
|
||||
args: ExecutePromptArgs
|
||||
prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
super().schedule(function, args, kwargs, timeout)
|
||||
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
|
||||
# try:
|
||||
# args: ExecutePromptArgs
|
||||
# prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
|
||||
#
|
||||
# except ValueError:
|
||||
# pass
|
||||
return super().schedule(function, args, kwargs, timeout)
|
||||
|
||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Callable, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BatchEncoding
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchEncoding, BatchFeature, TensorType
|
||||
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, TruncationStrategy
|
||||
from transformers.utils import PaddingStrategy
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch
|
||||
|
||||
|
||||
class ProcessorResult(TypedDict):
|
||||
"""
|
||||
@ -18,7 +26,61 @@ class ProcessorResult(TypedDict):
|
||||
|
||||
attention_mask: NotRequired[torch.Tensor]
|
||||
pixel_values: NotRequired[torch.Tensor]
|
||||
|
||||
images: NotRequired[torch.Tensor]
|
||||
inputs: NotRequired[BatchEncoding]
|
||||
images: NotRequired[RGBImageBatch]
|
||||
inputs: NotRequired[BatchEncoding | list[str]]
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
|
||||
|
||||
class GenerationKwargs(TypedDict):
|
||||
top_k: NotRequired[int]
|
||||
top_p: NotRequired[float]
|
||||
temperature: NotRequired[float]
|
||||
penalty_alpha: NotRequired[float]
|
||||
num_beams: NotRequired[int]
|
||||
early_stopping: NotRequired[bool]
|
||||
|
||||
|
||||
GENERATION_KWARGS_TYPE = GenerationKwargs
|
||||
GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
|
||||
TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
|
||||
TOKENS_TYPE_NAME = "TOKENS"
|
||||
|
||||
|
||||
class TransformerStreamedProgress(TypedDict):
|
||||
next_token: str
|
||||
|
||||
|
||||
LLaVAProcessor = Callable[
|
||||
[
|
||||
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
|
||||
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
|
||||
Union[bool, str, PaddingStrategy], # padding parameter
|
||||
Union[bool, str, TruncationStrategy], # truncation parameter
|
||||
Optional[int], # max_length parameter
|
||||
Optional[Union[str, TensorType]] # return_tensors parameter
|
||||
],
|
||||
BatchFeature
|
||||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LanguageModel(Protocol):
|
||||
@staticmethod
|
||||
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "LanguageModel":
|
||||
...
|
||||
|
||||
def generate(self, tokens: TOKENS_TYPE = None,
|
||||
max_new_tokens: int = 512,
|
||||
repetition_penalty: float = 0.0,
|
||||
seed: int = 0,
|
||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||
*args,
|
||||
**kwargs) -> str:
|
||||
...
|
||||
|
||||
def tokenize(self, prompt: str, images: List[torch.Tensor] | torch.Tensor, chat_template: str | None = None) -> ProcessorResult:
|
||||
...
|
||||
|
||||
@property
|
||||
def repo_id(self) -> str:
|
||||
return ""
|
||||
|
||||
@ -1,38 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import operator
|
||||
import pathlib
|
||||
import warnings
|
||||
from typing import Optional, Any, Callable, Union, List
|
||||
from functools import reduce
|
||||
from typing import Optional, Any, Callable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, AutoProcessor, AutoTokenizer, \
|
||||
TensorType, BatchFeature
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from transformers.utils import PaddingStrategy
|
||||
BatchFeature, AutoModelForVision2Seq, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel, \
|
||||
PretrainedConfig, TextStreamer, LogitsProcessor
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, \
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
|
||||
from .chat_templates import KNOWN_CHAT_TEMPLATES
|
||||
from .language_types import ProcessorResult
|
||||
from ..model_management import unet_offload_device, get_torch_device
|
||||
from .language_types import ProcessorResult, TOKENS_TYPE, GENERATION_KWARGS_TYPE, TransformerStreamedProgress, \
|
||||
LLaVAProcessor, LanguageModel
|
||||
from .. import model_management
|
||||
from ..model_downloader import get_or_download_huggingface_repo
|
||||
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
|
||||
from ..model_management_types import ModelManageable
|
||||
|
||||
LLaVAProcessor = Callable[
|
||||
[
|
||||
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], # text parameter
|
||||
Union[Image, np.ndarray, torch.Tensor, List[Image], List[np.ndarray], List[torch.Tensor]], # images parameter
|
||||
Union[bool, str, PaddingStrategy], # padding parameter
|
||||
Union[bool, str, TruncationStrategy], # truncation parameter
|
||||
Optional[int], # max_length parameter
|
||||
Optional[Union[str, TensorType]] # return_tensors parameter
|
||||
],
|
||||
BatchFeature
|
||||
]
|
||||
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
||||
|
||||
|
||||
class TransformersManagedModel(ModelManageable):
|
||||
class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
@ -41,7 +35,7 @@ class TransformersManagedModel(ModelManageable):
|
||||
config_dict: Optional[dict] = None,
|
||||
processor: Optional[ProcessorMixin | AutoProcessor] = None
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self._repo_id = repo_id
|
||||
self.model = model
|
||||
self._tokenizer = tokenizer
|
||||
self._processor = processor
|
||||
@ -54,6 +48,200 @@ class TransformersManagedModel(ModelManageable):
|
||||
if model.device != self.offload_device:
|
||||
model.to(device=self.offload_device)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "TransformersManagedModel":
|
||||
hub_kwargs = {}
|
||||
if subfolder is not None and subfolder != "":
|
||||
hub_kwargs["subfolder"] = subfolder
|
||||
repo_id = ckpt_name
|
||||
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||
with comfy_tqdm():
|
||||
from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": ckpt_name,
|
||||
"trust_remote_code": True,
|
||||
**hub_kwargs
|
||||
}
|
||||
|
||||
# compute bitsandbytes configuration
|
||||
try:
|
||||
import bitsandbytes
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
||||
model_type = config_dict["model_type"]
|
||||
# language models prefer to use bfloat16 over float16
|
||||
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
"low_cpu_mem_usage": True,
|
||||
"device_map": str(unet_offload_device()), }, {})
|
||||
|
||||
# if we have flash-attn installed, try to use it
|
||||
try:
|
||||
import flash_attn
|
||||
attn_override_kwargs = {
|
||||
"attn_implementation": "flash_attention_2",
|
||||
**kwargs_to_try[0]
|
||||
}
|
||||
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
||||
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
||||
except ImportError:
|
||||
pass
|
||||
for i, props in enumerate(kwargs_to_try):
|
||||
try:
|
||||
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
if model is not None:
|
||||
break
|
||||
except Exception as exc_info:
|
||||
if i == len(kwargs_to_try) - 1:
|
||||
raise exc_info
|
||||
else:
|
||||
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
for i, props in enumerate(kwargs_to_try):
|
||||
try:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
except:
|
||||
processor = None
|
||||
if isinstance(processor, PreTrainedTokenizerBase):
|
||||
tokenizer = processor
|
||||
processor = None
|
||||
else:
|
||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
||||
if tokenizer is not None or processor is not None:
|
||||
break
|
||||
except Exception as exc_info:
|
||||
if i == len(kwargs_to_try) - 1:
|
||||
raise exc_info
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
logging.debug("enabled xformers memory efficient attention")
|
||||
|
||||
model_managed = TransformersManagedModel(
|
||||
repo_id=repo_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config_dict=config_dict,
|
||||
processor=processor
|
||||
)
|
||||
|
||||
return model_managed
|
||||
|
||||
def generate(self, tokens: TOKENS_TYPE = None,
|
||||
max_new_tokens: int = 512,
|
||||
repetition_penalty: float = 0.0,
|
||||
seed: int = 0,
|
||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||
*args,
|
||||
**kwargs) -> str:
|
||||
tokens = copy.copy(tokens)
|
||||
tokens_original = copy.copy(tokens)
|
||||
sampler = sampler or {}
|
||||
generate_kwargs = copy.copy(sampler)
|
||||
load_models_gpu([self])
|
||||
transformers_model: PreTrainedModel = self.model
|
||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = self.tokenizer
|
||||
# remove unused inputs
|
||||
# maximizes compatibility with different models
|
||||
generate_signature = inspect.signature(transformers_model.generate).parameters
|
||||
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
|
||||
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
|
||||
gen_sig_keys = generate_signature.keys()
|
||||
if "tgt_lang" in tokens:
|
||||
to_delete.add("tgt_lang")
|
||||
to_delete.add("src_lang")
|
||||
to_delete.discard("input_ids")
|
||||
if "forced_bos_token_id" in tokens:
|
||||
to_delete.discard("forced_bos_token_id")
|
||||
elif hasattr(tokenizer, "convert_tokens_to_ids"):
|
||||
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
|
||||
else:
|
||||
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
|
||||
if "input_ids" in tokens and "inputs" in tokens:
|
||||
if "input_ids" in gen_sig_keys:
|
||||
to_delete.add("inputs")
|
||||
elif "inputs" in gen_sig_keys:
|
||||
to_delete.add("input_ids")
|
||||
for unused_kwarg in to_delete:
|
||||
tokens.pop(unused_kwarg)
|
||||
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
|
||||
|
||||
# images should be moved to model
|
||||
for key in ("images", "pixel_values"):
|
||||
if key in tokens:
|
||||
tokens[key] = tokens[key].to(device=self.current_device, dtype=self.model_dtype())
|
||||
|
||||
# sets up inputs
|
||||
inputs = tokens
|
||||
|
||||
# used to determine if text streaming is supported
|
||||
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
|
||||
|
||||
progress_bar: ProgressBar
|
||||
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||
# todo: deal with batches correctly, don't assume batch size 1
|
||||
token_count = 0
|
||||
|
||||
# progress
|
||||
def on_finalized_text(next_token: str, stop: bool):
|
||||
nonlocal token_count
|
||||
nonlocal progress_bar
|
||||
|
||||
token_count += 1
|
||||
preview = TransformerStreamedProgress(next_token=next_token)
|
||||
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
|
||||
|
||||
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||
|
||||
with seed_for_block(seed):
|
||||
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
||||
inputs.pop("attention_mask")
|
||||
output_ids = transformers_model.generate(
|
||||
**inputs,
|
||||
streamer=text_streamer if num_beams <= 1 else None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
||||
**generate_kwargs
|
||||
)
|
||||
|
||||
if not transformers_model.config.is_encoder_decoder:
|
||||
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
||||
output_ids = output_ids[:, start_position:]
|
||||
|
||||
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
|
||||
prev_src_lang = tokenizer.src_lang
|
||||
tokenizer.src_lang = tokens_original["src_lang"]
|
||||
else:
|
||||
prev_src_lang = None
|
||||
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
||||
try:
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
finally:
|
||||
if prev_src_lang is not None:
|
||||
tokenizer.src_lang = prev_src_lang
|
||||
# gpu-loaded stuff like images can now be unloaded
|
||||
if hasattr(tokens, "to"):
|
||||
del tokens
|
||||
else:
|
||||
for to_delete in tokens.values():
|
||||
del to_delete
|
||||
del tokens
|
||||
|
||||
# todo: better support batches
|
||||
return outputs[0]
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> PreTrainedTokenizerBase | AutoTokenizer:
|
||||
return self._tokenizer
|
||||
@ -178,9 +366,32 @@ class TransformersManagedModel(ModelManageable):
|
||||
**batch_feature
|
||||
}
|
||||
|
||||
@property
|
||||
def repo_id(self) -> str:
|
||||
return self._repo_id
|
||||
|
||||
def __str__(self):
|
||||
if self.repo_id is not None:
|
||||
repo_id_as_path = pathlib.PurePath(self.repo_id)
|
||||
return f"<TransformersManagedModel for {'/'.join(repo_id_as_path.parts[-2:])} ({self.model.__class__.__name__})>"
|
||||
else:
|
||||
return f"<TransformersManagedModel for {self.model.__class__.__name__}>"
|
||||
|
||||
|
||||
class _ProgressTextStreamer(TextStreamer):
|
||||
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||
self.on_finalized_text_handler = on_finalized_text
|
||||
|
||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||
self.on_finalized_text_handler(text, stream_end)
|
||||
|
||||
|
||||
class _ProgressLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, model: TransformersManagedModel):
|
||||
self.eos_token_id = model.tokenizer.eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
probabilities = scores.softmax(dim=-1)
|
||||
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
||||
return scores
|
||||
|
||||
@ -385,6 +385,10 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
|
||||
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"),
|
||||
HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"),
|
||||
HuggingFile("TheMistoAI/MistoLine_Flux.dev", "mistoline_flux.dev_v1.safetensors"),
|
||||
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-canny-controlnet-v3.safetensors"),
|
||||
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-depth-controlnet-v3.safetensors"),
|
||||
HuggingFile("XLabs-AI/flux-controlnet-collections", "flux-hed-controlnet-v3.safetensors"),
|
||||
], folder_name="controlnet")
|
||||
|
||||
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
@ -418,6 +422,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
||||
'llava-hf/llava-v1.6-mistral-7b-hf',
|
||||
'facebook/nllb-200-distilled-1.3B',
|
||||
'THUDM/chatglm3-6b',
|
||||
'roborovski/superprompt-v1',
|
||||
}
|
||||
|
||||
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -24,7 +24,7 @@ from .. import model_management
|
||||
from ..cli_args import args
|
||||
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..component_model.tensor_types import RGBImage
|
||||
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||
@ -808,7 +808,7 @@ class ControlNetApply:
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength):
|
||||
if strength == 0:
|
||||
return (conditioning, )
|
||||
|
||||
@ -1573,7 +1573,7 @@ class LoadImage:
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
|
||||
def load_image(self, image: str):
|
||||
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
@ -1703,7 +1703,7 @@ class ImageScale:
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
|
||||
def upscale(self, image, upscale_method, width, height, crop):
|
||||
def upscale(self, image: RGBImageBatch, upscale_method, width, height, crop) -> tuple[RGBImageBatch]:
|
||||
if width == 0 and height == 0:
|
||||
s = image
|
||||
else:
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
from skimage import exposure
|
||||
|
||||
import comfy.utils
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, ImageBatch, MaskBatch
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
@ -34,10 +34,7 @@ class PorterDuffMode(Enum):
|
||||
XOR = 17
|
||||
|
||||
|
||||
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
||||
# convert mask to alpha
|
||||
src_alpha = 1 - src_alpha
|
||||
dst_alpha = 1 - dst_alpha
|
||||
def _porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
||||
# premultiply alpha
|
||||
src_image = src_image * src_alpha
|
||||
dst_image = dst_image * dst_alpha
|
||||
@ -109,24 +106,31 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
||||
return out_image, out_alpha
|
||||
|
||||
|
||||
class PorterDuffImageComposite:
|
||||
class PorterDuffImageCompositeV2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"source": ("IMAGE",),
|
||||
"source_alpha": ("MASK",),
|
||||
"destination": ("IMAGE",),
|
||||
"destination_alpha": ("MASK",),
|
||||
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
||||
},
|
||||
"optional": {
|
||||
"source_alpha": ("MASK",),
|
||||
"destination_alpha": ("MASK",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "composite"
|
||||
CATEGORY = "mask/compositing"
|
||||
|
||||
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
||||
def composite(self, source: RGBImageBatch, destination: RGBImageBatch, mode, source_alpha: MaskBatch = None, destination_alpha: MaskBatch = None) -> tuple[RGBImageBatch, MaskBatch]:
|
||||
if source_alpha is None:
|
||||
source_alpha = torch.zeros(source.shape[:3])
|
||||
if destination_alpha is None:
|
||||
destination_alpha = torch.zeros(destination.shape[:3])
|
||||
|
||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||
out_images = []
|
||||
out_alphas = []
|
||||
@ -153,7 +157,7 @@ class PorterDuffImageComposite:
|
||||
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
||||
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
||||
|
||||
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
||||
out_image, out_alpha = _porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
||||
|
||||
out_images.append(out_image)
|
||||
out_alphas.append(out_alpha.squeeze(2))
|
||||
@ -162,6 +166,28 @@ class PorterDuffImageComposite:
|
||||
return result
|
||||
|
||||
|
||||
class PorterDuffImageCompositeV1(PorterDuffImageCompositeV2):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"source": ("IMAGE",),
|
||||
"source_alpha": ("MASK",),
|
||||
"destination": ("IMAGE",),
|
||||
"destination_alpha": ("MASK",),
|
||||
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
||||
},
|
||||
}
|
||||
|
||||
FUNCTION = "composite_v1"
|
||||
|
||||
def composite_v1(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> tuple[RGBImageBatch, MaskBatch]:
|
||||
# convert mask to alpha
|
||||
source_alpha = 1 - source_alpha
|
||||
destination_alpha = 1 - destination_alpha
|
||||
return super().composite(source, destination, mode, source_alpha, destination_alpha)
|
||||
|
||||
|
||||
class SplitImageWithAlpha:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -312,7 +338,8 @@ class Posterize(CustomNode):
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
||||
"PorterDuffImageComposite": PorterDuffImageCompositeV1,
|
||||
"PorterDuffImageCompositeV2": PorterDuffImageCompositeV2,
|
||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
||||
"EnhanceContrast": EnhanceContrast,
|
||||
@ -321,7 +348,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
||||
"PorterDuffImageComposite": "Porter-Duff Image Composite (V1)",
|
||||
"PorterDuffImageCompositeV2": "Image Composite",
|
||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
||||
"JoinImageWithAlpha": "Join Image with Alpha",
|
||||
}
|
||||
|
||||
@ -1,80 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import operator
|
||||
import os.path
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, Optional, List, Callable, Union
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||
PreTrainedTokenizerBase, PretrainedConfig, AutoProcessor, BatchFeature, AutoModel, AutoModelForCausalLM, \
|
||||
AutoModelForSeq2SeqLM
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, \
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, AutoModelForVision2Seq
|
||||
from transformers import AutoProcessor
|
||||
from transformers.models.m2m_100.tokenization_m2m_100 import \
|
||||
FAIRSEQ_LANGUAGE_CODES as tokenization_m2m_100_FAIRSEQ_LANGUAGE_CODES
|
||||
from transformers.models.nllb.tokenization_nllb import \
|
||||
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.folder_path_types import SaveImagePathResponse
|
||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||
from comfy.language.language_types import ProcessorResult
|
||||
from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \
|
||||
TOKENS_TYPE_NAME, LanguageModel
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device, load_models_gpu
|
||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||
|
||||
_AUTO_CHAT_TEMPLATE = "default"
|
||||
|
||||
# add llava support
|
||||
try:
|
||||
from llava import model as _llava_model_side_effects
|
||||
|
||||
logging.debug("Additional LLaVA models are now supported")
|
||||
except ImportError as exc:
|
||||
logging.debug(f"Install LLavA with `pip install git+https://github.com/AppMana/appmana-comfyui-llava` for additional LLaVA support")
|
||||
|
||||
# aka kwargs type
|
||||
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
||||
_GENERATION_KWARGS_TYPE_NAME = "SAMPLER"
|
||||
|
||||
_TOKENS_TYPE = Union[ProcessorResult, BatchFeature]
|
||||
TOKENS_TYPE_NAME = "TOKENS"
|
||||
|
||||
|
||||
class _ProgressTextStreamer(TextStreamer):
|
||||
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||
self.on_finalized_text_handler = on_finalized_text
|
||||
|
||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||
self.on_finalized_text_handler(text, stream_end)
|
||||
|
||||
|
||||
class _ProgressLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, model: TransformersManagedModel):
|
||||
self.eos_token_id = model.tokenizer.eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
probabilities = scores.softmax(dim=-1)
|
||||
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
||||
return scores
|
||||
|
||||
|
||||
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
|
||||
class TransformerStreamedProgress(TypedDict):
|
||||
next_token: str
|
||||
|
||||
|
||||
class TransformerSamplerBase(CustomNode):
|
||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
||||
RETURN_NAMES = "GENERATION ARGS",
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "language/samplers"
|
||||
@ -142,7 +94,7 @@ class TransformersGenerationConfig(CustomNode):
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
||||
RETURN_NAMES = "GENERATION ARGS",
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "language"
|
||||
@ -182,15 +134,15 @@ class TransformerBeamSearchSampler(TransformerSamplerBase):
|
||||
class TransformerMergeSamplers(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
||||
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
||||
range_ = {"value0": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
||||
range_.update({f"value{i}": (GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
@ -238,98 +190,11 @@ class TransformersLoader(CustomNode):
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = "MODEL",
|
||||
RETURN_NAMES = "language model",
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs):
|
||||
hub_kwargs = {}
|
||||
if subfolder is not None and subfolder != "":
|
||||
hub_kwargs["subfolder"] = subfolder
|
||||
|
||||
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||
with comfy_tqdm():
|
||||
from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": ckpt_name,
|
||||
"trust_remote_code": True,
|
||||
**hub_kwargs
|
||||
}
|
||||
|
||||
# if flash attention exists, use it
|
||||
|
||||
# compute bitsandbytes configuration
|
||||
try:
|
||||
import bitsandbytes
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(ckpt_name, **hub_kwargs)
|
||||
model_type = config_dict["model_type"]
|
||||
# language models prefer to use bfloat16 over float16
|
||||
kwargs_to_try = ({"torch_dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
"low_cpu_mem_usage": True,
|
||||
"device_map": str(unet_offload_device()), }, {})
|
||||
|
||||
# if we have flash-attn installed, try to use it
|
||||
try:
|
||||
import flash_attn
|
||||
attn_override_kwargs = {
|
||||
"attn_implementation": "flash_attention_2",
|
||||
**kwargs_to_try[0]
|
||||
}
|
||||
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
||||
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
||||
except ImportError:
|
||||
pass
|
||||
for i, props in enumerate(kwargs_to_try):
|
||||
try:
|
||||
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
if model is not None:
|
||||
break
|
||||
except Exception as exc_info:
|
||||
if i == len(kwargs_to_try) - 1:
|
||||
raise exc_info
|
||||
else:
|
||||
logging.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
for i, props in enumerate(kwargs_to_try):
|
||||
try:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
except:
|
||||
processor = None
|
||||
if isinstance(processor, PreTrainedTokenizerBase):
|
||||
tokenizer = processor
|
||||
processor = None
|
||||
else:
|
||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
||||
if tokenizer is not None or processor is not None:
|
||||
break
|
||||
except Exception as exc_info:
|
||||
if i == len(kwargs_to_try) - 1:
|
||||
raise exc_info
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
if model_management.xformers_enabled() and hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
logging.debug("enabled xformers memory efficient attention")
|
||||
|
||||
model_managed = TransformersManagedModel(
|
||||
repo_id=ckpt_name,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config_dict=config_dict,
|
||||
processor=processor
|
||||
)
|
||||
return model_managed,
|
||||
def execute(self, ckpt_name: str, subfolder: Optional[str] = None, *args, **kwargs) -> tuple[TransformersManagedModel]:
|
||||
return TransformersManagedModel.from_pretrained(ckpt_name, subfolder),
|
||||
|
||||
|
||||
class TransformersTokenize(CustomNode):
|
||||
@ -346,7 +211,7 @@ class TransformersTokenize(CustomNode):
|
||||
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: TransformersManagedModel, prompt: str) -> ValidatedNodeResult:
|
||||
def execute(self, model: LanguageModel, prompt: str) -> ValidatedNodeResult:
|
||||
return model.tokenize(prompt, [], None),
|
||||
|
||||
|
||||
@ -452,7 +317,7 @@ class OneShotInstructTokenize(CustomNode):
|
||||
RETURN_TYPES = (TOKENS_TYPE_NAME,)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: TransformersManagedModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
|
||||
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, chat_template: str = "__auto__") -> ValidatedNodeResult:
|
||||
if chat_template == _AUTO_CHAT_TEMPLATE:
|
||||
# use an exact match
|
||||
model_name = os.path.basename(model.repo_id)
|
||||
@ -475,10 +340,9 @@ class TransformersGenerate(CustomNode):
|
||||
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
||||
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}),
|
||||
"use_cache": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {
|
||||
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
|
||||
"sampler": (GENERATION_KWARGS_TYPE_NAME, {}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -487,110 +351,14 @@ class TransformersGenerate(CustomNode):
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self,
|
||||
model: Optional[TransformersManagedModel] = None,
|
||||
tokens: _TOKENS_TYPE = None,
|
||||
model: Optional[LanguageModel] = None,
|
||||
tokens: TOKENS_TYPE = None,
|
||||
max_new_tokens: int = 512,
|
||||
repetition_penalty: float = 0.0,
|
||||
seed: int = 0,
|
||||
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||
):
|
||||
tokens = copy.copy(tokens)
|
||||
tokens_original = copy.copy(tokens)
|
||||
sampler = sampler or {}
|
||||
generate_kwargs = copy.copy(sampler)
|
||||
load_models_gpu([model])
|
||||
transformers_model: PreTrainedModel = model.model
|
||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
||||
# remove unused inputs
|
||||
# maximizes compatibility with different models
|
||||
generate_signature = inspect.signature(transformers_model.generate).parameters
|
||||
prepare_signature = inspect.signature(transformers_model.prepare_inputs_for_generation).parameters
|
||||
to_delete = set(reduce(operator.sub, map(lambda x: x.keys(), [tokens, generate_signature, prepare_signature])))
|
||||
gen_sig_keys = generate_signature.keys()
|
||||
if "tgt_lang" in tokens:
|
||||
to_delete.add("tgt_lang")
|
||||
to_delete.add("src_lang")
|
||||
to_delete.discard("input_ids")
|
||||
if "forced_bos_token_id" in tokens:
|
||||
to_delete.discard("forced_bos_token_id")
|
||||
elif hasattr(tokenizer, "convert_tokens_to_ids"):
|
||||
generate_kwargs["forced_bos_token_id"] = tokenizer.convert_tokens_to_ids(tokens["tgt_lang"])
|
||||
else:
|
||||
logging.warning(f"tokenizer {tokenizer} unexpected for translation task")
|
||||
if "input_ids" in tokens and "inputs" in tokens:
|
||||
if "input_ids" in gen_sig_keys:
|
||||
to_delete.add("inputs")
|
||||
elif "inputs" in gen_sig_keys:
|
||||
to_delete.add("input_ids")
|
||||
for unused_kwarg in to_delete:
|
||||
tokens.pop(unused_kwarg)
|
||||
logging.debug(f"{transformers_model.name_or_path}.generate does not accept {unused_kwarg}, removing")
|
||||
|
||||
# images should be moved to model
|
||||
for key in ("images", "pixel_values"):
|
||||
if key in tokens:
|
||||
tokens[key] = tokens[key].to(device=model.current_device, dtype=model.model_dtype())
|
||||
|
||||
# sets up inputs
|
||||
inputs = tokens
|
||||
|
||||
# used to determine if text streaming is supported
|
||||
num_beams = generate_kwargs.get("num_beams", transformers_model.generation_config.num_beams)
|
||||
|
||||
progress_bar: ProgressBar
|
||||
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||
# todo: deal with batches correctly, don't assume batch size 1
|
||||
token_count = 0
|
||||
|
||||
# progress
|
||||
def on_finalized_text(next_token: str, stop: bool):
|
||||
nonlocal token_count
|
||||
nonlocal progress_bar
|
||||
|
||||
token_count += 1
|
||||
preview = TransformerStreamedProgress(next_token=next_token)
|
||||
progress_bar.update_absolute(token_count, total=max_new_tokens, preview_image_or_output=preview)
|
||||
|
||||
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||
|
||||
with seed_for_block(seed):
|
||||
if hasattr(inputs, "encodings") and inputs.encodings is not None and all(hasattr(encoding, "attention_mask") for encoding in inputs.encodings) and "attention_mask" in inputs:
|
||||
inputs.pop("attention_mask")
|
||||
output_ids = transformers_model.generate(
|
||||
**inputs,
|
||||
streamer=text_streamer if num_beams <= 1 else None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
||||
**generate_kwargs
|
||||
)
|
||||
|
||||
if not transformers_model.config.is_encoder_decoder:
|
||||
start_position = inputs["input_ids" if "input_ids" in inputs else "inputs"].shape[1]
|
||||
output_ids = output_ids[:, start_position:]
|
||||
|
||||
if hasattr(tokenizer, "src_lang") and "src_lang" in tokens_original:
|
||||
prev_src_lang = tokenizer.src_lang
|
||||
tokenizer.src_lang = tokens_original["src_lang"]
|
||||
else:
|
||||
prev_src_lang = None
|
||||
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
||||
try:
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
finally:
|
||||
if prev_src_lang is not None:
|
||||
tokenizer.src_lang = prev_src_lang
|
||||
# gpu-loaded stuff like images can now be unloaded
|
||||
if hasattr(tokens, "to"):
|
||||
del tokens
|
||||
else:
|
||||
for to_delete in tokens.values():
|
||||
del to_delete
|
||||
del tokens
|
||||
|
||||
# todo: better support batches
|
||||
return outputs[0],
|
||||
return model.generate(tokens, max_new_tokens, repetition_penalty, seed, sampler),
|
||||
|
||||
|
||||
class PreviewString(CustomNode):
|
||||
|
||||
206
comfy_extras/nodes/nodes_openai.py
Normal file
206
comfy_extras/nodes/nodes_openai.py
Normal file
@ -0,0 +1,206 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.component_model.tensor_types import RGBImageBatch
|
||||
from comfy.language.language_types import LanguageModel, ProcessorResult, GENERATION_KWARGS_TYPE, TOKENS_TYPE, \
|
||||
TransformerStreamedProgress
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
from comfy.utils import comfy_progress, ProgressBar, seed_for_block
|
||||
|
||||
|
||||
class _Client:
|
||||
_client: Optional[OpenAI] = None
|
||||
|
||||
@staticmethod
|
||||
def instance() -> OpenAI:
|
||||
if _Client._client is None:
|
||||
open_ai_api_key = args.openai_api_key
|
||||
_Client._client = OpenAI(
|
||||
api_key=open_ai_api_key,
|
||||
)
|
||||
|
||||
return _Client._client
|
||||
|
||||
|
||||
def validate_has_key():
|
||||
open_api_key = os.environ.get("OPENAI_API_KEY", args.openai_api_key)
|
||||
if open_api_key is None or open_api_key == "":
|
||||
return "set OPENAI_API_KEY environment variable"
|
||||
return True
|
||||
|
||||
|
||||
def image_to_base64(image: RGBImageBatch) -> str:
|
||||
pil_image = Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||
buffered = io.BytesIO()
|
||||
pil_image.save(buffered, format="JPEG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
class OpenAILanguageModelWrapper(LanguageModel):
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
self.client = _Client.instance()
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(ckpt_name: str, subfolder: Optional[str] = None) -> "OpenAILanguageModelWrapper":
|
||||
return OpenAILanguageModelWrapper(ckpt_name)
|
||||
|
||||
def generate(self, tokens: TOKENS_TYPE = None,
|
||||
max_new_tokens: int = 512,
|
||||
repetition_penalty: float = 0.0,
|
||||
seed: int = 0,
|
||||
sampler: Optional[GENERATION_KWARGS_TYPE] = None,
|
||||
*args,
|
||||
**kwargs) -> str:
|
||||
sampler = sampler or {}
|
||||
prompt = tokens.get("inputs", [])
|
||||
prompt = "".join(prompt)
|
||||
images = tokens.get("images", [])
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
] + [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_to_base64(image)}"
|
||||
}
|
||||
} for image in images
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
progress_bar: ProgressBar
|
||||
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||
token_count = 0
|
||||
full_response = ""
|
||||
|
||||
def on_finalized_text(next_token: str, stop: bool):
|
||||
nonlocal token_count
|
||||
nonlocal progress_bar
|
||||
nonlocal full_response
|
||||
|
||||
token_count += 1
|
||||
full_response += next_token
|
||||
preview = TransformerStreamedProgress(next_token=next_token)
|
||||
progress_bar.update_absolute(max_new_tokens if stop else token_count, total=max_new_tokens, preview_image_or_output=preview)
|
||||
|
||||
with seed_for_block(seed):
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=sampler.get("temperature", 1.0),
|
||||
top_p=sampler.get("top_p", 1.0),
|
||||
# n=1,
|
||||
# stop=None,
|
||||
# presence_penalty=repetition_penalty,
|
||||
seed=seed,
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
on_finalized_text(chunk.choices[0].delta.content, False)
|
||||
|
||||
on_finalized_text("", True) # Signal the end of streaming
|
||||
|
||||
return full_response
|
||||
|
||||
def tokenize(self, prompt: str, images: RGBImageBatch, chat_template: str | None = None) -> ProcessorResult:
|
||||
# OpenAI API doesn't require explicit tokenization, so we'll just return the prompt and images as is
|
||||
return {
|
||||
"inputs": [prompt],
|
||||
"attention_mask": torch.ones(1, len(prompt)), # Dummy attention mask
|
||||
"images": images
|
||||
}
|
||||
|
||||
@property
|
||||
def repo_id(self) -> str:
|
||||
return f"openai/{self.model}"
|
||||
|
||||
|
||||
class OpenAILanguageModelLoader(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": (["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"], {"default": "gpt-3.5-turbo"})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("language model",)
|
||||
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "openai"
|
||||
|
||||
def execute(self, model: str) -> tuple[LanguageModel]:
|
||||
return OpenAILanguageModelWrapper(model),
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls):
|
||||
return validate_has_key()
|
||||
|
||||
|
||||
class DallEGenerate(CustomNode):
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {"required": {
|
||||
"model": (["dall-e-2", "dall-e-3"], {"default": "dall-e-3"}),
|
||||
"text": ("STRING", {"multiline": True}),
|
||||
"size": (["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], {"default": "1024x1024"}),
|
||||
"quality": (["standard", "hd"], {"default": "standard"}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "STRING",)
|
||||
RETURN_NAMES = ("images", "revised prompt")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "openai"
|
||||
|
||||
def generate(self,
|
||||
model: Literal["dall-e-2", "dall-e-3"],
|
||||
text: str,
|
||||
size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
|
||||
quality: Literal["standard", "hd"]) -> tuple[RGBImageBatch, str]:
|
||||
response = _Client.instance().images.generate(
|
||||
model=model,
|
||||
prompt=text,
|
||||
size=size,
|
||||
quality=quality,
|
||||
n=1,
|
||||
)
|
||||
|
||||
image_url = response.data[0].url
|
||||
image_response = requests.get(image_url)
|
||||
|
||||
img = Image.open(BytesIO(image_response.content))
|
||||
|
||||
image = np.array(img).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
return image, response.data[0].revised_prompt
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls):
|
||||
return validate_has_key()
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"DallEGenerate": DallEGenerate,
|
||||
"OpenAILanguageModelLoader": OpenAILanguageModelLoader
|
||||
}
|
||||
@ -65,4 +65,6 @@ ml_dtypes
|
||||
diffusers>=0.30.1
|
||||
vtracer
|
||||
skia-python
|
||||
pebble>=5.0.7
|
||||
pebble>=5.0.7
|
||||
openai
|
||||
anthropic
|
||||
14
tests/inference/test_language.py
Normal file
14
tests/inference/test_language.py
Normal file
@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize
|
||||
|
||||
|
||||
def test_integration_transformers_loader_and_tokenize():
|
||||
loader = TransformersLoader()
|
||||
tokenize = OneShotInstructTokenize()
|
||||
|
||||
model, = loader.execute("llava-hf/llava-v1.6-mistral-7b-hf", "")
|
||||
tokens, = tokenize.execute(model, "Describe this image:", torch.rand((1, 224, 224, 3)), "llava-v1.6-mistral-7b-hf", )
|
||||
|
||||
assert isinstance(tokens, dict)
|
||||
assert "input_ids" in tokens or "inputs" in tokens
|
||||
@ -1,10 +1,17 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.language.language_types import LanguageModel, ProcessorResult
|
||||
from comfy_extras.nodes.nodes_language import SaveString
|
||||
from comfy_extras.nodes.nodes_language import TransformersLoader, OneShotInstructTokenize, TransformersGenerate, \
|
||||
PreviewString
|
||||
from comfy_extras.nodes.nodes_openai import OpenAILanguageModelLoader, OpenAILanguageModelWrapper, DallEGenerate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -57,3 +64,151 @@ def test_save_string_default_extension(save_string_node, mock_get_save_path):
|
||||
assert os.path.exists(saved_file_path)
|
||||
with open(saved_file_path, "r") as f:
|
||||
assert f.read() == test_string
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
with patch('comfy_extras.nodes.nodes_openai._Client') as mock_client:
|
||||
instance = mock_client.instance.return_value
|
||||
instance.chat.completions.create = Mock()
|
||||
instance.images.generate = Mock()
|
||||
yield instance
|
||||
|
||||
|
||||
def test_transformers_loader():
|
||||
loader = TransformersLoader()
|
||||
model, = loader.execute("microsoft/Phi-3-mini-4k-instruct", "")
|
||||
assert isinstance(model, LanguageModel)
|
||||
assert model.repo_id == "microsoft/Phi-3-mini-4k-instruct"
|
||||
|
||||
|
||||
def test_one_shot_instruct_tokenize(mocker):
|
||||
tokenize = OneShotInstructTokenize()
|
||||
mock_model = mocker.Mock()
|
||||
mock_model.tokenize.return_value = {"input_ids": torch.tensor([[1, 2, 3]])}
|
||||
|
||||
tokens, = tokenize.execute(mock_model, "What comes after apple?", [], "phi-3")
|
||||
mock_model.tokenize.assert_called_once_with("What comes after apple?", [], mocker.ANY)
|
||||
assert "input_ids" in tokens
|
||||
|
||||
|
||||
def test_transformers_generate(mocker):
|
||||
generate = TransformersGenerate()
|
||||
mock_model = mocker.Mock()
|
||||
mock_model.generate.return_value = "The letter B comes after A in the alphabet."
|
||||
|
||||
tokens: ProcessorResult = {"inputs": torch.tensor([[1, 2, 3]])}
|
||||
result, = generate.execute(mock_model, tokens, 512, 0, 42)
|
||||
mock_model.generate.assert_called_once()
|
||||
assert isinstance(result, str)
|
||||
assert "letter B" in result
|
||||
|
||||
|
||||
def test_preview_string():
|
||||
preview = PreviewString()
|
||||
result = preview.execute("Test output")
|
||||
assert result == {"ui": {"string": ["Test output"]}}
|
||||
|
||||
|
||||
def test_openai_language_model_loader():
|
||||
if not "OPENAI_API_KEY" in os.environ:
|
||||
pytest.skip("must set OPENAI_API_KEY")
|
||||
loader = OpenAILanguageModelLoader()
|
||||
model, = loader.execute("gpt-3.5-turbo")
|
||||
assert isinstance(model, OpenAILanguageModelWrapper)
|
||||
assert model.model == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
def test_openai_language_model_wrapper_generate(mock_openai_client):
|
||||
wrapper = OpenAILanguageModelWrapper("gpt-3.5-turbo")
|
||||
mock_stream = [
|
||||
Mock(choices=[Mock(delta=Mock(content="This "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="is "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="a "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="test "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="response."))]),
|
||||
]
|
||||
|
||||
mock_openai_client.chat.completions.create.return_value = mock_stream
|
||||
|
||||
tokens = {"inputs": ["What is the capital of France?"]}
|
||||
result = wrapper.generate(tokens, max_new_tokens=50)
|
||||
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}],
|
||||
max_tokens=50,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
seed=0,
|
||||
stream=True
|
||||
)
|
||||
assert result == "This is a test response."
|
||||
|
||||
|
||||
def test_openai_language_model_wrapper_generate_with_image(mock_openai_client):
|
||||
wrapper = OpenAILanguageModelWrapper("gpt-4-vision-preview")
|
||||
mock_stream = [
|
||||
Mock(choices=[Mock(delta=Mock(content="This "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="image "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="shows "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="a "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="landscape."))]),
|
||||
]
|
||||
mock_openai_client.chat.completions.create.return_value = mock_stream
|
||||
|
||||
image_tensor = torch.rand((1, 224, 224, 3))
|
||||
tokens: ProcessorResult = {
|
||||
"inputs": ["Describe this image:"],
|
||||
"images": image_tensor
|
||||
}
|
||||
result = wrapper.generate(tokens, max_new_tokens=50)
|
||||
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
assert result == "This image shows a landscape."
|
||||
|
||||
|
||||
def test_dalle_generate(mock_openai_client):
|
||||
dalle = DallEGenerate()
|
||||
mock_openai_client.images.generate.return_value = Mock(
|
||||
data=[Mock(url="http://example.com/image.jpg", revised_prompt="A beautiful sunset")]
|
||||
)
|
||||
test_image = Image.new('RGB', (10, 10), color='red')
|
||||
img_byte_arr = io.BytesIO()
|
||||
test_image.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
|
||||
with patch('requests.get') as mock_get:
|
||||
mock_get.return_value = Mock(content=img_byte_arr)
|
||||
image, revised_prompt = dalle.generate("dall-e-3", "Create a sunset image", "1024x1024", "standard")
|
||||
|
||||
assert isinstance(image, torch.Tensor)
|
||||
assert image.shape == (1, 10, 10, 3)
|
||||
assert torch.allclose(image, torch.tensor([1.0, 0, 0]).view(1, 1, 1, 3).expand(1, 10, 10, 3))
|
||||
assert revised_prompt == "A beautiful sunset"
|
||||
mock_openai_client.images.generate.assert_called_once_with(
|
||||
model="dall-e-3",
|
||||
prompt="Create a sunset image",
|
||||
size="1024x1024",
|
||||
quality="standard",
|
||||
n=1,
|
||||
)
|
||||
|
||||
|
||||
def test_integration_openai_loader_and_wrapper(mock_openai_client):
|
||||
loader = OpenAILanguageModelLoader()
|
||||
model, = loader.execute("gpt-4")
|
||||
|
||||
mock_stream = [
|
||||
Mock(choices=[Mock(delta=Mock(content="Paris "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="is "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="the "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="capital "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="of France."))]),
|
||||
]
|
||||
mock_openai_client.chat.completions.create.return_value = mock_stream
|
||||
|
||||
tokens = {"inputs": ["What is the capital of France?"]}
|
||||
result = model.generate(tokens, max_new_tokens=50)
|
||||
|
||||
assert result == "Paris is the capital of France."
|
||||
|
||||
Loading…
Reference in New Issue
Block a user