diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index e4271a98a..34e75c04d 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -100,7 +100,7 @@ class EmbeddedComfyClient: if self._configuration is None: options.enable_args_parsing() else: - from ..cli_args import args + from ..cmd.main_pre import args args.clear() args.update(self._configuration) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 4d929f996..efb03e599 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,4 +1,3 @@ -# Suppress warnings during import import asyncio import gc import itertools @@ -8,9 +7,10 @@ import shutil import threading import time +# main_pre must be the earliest import since it suppresses some spurious warnings +from .main_pre import args from ..utils import hijack_progress from .extra_model_paths import load_extra_path_config -from .main_pre import args from .. import model_management from ..analytics.analytics import initialize_event_tracking from ..cmd import cuda_malloc diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index b78ddf231..51408b682 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -1,3 +1,12 @@ +""" +This should be imported before entrypoints to correctly configure global options prior to importing packages like torch and cv2. + +Use this instead of cli_args to import the args: + +>>> from comfy.cmd.main_pre import args + +It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes. +""" import os from .. import options @@ -9,6 +18,8 @@ options.enable_args_parsing() if os.name == "nt": logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") +warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.") +warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") from ..cli_args import args @@ -20,4 +31,5 @@ if args.deterministic: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" __all__ = ["args"] diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index d762c4609..84a71fb6f 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -38,6 +38,7 @@ from ..component_model.file_output_path import file_output_path from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation from ..digest import digest from ..nodes.package_typing import ExportedNodes +from ..images import open_image class HeuristicPath(NamedTuple): @@ -289,9 +290,12 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=400) if os.path.isfile(file): - if 'preview' in request.rel_url.query: - with Image.open(file) as img: - preview_info = request.rel_url.query['preview'].split(';') + # todo: any image file we upload that browsers don't support, we should encode a preview + # todo: image handling has to be a little bit more standardized, sometimes we want a Pillow Image, sometimes + # we want something that will render to the user, sometimes we want tensors + if 'preview' in request.rel_url.query or file.endswith(".exr"): + with open_image(file) as img: + preview_info = request.rel_url.query.get("preview", "jpeg;90").split(';') image_format = preview_info[0] if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''): image_format = 'webp' diff --git a/comfy/cmd/worker.py b/comfy/cmd/worker.py index 51224c0cc..b3a2f6cf7 100644 --- a/comfy/cmd/worker.py +++ b/comfy/cmd/worker.py @@ -4,11 +4,7 @@ import os import logging from .extra_model_paths import load_extra_path_config -from .. import options - -options.enable_args_parsing() - -from ..cli_args import args +from .main_pre import args async def main(): @@ -17,28 +13,19 @@ async def main(): args.distributed_queue_frontend = False assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server" - - if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - logging.info(f"Set cuda device to: {args.cuda_device}") - - if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" - # configure paths if args.output_directory: output_dir = os.path.abspath(args.output_directory) logging.info(f"Setting output directory to: {output_dir}") from ..cmd import folder_paths - + folder_paths.set_output_directory(output_dir) - + if args.input_directory: input_dir = os.path.abspath(args.input_directory) logging.info(f"Setting input directory to: {input_dir}") from ..cmd import folder_paths - + folder_paths.set_input_directory(input_dir) if args.temp_directory: diff --git a/comfy/component_model/images_types.py b/comfy/component_model/images_types.py new file mode 100644 index 000000000..8ef892059 --- /dev/null +++ b/comfy/component_model/images_types.py @@ -0,0 +1,8 @@ +from typing import NamedTuple + +from torch import Tensor + + +class RgbMaskTuple(NamedTuple): + rgb: Tensor + mask: Tensor diff --git a/comfy/images.py b/comfy/images.py new file mode 100644 index 000000000..660de0a03 --- /dev/null +++ b/comfy/images.py @@ -0,0 +1,19 @@ +import os.path +from contextlib import contextmanager + +import cv2 +from PIL import Image + + +def _open_exr(exr_path) -> Image.Image: + return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) + + +@contextmanager +def open_image(file_path: str) -> Image.Image: + _, ext = os.path.splitext(file_path) + if ext == ".exr": + yield _open_exr(file_path) + else: + with Image.open(file_path) as image: + yield image diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 7af016829..40d7e8e0e 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -755,7 +755,32 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n @torch.no_grad() def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): - # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ + """ + Portions of this function are adapted from the repository + https://github.com/Carzit/sd-webui-samplers-scheduler + + MIT License + + Copyright (c) 2023 Carzit + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) s_end = sigmas[-1] diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 96a1061fc..8c9971076 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -51,7 +51,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi else: linked_filename = None try: - os.symlink(os.path.join(destination,known_file.filename), linked_filename) + os.symlink(os.path.join(destination, known_file.filename), linked_filename) except Exception as exc_info: logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info) else: @@ -213,6 +213,7 @@ KNOWN_CONTROLNETS = [ HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_faid_vidit.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_zeed.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_softedge.safetensors"), + HuggingFile("SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe", "depth-zoe-xl-v1.0-controlnet.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_canny.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_midas.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_zoe.safetensors"), diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 788af7eb8..554a325e5 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -2,7 +2,7 @@ from __future__ import annotations import dataclasses from os.path import split -from typing import Optional, List +from typing import Optional, List, Sequence from typing_extensions import TypedDict, NotRequired @@ -15,10 +15,12 @@ class CivitFile: model_id (int): The ID of the model model_version_id (int): The version filename (str): The name of the file in the model + trigger_words (List[str]): Trigger words associated with the model """ model_id: int model_version_id: int filename: str + trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple) def __str__(self): return self.filename diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 0c6bdb16d..081fd1d2b 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -9,6 +9,7 @@ import logging from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo +from natsort import natsorted from pkg_resources import resource_filename import numpy as np import safetensors.torch @@ -23,10 +24,13 @@ from .. import model_management from ..cli_args import args from ..cmd import folder_paths, latent_preview +from ..images import open_image from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS from ..nodes.common import MAX_RESOLUTION from .. import controlnet +from ..open_exr import load_exr + class CLIPTextEncode: @classmethod @@ -1454,38 +1458,49 @@ class PreviewImage(SaveImage): "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - return {"required": - {"image": (sorted(files), {"image_upload": True})}, - } + return { + "required": { + "image": (natsorted(files), {"image_upload": True}), + }, + } CATEGORY = "image" RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" - def load_image(self, image): + + def load_image(self, image: str): image_path = folder_paths.get_annotated_filepath(image) - img = Image.open(image_path) output_images = [] output_masks = [] - for i in ImageSequence.Iterator(img): - i = ImageOps.exif_transpose(i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - output_images.append(image) - output_masks.append(mask.unsqueeze(0)) + + # maintain the legacy path + # this will ultimately return a tensor, so we'd rather have the tensors directly + # from cv2 rather than get them out of a PIL image + _, ext = os.path.splitext(image) + if ext == ".exr": + return load_exr(image_path, srgb=False) + with open_image(image_path) as img: + for i in ImageSequence.Iterator(img): + i = ImageOps.exif_transpose(i) + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) @@ -1494,7 +1509,7 @@ class LoadImage: output_image = output_images[0] output_mask = output_masks[0] - return (output_image, output_mask) + return output_image, output_mask @classmethod def IS_CHANGED(s, image): diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index f993b9945..6de60703c 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -53,7 +53,7 @@ BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions] ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]] -NonPrimitiveTypeSpec = Tuple[CommonReturnTypes] +NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any] InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec] diff --git a/comfy/open_exr.py b/comfy/open_exr.py new file mode 100644 index 000000000..acd76f1bb --- /dev/null +++ b/comfy/open_exr.py @@ -0,0 +1,86 @@ +""" +Portions of this code are adapted from the repository +https://github.com/spacepxl/ComfyUI-HQ-Image-Save + +MIT License + +Copyright (c) 2023 spacepxl + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import copy +from typing import Sequence, Tuple + +import cv2 as cv +import numpy as np +import torch +from torch import Tensor + +from .component_model.images_types import RgbMaskTuple + + +def mut_srgb_to_linear(np_array) -> None: + less = np_array <= 0.0404482362771082 + np_array[less] = np_array[less] / 12.92 + np_array[~less] = np.power((np_array[~less] + 0.055) / 1.055, 2.4) + + +def mut_linear_to_srgb(np_array) -> None: + less = np_array <= 0.0031308 + np_array[less] = np_array[less] * 12.92 + np_array[~less] = np.power(np_array[~less], 1 / 2.4) * 1.055 - 0.055 + + +def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple: + image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) + rgb = np.flip(image[:, :, :3], 2).copy() + if srgb: + mut_linear_to_srgb(rgb) + rgb = np.clip(rgb, 0, 1) + rgb = torch.unsqueeze(torch.from_numpy(rgb), 0) + + mask = torch.zeros((1, image.shape[0], image.shape[1]), dtype=torch.float32) + if image.shape[2] > 3: + mask[0] = torch.from_numpy(np.clip(image[:, :, 3], 0, 1)) + + return RgbMaskTuple(rgb, mask) + + +def load_exr_latent(file_path: str) -> Tuple[Tensor]: + image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) + image = image[:, :, np.array([2, 1, 0, 3])] + image = torch.unsqueeze(torch.from_numpy(image), 0) + image = torch.movedim(image, -1, 1) + return image, + + +def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linear"): + linear = images.detach().clone().cpu().numpy().astype(np.float32) + if colorspace == "linear": + mut_srgb_to_linear(linear[:, :, :, :3]) # only convert RGB, not Alpha + + bgr = copy.deepcopy(linear) + bgr[:, :, :, 0] = linear[:, :, :, 2] # flip RGB to BGR for opencv + bgr[:, :, :, 2] = linear[:, :, :, 0] + if bgr.shape[-1] > 3: + bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha + + for i in range(len(linear.shape[0])): + cv.imwrite(filepaths_batched[i], bgr[i]) diff --git a/comfy/utils.py b/comfy/utils.py index 98f3f602f..acfaa7088 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -334,7 +334,7 @@ def get_attr(obj, attr): def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' - + c = b1.shape[-1] #norms @@ -359,16 +359,16 @@ def bislerp(samples, width, height): res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) #edge cases for same or polar opposites - res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] return res - + def generate_bilinear_data(length_old, length_new, device): coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) - + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 coords_2[:,:,:,-1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") @@ -379,7 +379,7 @@ def bislerp(samples, width, height): samples = samples.float() n,c,h,w = samples.shape h_new, w_new = (height, width) - + #linear w ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) @@ -496,6 +496,17 @@ def set_progress_bar_global_hook(function): PROGRESS_BAR_HOOK = function +class _DisabledProgressBar: + def __init__(self, *args, **kwargs): + pass + + def update(self, *args, **kwargs): + pass + + def update_absolute(self, *args, **kwargs): + pass + + class ProgressBar: def __init__(self, total: float): global PROGRESS_BAR_HOOK @@ -545,3 +556,12 @@ def comfy_tqdm(): # Restore original tqdm tqdm.__init__ = _original_init tqdm.update = _original_update + + +@contextmanager +def comfy_progress(total: float) -> ProgressBar: + global PROGRESS_BAR_ENABLED + if PROGRESS_BAR_ENABLED: + yield ProgressBar(total) + else: + yield _DisabledProgressBar() diff --git a/comfy/web/scripts/widgets.js b/comfy/web/scripts/widgets.js index 678b1b8ec..1b9680a59 100644 --- a/comfy/web/scripts/widgets.js +++ b/comfy/web/scripts/widgets.js @@ -469,7 +469,7 @@ export const ComfyWidgets = { const fileInput = document.createElement("input"); Object.assign(fileInput, { type: "file", - accept: "image/jpeg,image/png,image/webp", + accept: "image/jpeg,image/png,image/webp,image/x-exr,.exr", style: "display: none", onchange: async () => { if (fileInput.files.length) { diff --git a/comfy_extras/nodes/nodes_apply_color_map.py b/comfy_extras/nodes/nodes_apply_color_map.py new file mode 100644 index 000000000..4c2fc1793 --- /dev/null +++ b/comfy_extras/nodes/nodes_apply_color_map.py @@ -0,0 +1,85 @@ +import cv2 +import numpy as np +import torch +from torch import Tensor + +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult + +_available_colormaps = ["Grayscale"] + [attr for attr in dir(cv2) if attr.startswith('COLORMAP')] + + +class ImageApplyColorMap(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "image": ("IMAGE", {}), + "colormap": (_available_colormaps, {"default": "COLORMAP_INFERNO"}), + "gamma": ("FLOAT", {"default": 1.0, "min": 0.001, "step": 0.001, "round": 0.001}), + "min_depth": ("FLOAT", {"default": 0.001, "min": 0.001, "round": 0.00001, "step": 0.001}), + "max_depth": ("FLOAT", {"default": 1e2, "round": 0.00001, "step": 0.1}), + "one_minus": ("BOOLEAN", {"default": False}), + "clip_min": ("BOOLEAN", {"default": True}), + "clip_max": ("BOOLEAN", {"default": False}), + } + } + + RETURN_TYPES = ("IMAGE",) + CATEGORY = "image/postprocessing" + FUNCTION = "execute" + + def execute(self, + image: Tensor, + gamma: float = 1.0, + min_depth: float = 0.001, + max_depth: float = 1e3, + colormap: str = "COLORMAP_INFERNO", + one_minus: bool = False, + clip_min: bool = True, + clip_max: bool = False, + ) -> ValidatedNodeResult: + """ + Invert and apply a colormap to a batch of absolute distance depth images. + + For Zoe and Midas, set colormap to be `COLORMAP_INFERNO`. Diffusers Depth expects `Grayscale`. + + As per https://huggingface.co/SargeZT/controlnet-v1e-sdxl-depth/discussions/7, some ControlNet checkpoints + expect one_minus to be true. + """ + colored_images = [] + + for i in range(image.shape[0]): + depth_image = image[i, :, :, 0].numpy() + depth_image = np.where(depth_image <= min_depth, np.nan if not clip_min else min_depth, depth_image) + if clip_max: + depth_image = np.where(depth_image >= max_depth, max_depth, depth_image) + depth_image = np.power(depth_image, 1.0 / gamma) + inv_depth_image = 1.0 / depth_image + + xp = [1.0 / max_depth, 1.0 / min_depth] + fp = [0, 1] + normalized_depth = np.interp(inv_depth_image, xp, fp, left=0, right=1) + normalized_depth = np.nan_to_num(normalized_depth, nan=0) + + normalized_depth_uint8 = (normalized_depth * 255).astype(np.uint8) + if one_minus: + normalized_depth_uint8 = 255 - normalized_depth_uint8 + if colormap == "Grayscale": + colored_image = normalized_depth_uint8 + else: + cv2_colormap = getattr(cv2, colormap) + colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap) + colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB) + rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0 + colored_images.append(rgb_tensor) + + return torch.stack(colored_images), + + +NODE_CLASS_MAPPINGS = { + ImageApplyColorMap.__name__: ImageApplyColorMap, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + ImageApplyColorMap.__name__: "Apply ColorMap to Image (CV2)", +} diff --git a/comfy_extras/nodes/nodes_arithmetic.py b/comfy_extras/nodes/nodes_arithmetic.py new file mode 100644 index 000000000..7527d5bfe --- /dev/null +++ b/comfy_extras/nodes/nodes_arithmetic.py @@ -0,0 +1,514 @@ +from functools import reduce +from operator import add, mul, pow + +from comfy.nodes.package_typing import CustomNode, InputTypes + + +class FloatAdd(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(add, kwargs.values(), 0.0),) + + +class FloatSubtract(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("FLOAT", {}), + "value1": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 - value1,) + + +class FloatMultiply(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 1.0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(mul, kwargs.values(), 1.0),) + + +class FloatDivide(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("FLOAT", {}), + "value1": ("FLOAT", {"default": 1.0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 / value1 if value1 != 0 else float("inf"),) + + +class FloatPower(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "base": ("FLOAT", {}), + "exponent": ("FLOAT", {"default": 1.0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, base, exponent): + return (pow(base, exponent),) + + +class IntAdd(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(add, kwargs.values(), 0),) + + +class IntSubtract(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("INT", {}), + "value1": ("INT", {"default": 0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 - value1,) + + +class IntMultiply(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 1}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (reduce(mul, kwargs.values(), 1),) + + +class IntDivide(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("INT", {}), + "value1": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 // value1 if value1 != 0 else 0,) + + +class IntMod(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value0": ("INT", {}), + "value1": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value0, value1): + return (value0 % value1 if value1 != 0 else 0,) + + +class IntPower(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "base": ("INT", {}), + "exponent": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, base, exponent): + return (pow(base, exponent),) + + +class FloatMin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (min(kwargs.values()),) + + +class FloatMax(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (max(kwargs.values()),) + + +class FloatAbs(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("FLOAT", {}) + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value): + return (abs(value),) + + +class FloatAverage(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("FLOAT", {})} + range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (sum(kwargs.values()) / len(kwargs),) + + +class IntMin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (min(kwargs.values()),) + + +class IntMax(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (max(kwargs.values()),) + + +class IntAbs(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("INT", {}) + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value): + return (abs(value),) + + +class IntAverage(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": ("INT", {})} + range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, **kwargs): + return (sum(kwargs.values()) // len(kwargs),) + + +class FloatLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + "b": ("FLOAT", {"default": 1.0}), + "t": ("FLOAT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, a, b, t, clamped): + value = a + (b - a) * t + if clamped: + value = min(max(value, a), b) + return (value,) + + +class FloatInverseLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + "b": ("FLOAT", {"default": 1.0}), + "value": ("FLOAT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, a, b, value, clamped): + if a == b: + return (0.0,) + t = (value - a) / (b - a) + if clamped: + t = min(max(t, 0.0), 1.0) + return (t,) + + +class FloatClamp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("FLOAT", {}), + "min": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}), + "max": ("FLOAT", {"default": 1.0}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value: float = 0, **kwargs): + v_min: float = kwargs['min'] + v_max: float = kwargs['max'] + return (min(max(value, v_min), v_max),) + + +class IntLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("INT", {"default": 0}), + "b": ("INT", {"default": 10}), + "t": ("FLOAT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, a, b, t, clamped): + value = int(round(a + (b - a) * t)) + if clamped: + value = min(max(value, a), b) + return (value,) + + +class IntInverseLerp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "a": ("INT", {"default": 0}), + "b": ("INT", {"default": 10}), + "value": ("INT", {}), + "clamped": ("BOOLEAN", {"default": True}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, a, b, value, clamped): + if a == b: + return (0.0,) + t = (value - a) / (b - a) + if clamped: + t = min(max(t, 0.0), 1.0) + return (t,) + + +class IntClamp(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("INT", {}), + "min": ("INT", {"default": 0}), + "max": ("INT", {"default": 1}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value: int = 0, **kwargs): + v_min: int = kwargs['min'] + v_max: int = kwargs['max'] + + return (min(max(value, v_min), v_max),) + + +NODE_CLASS_MAPPINGS = {} +for cls in ( + FloatAdd, + FloatSubtract, + FloatMultiply, + FloatDivide, + FloatPower, + FloatMin, + FloatMax, + FloatAbs, + FloatAverage, + FloatLerp, + FloatInverseLerp, + FloatClamp, + IntAdd, + IntSubtract, + IntMultiply, + IntDivide, + IntMod, + IntPower, + IntMin, + IntMax, + IntAbs, + IntAverage, + IntLerp, + IntInverseLerp, + IntClamp, +): + NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/comfy_extras/nodes/nodes_freelunch.py b/comfy_extras/nodes/nodes_freelunch.py index 6f1d87bf3..e4a0b8021 100644 --- a/comfy_extras/nodes/nodes_freelunch.py +++ b/comfy_extras/nodes/nodes_freelunch.py @@ -1,4 +1,28 @@ -#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) +""" +Portions of this code are adapted from the repository +https://github.com/ChenyangSi/FreeU + +MIT License + + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" import torch import logging diff --git a/comfy_extras/nodes/nodes_image_arithmetic.py b/comfy_extras/nodes/nodes_image_arithmetic.py new file mode 100644 index 000000000..d083bb340 --- /dev/null +++ b/comfy_extras/nodes/nodes_image_arithmetic.py @@ -0,0 +1,48 @@ +from torch import Tensor + +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult + + +class ImageMin(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "image": ("IMAGE", {}) + } + } + + RETURN_TYPES = ("FLOAT",) + CATEGORY = "image/postprocessing" + FUNCTION = "execute" + + def execute(self, image: Tensor) -> ValidatedNodeResult: + return float(image.min().item()), + + +class ImageMax(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "image": ("IMAGE", {}) + } + } + + RETURN_TYPES = ("FLOAT",) + CATEGORY = "image/postprocessing" + FUNCTION = "execute" + + def execute(self, image: Tensor) -> ValidatedNodeResult: + return float(image.max().item()), + + +NODE_CLASS_MAPPINGS = { + ImageMin.__name__: ImageMin, + ImageMax.__name__: ImageMax, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + ImageMin.__name__: "Image Minimum Value", + ImageMax.__name__: "Image Maximum Value" +} diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 095c303fb..b5bb8e63e 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -16,12 +16,12 @@ import fsspec import numpy as np from PIL import Image from PIL.PngImagePlugin import PngInfo -from fsspec.core import OpenFiles, OpenFile +from fsspec.core import OpenFile from fsspec.generic import GenericFileSystem from fsspec.implementations.local import LocalFileSystem from joblib import Parallel, delayed -from torch import Tensor from natsort import natsorted +from torch import Tensor from comfy.cmd import folder_paths from comfy.digest import digest diff --git a/comfy_extras/nodes/nodes_sag.py b/comfy_extras/nodes/nodes_sag.py index 16ccf04b1..e520e13e4 100644 --- a/comfy_extras/nodes/nodes_sag.py +++ b/comfy_extras/nodes/nodes_sag.py @@ -97,8 +97,8 @@ class SelfAttentionGuidance: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), - "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01, "round": 0.01}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" diff --git a/requirements.txt b/requirements.txt index 2107d9046..dde0ec0cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,5 @@ huggingface_hub lazy-object-proxy can_ada fsspec -natsort \ No newline at end of file +natsort +OpenEXR \ No newline at end of file diff --git a/tests/nodes/test_arithmetic_unit.py b/tests/nodes/test_arithmetic_unit.py new file mode 100644 index 000000000..eb071e111 --- /dev/null +++ b/tests/nodes/test_arithmetic_unit.py @@ -0,0 +1,198 @@ +import pytest + +from comfy_extras.nodes.nodes_arithmetic import IntAdd, IntSubtract, IntMultiply, IntDivide, IntMod, IntPower, FloatAdd, FloatSubtract, FloatMultiply, FloatDivide, FloatPower, FloatMin, FloatMax, FloatAbs, FloatAverage, IntMin, IntMax, IntAbs, IntAverage, FloatLerp, IntLerp, IntClamp, IntInverseLerp, FloatClamp, FloatInverseLerp + + +def test_int_add(): + n = IntAdd() + res, = n.execute(value0=1, value1=2, value2=3) + assert res == 6 + + +def test_int_subtract(): + n = IntSubtract() + res, = n.execute(value0=10, value1=3) + assert res == 7 + + +def test_int_multiply(): + n = IntMultiply() + res, = n.execute(value0=2, value1=3, value2=4) + assert res == 24 + + +def test_int_divide(): + n = IntDivide() + res, = n.execute(value0=10, value1=3) + assert res == 3 + + res, = n.execute(value0=10, value1=0) + assert res == 0 + + +def test_int_mod(): + n = IntMod() + res, = n.execute(value0=10, value1=3) + assert res == 1 + + res, = n.execute(value0=10, value1=0) + assert res == 0 + + +def test_int_power(): + n = IntPower() + res, = n.execute(base=2, exponent=3) + assert res == 8 + + +def test_float_add(): + n = FloatAdd() + res, = n.execute(value0=1.5, value1=2.3, value2=3.7) + assert pytest.approx(res) == 7.5 + + +def test_float_subtract(): + n = FloatSubtract() + res, = n.execute(value0=10.5, value1=3.2) + assert pytest.approx(res) == 7.3 + + +def test_float_multiply(): + n = FloatMultiply() + res, = n.execute(value0=2.5, value1=3.0, value2=4.0) + assert pytest.approx(res) == 30.0 + + +def test_float_divide(): + n = FloatDivide() + res, = n.execute(value0=10.0, value1=4.0) + assert pytest.approx(res) == 2.5 + + res, = n.execute(value0=10.0, value1=0.0) + assert res == float("inf") + + +def test_float_power(): + n = FloatPower() + res, = n.execute(base=2.5, exponent=3.0) + assert pytest.approx(res) == 15.625 + + +def test_float_min(): + n = FloatMin() + res, = n.execute(value0=1.5, value1=2.3, value2=0.7) + assert res == 0.7 + + +def test_float_max(): + n = FloatMax() + res, = n.execute(value0=1.5, value1=2.3, value2=0.7) + assert res == 2.3 + + +def test_float_abs(): + n = FloatAbs() + res, = n.execute(value=-3.14) + assert res == 3.14 + + +def test_float_average(): + n = FloatAverage() + res, = n.execute(value0=1.5, value1=2.5, value2=3.5) + assert res == 2.5 + + +def test_int_min(): + n = IntMin() + res, = n.execute(value0=5, value1=2, value2=7) + assert res == 2 + + +def test_int_max(): + n = IntMax() + res, = n.execute(value0=5, value1=2, value2=7) + assert res == 7 + + +def test_int_abs(): + n = IntAbs() + res, = n.execute(value=-10) + assert res == 10 + + +def test_int_average(): + n = IntAverage() + res, = n.execute(value0=2, value1=4, value2=6) + assert res == 4 + + +def test_float_lerp(): + n = FloatLerp() + res, = n.execute(a=0.0, b=1.0, t=0.5, clamped=True) + assert res == 0.5 + + res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=True) + assert res == 1.0 + + res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=False) + assert res == 1.5 + + +def test_int_lerp(): + n = IntLerp() + res, = n.execute(a=0, b=10, t=0.5, clamped=True) + assert res == 5 + + res, = n.execute(a=0, b=10, t=1.5, clamped=True) + assert res == 10 + + res, = n.execute(a=0, b=10, t=1.5, clamped=False) + assert res == 15 + + +def test_float_inverse_lerp(): + n = FloatInverseLerp() + res, = n.execute(a=0.0, b=1.0, value=0.5, clamped=True) + assert res == 0.5 + + res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=True) + assert res == 1.0 + + res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=False) + assert res == 1.5 + + +def test_float_clamp(): + n = FloatClamp() + res, = n.execute(value=0.5, min=0.0, max=1.0) + assert res == 0.5 + + res, = n.execute(value=1.5, min=0.0, max=1.0) + assert res == 1.0 + + res, = n.execute(value=-0.5, min=0.0, max=1.0) + assert res == 0.0 + + +def test_int_inverse_lerp(): + n = IntInverseLerp() + res, = n.execute(a=0, b=10, value=5, clamped=True) + assert res == 0.5 + + res, = n.execute(a=0, b=10, value=15, clamped=True) + assert res == 1.0 + + res, = n.execute(a=0, b=10, value=15, clamped=False) + assert res == 1.5 + + +def test_int_clamp(): + n = IntClamp() + res, = n.execute(value=5, min=0, max=10) + assert res == 5 + + res, = n.execute(value=15, min=0, max=10) + assert res == 10 + + res, = n.execute(value=-5, min=0, max=10) + assert res == 0 diff --git a/tests/nodes/test_colormap_unit.py b/tests/nodes/test_colormap_unit.py new file mode 100644 index 000000000..16805185e --- /dev/null +++ b/tests/nodes/test_colormap_unit.py @@ -0,0 +1,49 @@ +import pytest +import torch +from comfy_extras.nodes.nodes_apply_color_map import ImageApplyColorMap + + +@pytest.fixture +def input_image(): + # Create a 1x1x2x1 tensor representing an image with absolute distances of 1.3 meters and 300 meters + return torch.tensor([[[[1.3], [300.0]]]], dtype=torch.float32) + + +def test_apply_colormap_grayscale(input_image): + node = ImageApplyColorMap() + colored_image, = node.execute(image=input_image, colormap="Grayscale", min_depth=1.3, max_depth=300.0) + + assert colored_image.shape == (1, 1, 2, 3) + assert colored_image.dtype == torch.float32 + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([1.0, 1.0, 1.0])) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0])) + + +def test_apply_colormap_inferno(input_image): + node = ImageApplyColorMap() + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", min_depth=1.3, max_depth=300.0) + + assert colored_image.shape == (1, 1, 2, 3) + assert colored_image.dtype == torch.float32 + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.000, 0.6431]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4) + + +def test_apply_colormap_clipping(input_image): + node = ImageApplyColorMap() + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=False, min_depth=1.3, max_depth=300.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=False, min_depth=1.3, max_depth=300.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4) + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=True, min_depth=1.3, max_depth=200.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4) + + colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=True, min_depth=1.3, max_depth=200.0) + assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4) + assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)