From b0be335d594597a6167422c68a91b0dec2c41206 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 26 Mar 2024 22:32:15 -0700 Subject: [PATCH] Improved support for ControlNet workflows with depth - ComfyUI can now load EXR files. - There are new arithmetic nodes for floats and integers. - EXR nodes can load depth maps and be remapped with ImageApplyColormap. This allows end users to use ground truth depth data from video game engines or 3D graphics tools and recolor it to the format expected by depth ControlNets: grayscale inverse depth maps and "inferno" colored inverse depth maps. - Fixed license notes. - Added an additional known ControlNet model. - Because CV2 is now used to read OpenEXR files, an environment variable must be set early on in the application, before CV2 is imported. This file, main_pre, is now imported early on in more places. --- comfy/client/embedded_comfy_client.py | 2 +- comfy/cmd/main.py | 4 +- comfy/cmd/main_pre.py | 12 + comfy/cmd/server.py | 10 +- comfy/cmd/worker.py | 21 +- comfy/component_model/images_types.py | 8 + comfy/images.py | 19 + comfy/k_diffusion/sampling.py | 27 +- comfy/model_downloader.py | 3 +- comfy/model_downloader_types.py | 4 +- comfy/nodes/base_nodes.py | 55 +- comfy/nodes/package_typing.py | 2 +- comfy/open_exr.py | 86 ++++ comfy/utils.py | 30 +- comfy/web/scripts/widgets.js | 2 +- comfy_extras/nodes/nodes_apply_color_map.py | 85 +++ comfy_extras/nodes/nodes_arithmetic.py | 514 +++++++++++++++++++ comfy_extras/nodes/nodes_freelunch.py | 26 +- comfy_extras/nodes/nodes_image_arithmetic.py | 48 ++ comfy_extras/nodes/nodes_open_api.py | 4 +- comfy_extras/nodes/nodes_sag.py | 4 +- requirements.txt | 3 +- tests/nodes/test_arithmetic_unit.py | 198 +++++++ tests/nodes/test_colormap_unit.py | 49 ++ 24 files changed, 1157 insertions(+), 59 deletions(-) create mode 100644 comfy/component_model/images_types.py create mode 100644 comfy/images.py create mode 100644 comfy/open_exr.py create mode 100644 comfy_extras/nodes/nodes_apply_color_map.py create mode 100644 comfy_extras/nodes/nodes_arithmetic.py create mode 100644 comfy_extras/nodes/nodes_image_arithmetic.py create mode 100644 tests/nodes/test_arithmetic_unit.py create mode 100644 tests/nodes/test_colormap_unit.py diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index e4271a98a..34e75c04d 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -100,7 +100,7 @@ class EmbeddedComfyClient: if self._configuration is None: options.enable_args_parsing() else: - from ..cli_args import args + from ..cmd.main_pre import args args.clear() args.update(self._configuration) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 4d929f996..efb03e599 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,4 +1,3 @@ -# Suppress warnings during import import asyncio import gc import itertools @@ -8,9 +7,10 @@ import shutil import threading import time +# main_pre must be the earliest import since it suppresses some spurious warnings +from .main_pre import args from ..utils import hijack_progress from .extra_model_paths import load_extra_path_config -from .main_pre import args from .. import model_management from ..analytics.analytics import initialize_event_tracking from ..cmd import cuda_malloc diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index b78ddf231..51408b682 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -1,3 +1,12 @@ +""" +This should be imported before entrypoints to correctly configure global options prior to importing packages like torch and cv2. + +Use this instead of cli_args to import the args: + +>>> from comfy.cmd.main_pre import args + +It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes. +""" import os from .. import options @@ -9,6 +18,8 @@ options.enable_args_parsing() if os.name == "nt": logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") +warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.") +warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") from ..cli_args import args @@ -20,4 +31,5 @@ if args.deterministic: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" __all__ = ["args"] diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index d762c4609..84a71fb6f 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -38,6 +38,7 @@ from ..component_model.file_output_path import file_output_path from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation from ..digest import digest from ..nodes.package_typing import ExportedNodes +from ..images import open_image class HeuristicPath(NamedTuple): @@ -289,9 +290,12 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=400) if os.path.isfile(file): - if 'preview' in request.rel_url.query: - with Image.open(file) as img: - preview_info = request.rel_url.query['preview'].split(';') + # todo: any image file we upload that browsers don't support, we should encode a preview + # todo: image handling has to be a little bit more standardized, sometimes we want a Pillow Image, sometimes + # we want something that will render to the user, sometimes we want tensors + if 'preview' in request.rel_url.query or file.endswith(".exr"): + with open_image(file) as img: + preview_info = request.rel_url.query.get("preview", "jpeg;90").split(';') image_format = preview_info[0] if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''): image_format = 'webp' diff --git a/comfy/cmd/worker.py b/comfy/cmd/worker.py index 51224c0cc..b3a2f6cf7 100644 --- a/comfy/cmd/worker.py +++ b/comfy/cmd/worker.py @@ -4,11 +4,7 @@ import os import logging from .extra_model_paths import load_extra_path_config -from .. import options - -options.enable_args_parsing() - -from ..cli_args import args +from .main_pre import args async def main(): @@ -17,28 +13,19 @@ async def main(): args.distributed_queue_frontend = False assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server" - - if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - logging.info(f"Set cuda device to: {args.cuda_device}") - - if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" - # configure paths if args.output_directory: output_dir = os.path.abspath(args.output_directory) logging.info(f"Setting output directory to: {output_dir}") from ..cmd import folder_paths - + folder_paths.set_output_directory(output_dir) - + if args.input_directory: input_dir = os.path.abspath(args.input_directory) logging.info(f"Setting input directory to: {input_dir}") from ..cmd import folder_paths - + folder_paths.set_input_directory(input_dir) if args.temp_directory: diff --git a/comfy/component_model/images_types.py b/comfy/component_model/images_types.py new file mode 100644 index 000000000..8ef892059 --- /dev/null +++ b/comfy/component_model/images_types.py @@ -0,0 +1,8 @@ +from typing import NamedTuple + +from torch import Tensor + + +class RgbMaskTuple(NamedTuple): + rgb: Tensor + mask: Tensor diff --git a/comfy/images.py b/comfy/images.py new file mode 100644 index 000000000..660de0a03 --- /dev/null +++ b/comfy/images.py @@ -0,0 +1,19 @@ +import os.path +from contextlib import contextmanager + +import cv2 +from PIL import Image + + +def _open_exr(exr_path) -> Image.Image: + return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) + + +@contextmanager +def open_image(file_path: str) -> Image.Image: + _, ext = os.path.splitext(file_path) + if ext == ".exr": + yield _open_exr(file_path) + else: + with Image.open(file_path) as image: + yield image diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 7af016829..40d7e8e0e 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -755,7 +755,32 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n @torch.no_grad() def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): - # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ + """ + Portions of this function are adapted from the repository + https://github.com/Carzit/sd-webui-samplers-scheduler + + MIT License + + Copyright (c) 2023 Carzit + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) s_end = sigmas[-1] diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 96a1061fc..8c9971076 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -51,7 +51,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi else: linked_filename = None try: - os.symlink(os.path.join(destination,known_file.filename), linked_filename) + os.symlink(os.path.join(destination, known_file.filename), linked_filename) except Exception as exc_info: logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info) else: @@ -213,6 +213,7 @@ KNOWN_CONTROLNETS = [ HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_faid_vidit.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_zeed.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_softedge.safetensors"), + HuggingFile("SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe", "depth-zoe-xl-v1.0-controlnet.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_canny.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_midas.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_zoe.safetensors"), diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 788af7eb8..554a325e5 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -2,7 +2,7 @@ from __future__ import annotations import dataclasses from os.path import split -from typing import Optional, List +from typing import Optional, List, Sequence from typing_extensions import TypedDict, NotRequired @@ -15,10 +15,12 @@ class CivitFile: model_id (int): The ID of the model model_version_id (int): The version filename (str): The name of the file in the model + trigger_words (List[str]): Trigger words associated with the model """ model_id: int model_version_id: int filename: str + trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple) def __str__(self): return self.filename diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 0c6bdb16d..081fd1d2b 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -9,6 +9,7 @@ import logging from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo +from natsort import natsorted from pkg_resources import resource_filename import numpy as np import safetensors.torch @@ -23,10 +24,13 @@ from .. import model_management from ..cli_args import args from ..cmd import folder_paths, latent_preview +from ..images import open_image from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS from ..nodes.common import MAX_RESOLUTION from .. import controlnet +from ..open_exr import load_exr + class CLIPTextEncode: @classmethod @@ -1454,38 +1458,49 @@ class PreviewImage(SaveImage): "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - return {"required": - {"image": (sorted(files), {"image_upload": True})}, - } + return { + "required": { + "image": (natsorted(files), {"image_upload": True}), + }, + } CATEGORY = "image" RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" - def load_image(self, image): + + def load_image(self, image: str): image_path = folder_paths.get_annotated_filepath(image) - img = Image.open(image_path) output_images = [] output_masks = [] - for i in ImageSequence.Iterator(img): - i = ImageOps.exif_transpose(i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - output_images.append(image) - output_masks.append(mask.unsqueeze(0)) + + # maintain the legacy path + # this will ultimately return a tensor, so we'd rather have the tensors directly + # from cv2 rather than get them out of a PIL image + _, ext = os.path.splitext(image) + if ext == ".exr": + return load_exr(image_path, srgb=False) + with open_image(image_path) as img: + for i in ImageSequence.Iterator(img): + i = ImageOps.exif_transpose(i) + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) @@ -1494,7 +1509,7 @@ class LoadImage: output_image = output_images[0] output_mask = output_masks[0] - return (output_image, output_mask) + return output_image, output_mask @classmethod def IS_CHANGED(s, image): diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index f993b9945..6de60703c 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -53,7 +53,7 @@ BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions] ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]] -NonPrimitiveTypeSpec = Tuple[CommonReturnTypes] +NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any] InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec] diff --git a/comfy/open_exr.py b/comfy/open_exr.py new file mode 100644 index 000000000..acd76f1bb --- /dev/null +++ b/comfy/open_exr.py @@ -0,0 +1,86 @@ +""" +Portions of this code are adapted from the repository +https://github.com/spacepxl/ComfyUI-HQ-Image-Save + +MIT License + +Copyright (c) 2023 spacepxl + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import copy +from typing import Sequence, Tuple + +import cv2 as cv +import numpy as np +import torch +from torch import Tensor + +from .component_model.images_types import RgbMaskTuple + + +def mut_srgb_to_linear(np_array) -> None: + less = np_array <= 0.0404482362771082 + np_array[less] = np_array[less] / 12.92 + np_array[~less] = np.power((np_array[~less] + 0.055) / 1.055, 2.4) + + +def mut_linear_to_srgb(np_array) -> None: + less = np_array <= 0.0031308 + np_array[less] = np_array[less] * 12.92 + np_array[~less] = np.power(np_array[~less], 1 / 2.4) * 1.055 - 0.055 + + +def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple: + image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) + rgb = np.flip(image[:, :, :3], 2).copy() + if srgb: + mut_linear_to_srgb(rgb) + rgb = np.clip(rgb, 0, 1) + rgb = torch.unsqueeze(torch.from_numpy(rgb), 0) + + mask = torch.zeros((1, image.shape[0], image.shape[1]), dtype=torch.float32) + if image.shape[2] > 3: + mask[0] = torch.from_numpy(np.clip(image[:, :, 3], 0, 1)) + + return RgbMaskTuple(rgb, mask) + + +def load_exr_latent(file_path: str) -> Tuple[Tensor]: + image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) + image = image[:, :, np.array([2, 1, 0, 3])] + image = torch.unsqueeze(torch.from_numpy(image), 0) + image = torch.movedim(image, -1, 1) + return image, + + +def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linear"): + linear = images.detach().clone().cpu().numpy().astype(np.float32) + if colorspace == "linear": + mut_srgb_to_linear(linear[:, :, :, :3]) # only convert RGB, not Alpha + + bgr = copy.deepcopy(linear) + bgr[:, :, :, 0] = linear[:, :, :, 2] # flip RGB to BGR for opencv + bgr[:, :, :, 2] = linear[:, :, :, 0] + if bgr.shape[-1] > 3: + bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha + + for i in range(len(linear.shape[0])): + cv.imwrite(filepaths_batched[i], bgr[i]) diff --git a/comfy/utils.py b/comfy/utils.py index 98f3f602f..acfaa7088 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -334,7 +334,7 @@ def get_attr(obj, attr): def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' - + c = b1.shape[-1] #norms @@ -359,16 +359,16 @@ def bislerp(samples, width, height): res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) #edge cases for same or polar opposites - res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] return res - + def generate_bilinear_data(length_old, length_new, device): coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) - + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 coords_2[:,:,:,-1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") @@ -379,7 +379,7 @@ def bislerp(samples, width, height): samples = samples.float() n,c,h,w = samples.shape h_new, w_new = (height, width) - + #linear w ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) @@ -496,6 +496,17 @@ def set_progress_bar_global_hook(function): PROGRESS_BAR_HOOK = function +class _DisabledProgressBar: + def __init__(self, *args, **kwargs): + pass + + def update(self, *args, **kwargs): + pass + + def update_absolute(self, *args, **kwargs): + pass + + class ProgressBar: def __init__(self, total: float): global PROGRESS_BAR_HOOK @@ -545,3 +556,12 @@ def comfy_tqdm(): # Restore original tqdm tqdm.__init__ = _original_init tqdm.update = _original_update + + +@contextmanager +def comfy_progress(total: float) -> ProgressBar: + global PROGRESS_BAR_ENABLED + if PROGRESS_BAR_ENABLED: + yield ProgressBar(total) + else: + yield _DisabledProgressBar() diff --git a/comfy/web/scripts/widgets.js b/comfy/web/scripts/widgets.js index 678b1b8ec..1b9680a59 100644 --- a/comfy/web/scripts/widgets.js +++ b/comfy/web/scripts/widgets.js @@ -469,7 +469,7 @@ export const ComfyWidgets = { const fileInput = document.createElement("input"); Object.assign(fileInput, { type: "file", - accept: "image/jpeg,image/png,image/webp", + accept: "image/jpeg,image/png,image/webp,image/x-exr,.exr", style: "display: none", onchange: async () => { if (fileInput.files.length) { diff --git a/comfy_extras/nodes/nodes_apply_color_map.py b/comfy_extras/nodes/nodes_apply_color_map.py new file mode 100644 index 000000000..4c2fc1793 --- /dev/null +++ b/comfy_extras/nodes/nodes_apply_color_map.py @@ -0,0 +1,85 @@ +import cv2 +import numpy as np +import torch +from torch import Tensor + +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult + +_available_colormaps = ["Grayscale"] + [attr for attr in dir(cv2) if attr.startswith('COLORMAP')] + + +class ImageApplyColorMap(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "image": ("IMAGE", {}), + "colormap": (_available_colormaps, {"default": "COLORMAP_INFERNO"}), + "gamma": ("FLOAT", {"default": 1.0, "min": 0.001, "step": 0.001, "round": 0.001}), + "min_depth": ("FLOAT", {"default": 0.001, "min": 0.001, "round": 0.00001, "step": 0.001}), + "max_depth": ("FLOAT", {"default": 1e2, "round": 0.00001, "step": 0.1}), + "one_minus": ("BOOLEAN", {"default": False}), + "clip_min": ("BOOLEAN", {"default": True}), + "clip_max": ("BOOLEAN", {"default": False}), + } + } + + RETURN_TYPES = ("IMAGE",) + CATEGORY = "image/postprocessing" + FUNCTION = "execute" + + def execute(self, + image: Tensor, + gamma: float = 1.0, + min_depth: float = 0.001, + max_depth: float = 1e3, + colormap: str = "COLORMAP_INFERNO", + one_minus: bool = False, + clip_min: bool = True, + clip_max: bool = False, + ) -> ValidatedNodeResult: + """ + Invert and apply a colormap to a batch of absolute distance depth images. + + For Zoe and Midas, set colormap to be `COLORMAP_INFERNO`. Diffusers Depth expects `Grayscale`. + + As per https://huggingface.co/SargeZT/controlnet-v1e-sdxl-depth/discussions/7, some ControlNet checkpoints + expect one_minus to be true. + """ + colored_images = [] + + for i in range(image.shape[0]): + depth_image = image[i, :, :, 0].numpy() + depth_image = np.where(depth_image <= min_depth, np.nan if not clip_min else min_depth, depth_image) + if clip_max: + depth_image = np.where(depth_image >= max_depth, max_depth, depth_image) + depth_image = np.power(depth_image, 1.0 / gamma) + inv_depth_image = 1.0 / depth_image + + xp = [1.0 / max_depth, 1.0 / min_depth] + fp = [0, 1] + normalized_depth = np.interp(inv_depth_image, xp, fp, left=0, right=1) + normalized_depth = np.nan_to_num(normalized_depth, nan=0) + + normalized_depth_uint8 = (normalized_depth * 255).astype(np.uint8) + if one_minus: + normalized_depth_uint8 = 255 - normalized_depth_uint8 + if colormap == "Grayscale": + colored_image = normalized_depth_uint8 + else: + cv2_colormap = getattr(cv2, colormap) + colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) + colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) + rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0 + colored_images.append(rgb_tensor) + + return torch.stack(colored_images), + + +NODE_CLASS_MAPPINGS = { + ImageApplyColorMap.__name__: ImageApplyColorMap, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + ImageApplyColorMap.__name__: "Apply ColorMap to Image (CV2)", +} diff --git a/comfy_extras/nodes/nodes_arithmetic.py b/comfy_extras/nodes/nodes_arithmetic.py new file mode 100644 index 000000000..7527d5bfe --- /dev/null +++ b/comfy_extras/nodes/nodes_arithmetic.py @@ -0,0 +1,514 @@ +from functools import reduce +from operator import add, mul, pow + +from comfy.nodes.package_typing import CustomNode, InputTypes + + +class FloatAdd(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(add, kwargs.values(), 0.0),) + + +class FloatSubtract(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("FLOAT", {}), + "value1": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 - value1,) + + +class FloatMultiply(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 1.0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(mul, kwargs.values(), 1.0),) + + +class FloatDivide(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("FLOAT", {}), + "value1": ("FLOAT", {"default": 1.0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 / value1 if value1 != 0 else float("inf"),) + + +class FloatPower(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "base": ("FLOAT", {}), + "exponent": ("FLOAT", {"default": 1.0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, base, exponent): + return (pow(base, exponent),) + + +class IntAdd(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(add, kwargs.values(), 0),) + + +class IntSubtract(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("INT", {}), + "value1": ("INT", {"default": 0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 - value1,) + + +class IntMultiply(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 1}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(mul, kwargs.values(), 1),) + + +class IntDivide(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("INT", {}), + "value1": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 // value1 if value1 != 0 else 0,) + + +class IntMod(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("INT", {}), + "value1": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 % value1 if value1 != 0 else 0,) + + +class IntPower(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "base": ("INT", {}), + "exponent": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, base, exponent): + return (pow(base, exponent),) + + +class FloatMin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (min(kwargs.values()),) + + +class FloatMax(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (max(kwargs.values()),) + + +class FloatAbs(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("FLOAT", {}) + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value): + return (abs(value),) + + +class FloatAverage(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (sum(kwargs.values()) / len(kwargs),) + + +class IntMin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (min(kwargs.values()),) + + +class IntMax(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (max(kwargs.values()),) + + +class IntAbs(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("INT", {}) + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value): + return (abs(value),) + + +class IntAverage(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (sum(kwargs.values()) // len(kwargs),) + + +class FloatLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + "b": ("FLOAT", {"default": 1.0}), + "t": ("FLOAT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, a, b, t, clamped): + value = a + (b - a) * t + if clamped: + value = min(max(value, a), b) + return (value,) + + +class FloatInverseLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + "b": ("FLOAT", {"default": 1.0}), + "value": ("FLOAT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, a, b, value, clamped): + if a == b: + return (0.0,) + t = (value - a) / (b - a) + if clamped: + t = min(max(t, 0.0), 1.0) + return (t,) + + +class FloatClamp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("FLOAT", {}), + "min": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + "max": ("FLOAT", {"default": 1.0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value: float = 0, **kwargs): + v_min: float = kwargs['min'] + v_max: float = kwargs['max'] + return (min(max(value, v_min), v_max),) + + +class IntLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("INT", {"default": 0}), + "b": ("INT", {"default": 10}), + "t": ("FLOAT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, a, b, t, clamped): + value = int(round(a + (b - a) * t)) + if clamped: + value = min(max(value, a), b) + return (value,) + + +class IntInverseLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("INT", {"default": 0}), + "b": ("INT", {"default": 10}), + "value": ("INT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, a, b, value, clamped): + if a == b: + return (0.0,) + t = (value - a) / (b - a) + if clamped: + t = min(max(t, 0.0), 1.0) + return (t,) + + +class IntClamp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("INT", {}), + "min": ("INT", {"default": 0}), + "max": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value: int = 0, **kwargs): + v_min: int = kwargs['min'] + v_max: int = kwargs['max'] + + return (min(max(value, v_min), v_max),) + + +NODE_CLASS_MAPPINGS = {} +for cls in ( + FloatAdd, + FloatSubtract, + FloatMultiply, + FloatDivide, + FloatPower, + FloatMin, + FloatMax, + FloatAbs, + FloatAverage, + FloatLerp, + FloatInverseLerp, + FloatClamp, + IntAdd, + IntSubtract, + IntMultiply, + IntDivide, + IntMod, + IntPower, + IntMin, + IntMax, + IntAbs, + IntAverage, + IntLerp, + IntInverseLerp, + IntClamp, +): + NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/comfy_extras/nodes/nodes_freelunch.py b/comfy_extras/nodes/nodes_freelunch.py index 6f1d87bf3..e4a0b8021 100644 --- a/comfy_extras/nodes/nodes_freelunch.py +++ b/comfy_extras/nodes/nodes_freelunch.py @@ -1,4 +1,28 @@ -#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) +""" +Portions of this code are adapted from the repository +https://github.com/ChenyangSi/FreeU + +MIT License + + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" import torch import logging diff --git a/comfy_extras/nodes/nodes_image_arithmetic.py b/comfy_extras/nodes/nodes_image_arithmetic.py new file mode 100644 index 000000000..d083bb340 --- /dev/null +++ b/comfy_extras/nodes/nodes_image_arithmetic.py @@ -0,0 +1,48 @@ +from torch import Tensor + +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult + + +class ImageMin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "image": ("IMAGE", {}) + } + } + + RETURN_TYPES = ("FLOAT",) + CATEGORY = "image/postprocessing" + FUNCTION = "execute" + + def execute(self, image: Tensor) -> ValidatedNodeResult: + return float(image.min().item()), + + +class ImageMax(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "image": ("IMAGE", {}) + } + } + + RETURN_TYPES = ("FLOAT",) + CATEGORY = "image/postprocessing" + FUNCTION = "execute" + + def execute(self, image: Tensor) -> ValidatedNodeResult: + return float(image.max().item()), + + +NODE_CLASS_MAPPINGS = { + ImageMin.__name__: ImageMin, + ImageMax.__name__: ImageMax, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + ImageMin.__name__: "Image Minimum Value", + ImageMax.__name__: "Image Maximum Value" +} diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 095c303fb..b5bb8e63e 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -16,12 +16,12 @@ import fsspec import numpy as np from PIL import Image from PIL.PngImagePlugin import PngInfo -from fsspec.core import OpenFiles, OpenFile +from fsspec.core import OpenFile from fsspec.generic import GenericFileSystem from fsspec.implementations.local import LocalFileSystem from joblib import Parallel, delayed -from torch import Tensor from natsort import natsorted +from torch import Tensor from comfy.cmd import folder_paths from comfy.digest import digest diff --git a/comfy_extras/nodes/nodes_sag.py b/comfy_extras/nodes/nodes_sag.py index 16ccf04b1..e520e13e4 100644 --- a/comfy_extras/nodes/nodes_sag.py +++ b/comfy_extras/nodes/nodes_sag.py @@ -97,8 +97,8 @@ class SelfAttentionGuidance: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), - "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01, "round": 0.01}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" diff --git a/requirements.txt b/requirements.txt index 2107d9046..dde0ec0cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,5 @@ huggingface_hub lazy-object-proxy can_ada fsspec -natsort \ No newline at end of file +natsort +OpenEXR \ No newline at end of file diff --git a/tests/nodes/test_arithmetic_unit.py b/tests/nodes/test_arithmetic_unit.py new file mode 100644 index 000000000..eb071e111 --- /dev/null +++ b/tests/nodes/test_arithmetic_unit.py @@ -0,0 +1,198 @@ +import pytest + +from comfy_extras.nodes.nodes_arithmetic import IntAdd, IntSubtract, IntMultiply, IntDivide, IntMod, IntPower, FloatAdd, FloatSubtract, FloatMultiply, FloatDivide, FloatPower, FloatMin, FloatMax, FloatAbs, FloatAverage, IntMin, IntMax, IntAbs, IntAverage, FloatLerp, IntLerp, IntClamp, IntInverseLerp, FloatClamp, FloatInverseLerp + + +def test_int_add(): + n = IntAdd() + res, = n.execute(value0=1, value1=2, value2=3) + assert res == 6 + + +def test_int_subtract(): + n = IntSubtract() + res, = n.execute(value0=10, value1=3) + assert res == 7 + + +def test_int_multiply(): + n = IntMultiply() + res, = n.execute(value0=2, value1=3, value2=4) + assert res == 24 + + +def test_int_divide(): + n = IntDivide() + res, = n.execute(value0=10, value1=3) + assert res == 3 + + res, = n.execute(value0=10, value1=0) + assert res == 0 + + +def test_int_mod(): + n = IntMod() + res, = n.execute(value0=10, value1=3) + assert res == 1 + + res, = n.execute(value0=10, value1=0) + assert res == 0 + + +def test_int_power(): + n = IntPower() + res, = n.execute(base=2, exponent=3) + assert res == 8 + + +def test_float_add(): + n = FloatAdd() + res, = n.execute(value0=1.5, value1=2.3, value2=3.7) + assert pytest.approx(res) == 7.5 + + +def test_float_subtract(): + n = FloatSubtract() + res, = n.execute(value0=10.5, value1=3.2) + assert pytest.approx(res) == 7.3 + + +def test_float_multiply(): + n = FloatMultiply() + res, = n.execute(value0=2.5, value1=3.0, value2=4.0) + assert pytest.approx(res) == 30.0 + + +def test_float_divide(): + n = FloatDivide() + res, = n.execute(value0=10.0, value1=4.0) + assert pytest.approx(res) == 2.5 + + res, = n.execute(value0=10.0, value1=0.0) + assert res == float("inf") + + +def test_float_power(): + n = FloatPower() + res, = n.execute(base=2.5, exponent=3.0) + assert pytest.approx(res) == 15.625 + + +def test_float_min(): + n = FloatMin() + res, = n.execute(value0=1.5, value1=2.3, value2=0.7) + assert res == 0.7 + + +def test_float_max(): + n = FloatMax() + res, = n.execute(value0=1.5, value1=2.3, value2=0.7) + assert res == 2.3 + + +def test_float_abs(): + n = FloatAbs() + res, = n.execute(value=-3.14) + assert res == 3.14 + + +def test_float_average(): + n = FloatAverage() + res, = n.execute(value0=1.5, value1=2.5, value2=3.5) + assert res == 2.5 + + +def test_int_min(): + n = IntMin() + res, = n.execute(value0=5, value1=2, value2=7) + assert res == 2 + + +def test_int_max(): + n = IntMax() + res, = n.execute(value0=5, value1=2, value2=7) + assert res == 7 + + +def test_int_abs(): + n = IntAbs() + res, = n.execute(value=-10) + assert res == 10 + + +def test_int_average(): + n = IntAverage() + res, = n.execute(value0=2, value1=4, value2=6) + assert res == 4 + + +def test_float_lerp(): + n = FloatLerp() + res, = n.execute(a=0.0, b=1.0, t=0.5, clamped=True) + assert res == 0.5 + + res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=True) + assert res == 1.0 + + res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=False) + assert res == 1.5 + + +def test_int_lerp(): + n = IntLerp() + res, = n.execute(a=0, b=10, t=0.5, clamped=True) + assert res == 5 + + res, = n.execute(a=0, b=10, t=1.5, clamped=True) + assert res == 10 + + res, = n.execute(a=0, b=10, t=1.5, clamped=False) + assert res == 15 + + +def test_float_inverse_lerp(): + n = FloatInverseLerp() + res, = n.execute(a=0.0, b=1.0, value=0.5, clamped=True) + assert res == 0.5 + + res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=True) + assert res == 1.0 + + res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=False) + assert res == 1.5 + + +def test_float_clamp(): + n = FloatClamp() + res, = n.execute(value=0.5, min=0.0, max=1.0) + assert res == 0.5 + + res, = n.execute(value=1.5, min=0.0, max=1.0) + assert res == 1.0 + + res, = n.execute(value=-0.5, min=0.0, max=1.0) + assert res == 0.0 + + +def test_int_inverse_lerp(): + n = IntInverseLerp() + res, = n.execute(a=0, b=10, value=5, clamped=True) + assert res == 0.5 + + res, = n.execute(a=0, b=10, value=15, clamped=True) + assert res == 1.0 + + res, = n.execute(a=0, b=10, value=15, clamped=False) + assert res == 1.5 + + +def test_int_clamp(): + n = IntClamp() + res, = n.execute(value=5, min=0, max=10) + assert res == 5 + + res, = n.execute(value=15, min=0, max=10) + assert res == 10 + + res, = n.execute(value=-5, min=0, max=10) + assert res == 0 diff --git a/tests/nodes/test_colormap_unit.py b/tests/nodes/test_colormap_unit.py new file mode 100644 index 000000000..16805185e --- /dev/null +++ b/tests/nodes/test_colormap_unit.py @@ -0,0 +1,49 @@ +import pytest +import torch +from comfy_extras.nodes.nodes_apply_color_map import ImageApplyColorMap + + +@pytest.fixture +def input_image(): + # Create a 1x1x2x1 tensor representing an image with absolute distances of 1.3 meters and 300 meters + return torch.tensor([[[[1.3], [300.0]]]], dtype=torch.float32) + + +def test_apply_colormap_grayscale(input_image): + node = ImageApplyColorMap() + colored_image, = node.execute(image=input_image, colormap="Grayscale", min_depth=1.3, max_depth=300.0) + + assert colored_image.shape == (1, 1, 2, 3) + assert colored_image.dtype == torch.float32 + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([1.0, 1.0, 1.0])) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0])) + + +def test_apply_colormap_inferno(input_image): + node = ImageApplyColorMap() + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", min_depth=1.3, max_depth=300.0) + + assert colored_image.shape == (1, 1, 2, 3) + assert colored_image.dtype == torch.float32 + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.000, 0.6431]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4) + + +def test_apply_colormap_clipping(input_image): + node = ImageApplyColorMap() + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=False, min_depth=1.3, max_depth=300.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=False, min_depth=1.3, max_depth=300.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4) + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=True, min_depth=1.3, max_depth=200.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=True, min_depth=1.3, max_depth=200.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)