mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improved support for ControlNet workflows with depth
- ComfyUI can now load EXR files. - There are new arithmetic nodes for floats and integers. - EXR nodes can load depth maps and be remapped with ImageApplyColormap. This allows end users to use ground truth depth data from video game engines or 3D graphics tools and recolor it to the format expected by depth ControlNets: grayscale inverse depth maps and "inferno" colored inverse depth maps. - Fixed license notes. - Added an additional known ControlNet model. - Because CV2 is now used to read OpenEXR files, an environment variable must be set early on in the application, before CV2 is imported. This file, main_pre, is now imported early on in more places.
This commit is contained in:
parent
d8846fcb39
commit
b0be335d59
@ -100,7 +100,7 @@ class EmbeddedComfyClient:
|
||||
if self._configuration is None:
|
||||
options.enable_args_parsing()
|
||||
else:
|
||||
from ..cli_args import args
|
||||
from ..cmd.main_pre import args
|
||||
args.clear()
|
||||
args.update(self._configuration)
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# Suppress warnings during import
|
||||
import asyncio
|
||||
import gc
|
||||
import itertools
|
||||
@ -8,9 +7,10 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
|
||||
# main_pre must be the earliest import since it suppresses some spurious warnings
|
||||
from .main_pre import args
|
||||
from ..utils import hijack_progress
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .main_pre import args
|
||||
from .. import model_management
|
||||
from ..analytics.analytics import initialize_event_tracking
|
||||
from ..cmd import cuda_malloc
|
||||
|
||||
@ -1,3 +1,12 @@
|
||||
"""
|
||||
This should be imported before entrypoints to correctly configure global options prior to importing packages like torch and cv2.
|
||||
|
||||
Use this instead of cli_args to import the args:
|
||||
|
||||
>>> from comfy.cmd.main_pre import args
|
||||
|
||||
It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes.
|
||||
"""
|
||||
import os
|
||||
|
||||
from .. import options
|
||||
@ -9,6 +18,8 @@ options.enable_args_parsing()
|
||||
if os.name == "nt":
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
|
||||
warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.")
|
||||
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
@ -20,4 +31,5 @@ if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
__all__ = ["args"]
|
||||
|
||||
@ -38,6 +38,7 @@ from ..component_model.file_output_path import file_output_path
|
||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
|
||||
from ..digest import digest
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
from ..images import open_image
|
||||
|
||||
|
||||
class HeuristicPath(NamedTuple):
|
||||
@ -289,9 +290,12 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.Response(status=400)
|
||||
|
||||
if os.path.isfile(file):
|
||||
if 'preview' in request.rel_url.query:
|
||||
with Image.open(file) as img:
|
||||
preview_info = request.rel_url.query['preview'].split(';')
|
||||
# todo: any image file we upload that browsers don't support, we should encode a preview
|
||||
# todo: image handling has to be a little bit more standardized, sometimes we want a Pillow Image, sometimes
|
||||
# we want something that will render to the user, sometimes we want tensors
|
||||
if 'preview' in request.rel_url.query or file.endswith(".exr"):
|
||||
with open_image(file) as img:
|
||||
preview_info = request.rel_url.query.get("preview", "jpeg;90").split(';')
|
||||
image_format = preview_info[0]
|
||||
if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
|
||||
image_format = 'webp'
|
||||
|
||||
@ -4,11 +4,7 @@ import os
|
||||
import logging
|
||||
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .. import options
|
||||
|
||||
options.enable_args_parsing()
|
||||
|
||||
from ..cli_args import args
|
||||
from .main_pre import args
|
||||
|
||||
|
||||
async def main():
|
||||
@ -17,28 +13,19 @@ async def main():
|
||||
args.distributed_queue_frontend = False
|
||||
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
|
||||
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
logging.info(f"Set cuda device to: {args.cuda_device}")
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
# configure paths
|
||||
if args.output_directory:
|
||||
output_dir = os.path.abspath(args.output_directory)
|
||||
logging.info(f"Setting output directory to: {output_dir}")
|
||||
from ..cmd import folder_paths
|
||||
|
||||
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
logging.info(f"Setting input directory to: {input_dir}")
|
||||
from ..cmd import folder_paths
|
||||
|
||||
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
|
||||
if args.temp_directory:
|
||||
|
||||
8
comfy/component_model/images_types.py
Normal file
8
comfy/component_model/images_types.py
Normal file
@ -0,0 +1,8 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class RgbMaskTuple(NamedTuple):
|
||||
rgb: Tensor
|
||||
mask: Tensor
|
||||
19
comfy/images.py
Normal file
19
comfy/images.py
Normal file
@ -0,0 +1,19 @@
|
||||
import os.path
|
||||
from contextlib import contextmanager
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def _open_exr(exr_path) -> Image.Image:
|
||||
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_image(file_path: str) -> Image.Image:
|
||||
_, ext = os.path.splitext(file_path)
|
||||
if ext == ".exr":
|
||||
yield _open_exr(file_path)
|
||||
else:
|
||||
with Image.open(file_path) as image:
|
||||
yield image
|
||||
@ -755,7 +755,32 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
||||
"""
|
||||
Portions of this function are adapted from the repository
|
||||
https://github.com/Carzit/sd-webui-samplers-scheduler
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Carzit
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
s_end = sigmas[-1]
|
||||
|
||||
@ -51,7 +51,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
||||
else:
|
||||
linked_filename = None
|
||||
try:
|
||||
os.symlink(os.path.join(destination,known_file.filename), linked_filename)
|
||||
os.symlink(os.path.join(destination, known_file.filename), linked_filename)
|
||||
except Exception as exc_info:
|
||||
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info)
|
||||
else:
|
||||
@ -213,6 +213,7 @@ KNOWN_CONTROLNETS = [
|
||||
HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_faid_vidit.safetensors"),
|
||||
HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_zeed.safetensors"),
|
||||
HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_softedge.safetensors"),
|
||||
HuggingFile("SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe", "depth-zoe-xl-v1.0-controlnet.safetensors"),
|
||||
HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_canny.safetensors"),
|
||||
HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_midas.safetensors"),
|
||||
HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_zoe.safetensors"),
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from os.path import split
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Sequence
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
|
||||
@ -15,10 +15,12 @@ class CivitFile:
|
||||
model_id (int): The ID of the model
|
||||
model_version_id (int): The version
|
||||
filename (str): The name of the file in the model
|
||||
trigger_words (List[str]): Trigger words associated with the model
|
||||
"""
|
||||
model_id: int
|
||||
model_version_id: int
|
||||
filename: str
|
||||
trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple)
|
||||
|
||||
def __str__(self):
|
||||
return self.filename
|
||||
|
||||
@ -9,6 +9,7 @@ import logging
|
||||
|
||||
from PIL import Image, ImageOps, ImageSequence
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from natsort import natsorted
|
||||
from pkg_resources import resource_filename
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
@ -23,10 +24,13 @@ from .. import model_management
|
||||
from ..cli_args import args
|
||||
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..images import open_image
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
||||
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
from ..open_exr import load_exr
|
||||
|
||||
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
@ -1454,38 +1458,49 @@ class PreviewImage(SaveImage):
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
|
||||
class LoadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True})},
|
||||
}
|
||||
return {
|
||||
"required": {
|
||||
"image": (natsorted(files), {"image_upload": True}),
|
||||
},
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image):
|
||||
|
||||
def load_image(self, image: str):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
img = Image.open(image_path)
|
||||
output_images = []
|
||||
output_masks = []
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
# maintain the legacy path
|
||||
# this will ultimately return a tensor, so we'd rather have the tensors directly
|
||||
# from cv2 rather than get them out of a PIL image
|
||||
_, ext = os.path.splitext(image)
|
||||
if ext == ".exr":
|
||||
return load_exr(image_path, srgb=False)
|
||||
with open_image(image_path) as img:
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = ImageOps.exif_transpose(i)
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
@ -1494,7 +1509,7 @@ class LoadImage:
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return (output_image, output_mask)
|
||||
return output_image, output_mask
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
|
||||
@ -53,7 +53,7 @@ BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
|
||||
|
||||
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
|
||||
|
||||
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes]
|
||||
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
||||
|
||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||
|
||||
|
||||
86
comfy/open_exr.py
Normal file
86
comfy/open_exr.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""
|
||||
Portions of this code are adapted from the repository
|
||||
https://github.com/spacepxl/ComfyUI-HQ-Image-Save
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 spacepxl
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .component_model.images_types import RgbMaskTuple
|
||||
|
||||
|
||||
def mut_srgb_to_linear(np_array) -> None:
|
||||
less = np_array <= 0.0404482362771082
|
||||
np_array[less] = np_array[less] / 12.92
|
||||
np_array[~less] = np.power((np_array[~less] + 0.055) / 1.055, 2.4)
|
||||
|
||||
|
||||
def mut_linear_to_srgb(np_array) -> None:
|
||||
less = np_array <= 0.0031308
|
||||
np_array[less] = np_array[less] * 12.92
|
||||
np_array[~less] = np.power(np_array[~less], 1 / 2.4) * 1.055 - 0.055
|
||||
|
||||
|
||||
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
||||
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
|
||||
rgb = np.flip(image[:, :, :3], 2).copy()
|
||||
if srgb:
|
||||
mut_linear_to_srgb(rgb)
|
||||
rgb = np.clip(rgb, 0, 1)
|
||||
rgb = torch.unsqueeze(torch.from_numpy(rgb), 0)
|
||||
|
||||
mask = torch.zeros((1, image.shape[0], image.shape[1]), dtype=torch.float32)
|
||||
if image.shape[2] > 3:
|
||||
mask[0] = torch.from_numpy(np.clip(image[:, :, 3], 0, 1))
|
||||
|
||||
return RgbMaskTuple(rgb, mask)
|
||||
|
||||
|
||||
def load_exr_latent(file_path: str) -> Tuple[Tensor]:
|
||||
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
|
||||
image = image[:, :, np.array([2, 1, 0, 3])]
|
||||
image = torch.unsqueeze(torch.from_numpy(image), 0)
|
||||
image = torch.movedim(image, -1, 1)
|
||||
return image,
|
||||
|
||||
|
||||
def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linear"):
|
||||
linear = images.detach().clone().cpu().numpy().astype(np.float32)
|
||||
if colorspace == "linear":
|
||||
mut_srgb_to_linear(linear[:, :, :, :3]) # only convert RGB, not Alpha
|
||||
|
||||
bgr = copy.deepcopy(linear)
|
||||
bgr[:, :, :, 0] = linear[:, :, :, 2] # flip RGB to BGR for opencv
|
||||
bgr[:, :, :, 2] = linear[:, :, :, 0]
|
||||
if bgr.shape[-1] > 3:
|
||||
bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha
|
||||
|
||||
for i in range(len(linear.shape[0])):
|
||||
cv.imwrite(filepaths_batched[i], bgr[i])
|
||||
@ -334,7 +334,7 @@ def get_attr(obj, attr):
|
||||
def bislerp(samples, width, height):
|
||||
def slerp(b1, b2, r):
|
||||
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||
|
||||
|
||||
c = b1.shape[-1]
|
||||
|
||||
#norms
|
||||
@ -359,16 +359,16 @@ def bislerp(samples, width, height):
|
||||
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||
|
||||
#edge cases for same or polar opposites
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
|
||||
def generate_bilinear_data(length_old, length_new, device):
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
@ -379,7 +379,7 @@ def bislerp(samples, width, height):
|
||||
samples = samples.float()
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
@ -496,6 +496,17 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
|
||||
class _DisabledProgressBar:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def update_absolute(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total: float):
|
||||
global PROGRESS_BAR_HOOK
|
||||
@ -545,3 +556,12 @@ def comfy_tqdm():
|
||||
# Restore original tqdm
|
||||
tqdm.__init__ = _original_init
|
||||
tqdm.update = _original_update
|
||||
|
||||
|
||||
@contextmanager
|
||||
def comfy_progress(total: float) -> ProgressBar:
|
||||
global PROGRESS_BAR_ENABLED
|
||||
if PROGRESS_BAR_ENABLED:
|
||||
yield ProgressBar(total)
|
||||
else:
|
||||
yield _DisabledProgressBar()
|
||||
|
||||
@ -469,7 +469,7 @@ export const ComfyWidgets = {
|
||||
const fileInput = document.createElement("input");
|
||||
Object.assign(fileInput, {
|
||||
type: "file",
|
||||
accept: "image/jpeg,image/png,image/webp",
|
||||
accept: "image/jpeg,image/png,image/webp,image/x-exr,.exr",
|
||||
style: "display: none",
|
||||
onchange: async () => {
|
||||
if (fileInput.files.length) {
|
||||
|
||||
85
comfy_extras/nodes/nodes_apply_color_map.py
Normal file
85
comfy_extras/nodes/nodes_apply_color_map.py
Normal file
@ -0,0 +1,85 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
|
||||
_available_colormaps = ["Grayscale"] + [attr for attr in dir(cv2) if attr.startswith('COLORMAP')]
|
||||
|
||||
|
||||
class ImageApplyColorMap(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE", {}),
|
||||
"colormap": (_available_colormaps, {"default": "COLORMAP_INFERNO"}),
|
||||
"gamma": ("FLOAT", {"default": 1.0, "min": 0.001, "step": 0.001, "round": 0.001}),
|
||||
"min_depth": ("FLOAT", {"default": 0.001, "min": 0.001, "round": 0.00001, "step": 0.001}),
|
||||
"max_depth": ("FLOAT", {"default": 1e2, "round": 0.00001, "step": 0.1}),
|
||||
"one_minus": ("BOOLEAN", {"default": False}),
|
||||
"clip_min": ("BOOLEAN", {"default": True}),
|
||||
"clip_max": ("BOOLEAN", {"default": False}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
CATEGORY = "image/postprocessing"
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self,
|
||||
image: Tensor,
|
||||
gamma: float = 1.0,
|
||||
min_depth: float = 0.001,
|
||||
max_depth: float = 1e3,
|
||||
colormap: str = "COLORMAP_INFERNO",
|
||||
one_minus: bool = False,
|
||||
clip_min: bool = True,
|
||||
clip_max: bool = False,
|
||||
) -> ValidatedNodeResult:
|
||||
"""
|
||||
Invert and apply a colormap to a batch of absolute distance depth images.
|
||||
|
||||
For Zoe and Midas, set colormap to be `COLORMAP_INFERNO`. Diffusers Depth expects `Grayscale`.
|
||||
|
||||
As per https://huggingface.co/SargeZT/controlnet-v1e-sdxl-depth/discussions/7, some ControlNet checkpoints
|
||||
expect one_minus to be true.
|
||||
"""
|
||||
colored_images = []
|
||||
|
||||
for i in range(image.shape[0]):
|
||||
depth_image = image[i, :, :, 0].numpy()
|
||||
depth_image = np.where(depth_image <= min_depth, np.nan if not clip_min else min_depth, depth_image)
|
||||
if clip_max:
|
||||
depth_image = np.where(depth_image >= max_depth, max_depth, depth_image)
|
||||
depth_image = np.power(depth_image, 1.0 / gamma)
|
||||
inv_depth_image = 1.0 / depth_image
|
||||
|
||||
xp = [1.0 / max_depth, 1.0 / min_depth]
|
||||
fp = [0, 1]
|
||||
normalized_depth = np.interp(inv_depth_image, xp, fp, left=0, right=1)
|
||||
normalized_depth = np.nan_to_num(normalized_depth, nan=0)
|
||||
|
||||
normalized_depth_uint8 = (normalized_depth * 255).astype(np.uint8)
|
||||
if one_minus:
|
||||
normalized_depth_uint8 = 255 - normalized_depth_uint8
|
||||
if colormap == "Grayscale":
|
||||
colored_image = normalized_depth_uint8
|
||||
else:
|
||||
cv2_colormap = getattr(cv2, colormap)
|
||||
colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap)
|
||||
colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB)
|
||||
rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0
|
||||
colored_images.append(rgb_tensor)
|
||||
|
||||
return torch.stack(colored_images),
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
ImageApplyColorMap.__name__: ImageApplyColorMap,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
ImageApplyColorMap.__name__: "Apply ColorMap to Image (CV2)",
|
||||
}
|
||||
514
comfy_extras/nodes/nodes_arithmetic.py
Normal file
514
comfy_extras/nodes/nodes_arithmetic.py
Normal file
@ -0,0 +1,514 @@
|
||||
from functools import reduce
|
||||
from operator import add, mul, pow
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
|
||||
class FloatAdd(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("FLOAT", {})}
|
||||
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (reduce(add, kwargs.values(), 0.0),)
|
||||
|
||||
|
||||
class FloatSubtract(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value0": ("FLOAT", {}),
|
||||
"value1": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value0, value1):
|
||||
return (value0 - value1,)
|
||||
|
||||
|
||||
class FloatMultiply(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("FLOAT", {})}
|
||||
range_.update({f"value{i}": ("FLOAT", {"default": 1.0}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (reduce(mul, kwargs.values(), 1.0),)
|
||||
|
||||
|
||||
class FloatDivide(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value0": ("FLOAT", {}),
|
||||
"value1": ("FLOAT", {"default": 1.0}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value0, value1):
|
||||
return (value0 / value1 if value1 != 0 else float("inf"),)
|
||||
|
||||
|
||||
class FloatPower(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"base": ("FLOAT", {}),
|
||||
"exponent": ("FLOAT", {"default": 1.0}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, base, exponent):
|
||||
return (pow(base, exponent),)
|
||||
|
||||
|
||||
class IntAdd(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("INT", {})}
|
||||
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (reduce(add, kwargs.values(), 0),)
|
||||
|
||||
|
||||
class IntSubtract(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value0": ("INT", {}),
|
||||
"value1": ("INT", {"default": 0}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value0, value1):
|
||||
return (value0 - value1,)
|
||||
|
||||
|
||||
class IntMultiply(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("INT", {})}
|
||||
range_.update({f"value{i}": ("INT", {"default": 1}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (reduce(mul, kwargs.values(), 1),)
|
||||
|
||||
|
||||
class IntDivide(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value0": ("INT", {}),
|
||||
"value1": ("INT", {"default": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value0, value1):
|
||||
return (value0 // value1 if value1 != 0 else 0,)
|
||||
|
||||
|
||||
class IntMod(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value0": ("INT", {}),
|
||||
"value1": ("INT", {"default": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value0, value1):
|
||||
return (value0 % value1 if value1 != 0 else 0,)
|
||||
|
||||
|
||||
class IntPower(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"base": ("INT", {}),
|
||||
"exponent": ("INT", {"default": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, base, exponent):
|
||||
return (pow(base, exponent),)
|
||||
|
||||
|
||||
class FloatMin(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("FLOAT", {})}
|
||||
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (min(kwargs.values()),)
|
||||
|
||||
|
||||
class FloatMax(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("FLOAT", {})}
|
||||
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (max(kwargs.values()),)
|
||||
|
||||
|
||||
class FloatAbs(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {})
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value):
|
||||
return (abs(value),)
|
||||
|
||||
|
||||
class FloatAverage(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("FLOAT", {})}
|
||||
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (sum(kwargs.values()) / len(kwargs),)
|
||||
|
||||
|
||||
class IntMin(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("INT", {})}
|
||||
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (min(kwargs.values()),)
|
||||
|
||||
|
||||
class IntMax(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("INT", {})}
|
||||
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (max(kwargs.values()),)
|
||||
|
||||
|
||||
class IntAbs(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("INT", {})
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value):
|
||||
return (abs(value),)
|
||||
|
||||
|
||||
class IntAverage(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": ("INT", {})}
|
||||
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return (sum(kwargs.values()) // len(kwargs),)
|
||||
|
||||
|
||||
class FloatLerp(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
|
||||
"b": ("FLOAT", {"default": 1.0}),
|
||||
"t": ("FLOAT", {}),
|
||||
"clamped": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, a, b, t, clamped):
|
||||
value = a + (b - a) * t
|
||||
if clamped:
|
||||
value = min(max(value, a), b)
|
||||
return (value,)
|
||||
|
||||
|
||||
class FloatInverseLerp(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
|
||||
"b": ("FLOAT", {"default": 1.0}),
|
||||
"value": ("FLOAT", {}),
|
||||
"clamped": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, a, b, value, clamped):
|
||||
if a == b:
|
||||
return (0.0,)
|
||||
t = (value - a) / (b - a)
|
||||
if clamped:
|
||||
t = min(max(t, 0.0), 1.0)
|
||||
return (t,)
|
||||
|
||||
|
||||
class FloatClamp(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {}),
|
||||
"min": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
|
||||
"max": ("FLOAT", {"default": 1.0}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value: float = 0, **kwargs):
|
||||
v_min: float = kwargs['min']
|
||||
v_max: float = kwargs['max']
|
||||
return (min(max(value, v_min), v_max),)
|
||||
|
||||
|
||||
class IntLerp(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0}),
|
||||
"b": ("INT", {"default": 10}),
|
||||
"t": ("FLOAT", {}),
|
||||
"clamped": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, a, b, t, clamped):
|
||||
value = int(round(a + (b - a) * t))
|
||||
if clamped:
|
||||
value = min(max(value, a), b)
|
||||
return (value,)
|
||||
|
||||
|
||||
class IntInverseLerp(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0}),
|
||||
"b": ("INT", {"default": 10}),
|
||||
"value": ("INT", {}),
|
||||
"clamped": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, a, b, value, clamped):
|
||||
if a == b:
|
||||
return (0.0,)
|
||||
t = (value - a) / (b - a)
|
||||
if clamped:
|
||||
t = min(max(t, 0.0), 1.0)
|
||||
return (t,)
|
||||
|
||||
|
||||
class IntClamp(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("INT", {}),
|
||||
"min": ("INT", {"default": 0}),
|
||||
"max": ("INT", {"default": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "arithmetic"
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value: int = 0, **kwargs):
|
||||
v_min: int = kwargs['min']
|
||||
v_max: int = kwargs['max']
|
||||
|
||||
return (min(max(value, v_min), v_max),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
FloatAdd,
|
||||
FloatSubtract,
|
||||
FloatMultiply,
|
||||
FloatDivide,
|
||||
FloatPower,
|
||||
FloatMin,
|
||||
FloatMax,
|
||||
FloatAbs,
|
||||
FloatAverage,
|
||||
FloatLerp,
|
||||
FloatInverseLerp,
|
||||
FloatClamp,
|
||||
IntAdd,
|
||||
IntSubtract,
|
||||
IntMultiply,
|
||||
IntDivide,
|
||||
IntMod,
|
||||
IntPower,
|
||||
IntMin,
|
||||
IntMax,
|
||||
IntAbs,
|
||||
IntAverage,
|
||||
IntLerp,
|
||||
IntInverseLerp,
|
||||
IntClamp,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
@ -1,4 +1,28 @@
|
||||
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
|
||||
"""
|
||||
Portions of this code are adapted from the repository
|
||||
https://github.com/ChenyangSi/FreeU
|
||||
|
||||
MIT License
|
||||
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
|
||||
48
comfy_extras/nodes/nodes_image_arithmetic.py
Normal file
48
comfy_extras/nodes/nodes_image_arithmetic.py
Normal file
@ -0,0 +1,48 @@
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
|
||||
|
||||
class ImageMin(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE", {})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
CATEGORY = "image/postprocessing"
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, image: Tensor) -> ValidatedNodeResult:
|
||||
return float(image.min().item()),
|
||||
|
||||
|
||||
class ImageMax(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE", {})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
CATEGORY = "image/postprocessing"
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, image: Tensor) -> ValidatedNodeResult:
|
||||
return float(image.max().item()),
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
ImageMin.__name__: ImageMin,
|
||||
ImageMax.__name__: ImageMax,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
ImageMin.__name__: "Image Minimum Value",
|
||||
ImageMax.__name__: "Image Maximum Value"
|
||||
}
|
||||
@ -16,12 +16,12 @@ import fsspec
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from fsspec.core import OpenFiles, OpenFile
|
||||
from fsspec.core import OpenFile
|
||||
from fsspec.generic import GenericFileSystem
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
from joblib import Parallel, delayed
|
||||
from torch import Tensor
|
||||
from natsort import natsorted
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.digest import digest
|
||||
|
||||
@ -97,8 +97,8 @@ class SelfAttentionGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
|
||||
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01, "round": 0.01}),
|
||||
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
@ -35,4 +35,5 @@ huggingface_hub
|
||||
lazy-object-proxy
|
||||
can_ada
|
||||
fsspec
|
||||
natsort
|
||||
natsort
|
||||
OpenEXR
|
||||
198
tests/nodes/test_arithmetic_unit.py
Normal file
198
tests/nodes/test_arithmetic_unit.py
Normal file
@ -0,0 +1,198 @@
|
||||
import pytest
|
||||
|
||||
from comfy_extras.nodes.nodes_arithmetic import IntAdd, IntSubtract, IntMultiply, IntDivide, IntMod, IntPower, FloatAdd, FloatSubtract, FloatMultiply, FloatDivide, FloatPower, FloatMin, FloatMax, FloatAbs, FloatAverage, IntMin, IntMax, IntAbs, IntAverage, FloatLerp, IntLerp, IntClamp, IntInverseLerp, FloatClamp, FloatInverseLerp
|
||||
|
||||
|
||||
def test_int_add():
|
||||
n = IntAdd()
|
||||
res, = n.execute(value0=1, value1=2, value2=3)
|
||||
assert res == 6
|
||||
|
||||
|
||||
def test_int_subtract():
|
||||
n = IntSubtract()
|
||||
res, = n.execute(value0=10, value1=3)
|
||||
assert res == 7
|
||||
|
||||
|
||||
def test_int_multiply():
|
||||
n = IntMultiply()
|
||||
res, = n.execute(value0=2, value1=3, value2=4)
|
||||
assert res == 24
|
||||
|
||||
|
||||
def test_int_divide():
|
||||
n = IntDivide()
|
||||
res, = n.execute(value0=10, value1=3)
|
||||
assert res == 3
|
||||
|
||||
res, = n.execute(value0=10, value1=0)
|
||||
assert res == 0
|
||||
|
||||
|
||||
def test_int_mod():
|
||||
n = IntMod()
|
||||
res, = n.execute(value0=10, value1=3)
|
||||
assert res == 1
|
||||
|
||||
res, = n.execute(value0=10, value1=0)
|
||||
assert res == 0
|
||||
|
||||
|
||||
def test_int_power():
|
||||
n = IntPower()
|
||||
res, = n.execute(base=2, exponent=3)
|
||||
assert res == 8
|
||||
|
||||
|
||||
def test_float_add():
|
||||
n = FloatAdd()
|
||||
res, = n.execute(value0=1.5, value1=2.3, value2=3.7)
|
||||
assert pytest.approx(res) == 7.5
|
||||
|
||||
|
||||
def test_float_subtract():
|
||||
n = FloatSubtract()
|
||||
res, = n.execute(value0=10.5, value1=3.2)
|
||||
assert pytest.approx(res) == 7.3
|
||||
|
||||
|
||||
def test_float_multiply():
|
||||
n = FloatMultiply()
|
||||
res, = n.execute(value0=2.5, value1=3.0, value2=4.0)
|
||||
assert pytest.approx(res) == 30.0
|
||||
|
||||
|
||||
def test_float_divide():
|
||||
n = FloatDivide()
|
||||
res, = n.execute(value0=10.0, value1=4.0)
|
||||
assert pytest.approx(res) == 2.5
|
||||
|
||||
res, = n.execute(value0=10.0, value1=0.0)
|
||||
assert res == float("inf")
|
||||
|
||||
|
||||
def test_float_power():
|
||||
n = FloatPower()
|
||||
res, = n.execute(base=2.5, exponent=3.0)
|
||||
assert pytest.approx(res) == 15.625
|
||||
|
||||
|
||||
def test_float_min():
|
||||
n = FloatMin()
|
||||
res, = n.execute(value0=1.5, value1=2.3, value2=0.7)
|
||||
assert res == 0.7
|
||||
|
||||
|
||||
def test_float_max():
|
||||
n = FloatMax()
|
||||
res, = n.execute(value0=1.5, value1=2.3, value2=0.7)
|
||||
assert res == 2.3
|
||||
|
||||
|
||||
def test_float_abs():
|
||||
n = FloatAbs()
|
||||
res, = n.execute(value=-3.14)
|
||||
assert res == 3.14
|
||||
|
||||
|
||||
def test_float_average():
|
||||
n = FloatAverage()
|
||||
res, = n.execute(value0=1.5, value1=2.5, value2=3.5)
|
||||
assert res == 2.5
|
||||
|
||||
|
||||
def test_int_min():
|
||||
n = IntMin()
|
||||
res, = n.execute(value0=5, value1=2, value2=7)
|
||||
assert res == 2
|
||||
|
||||
|
||||
def test_int_max():
|
||||
n = IntMax()
|
||||
res, = n.execute(value0=5, value1=2, value2=7)
|
||||
assert res == 7
|
||||
|
||||
|
||||
def test_int_abs():
|
||||
n = IntAbs()
|
||||
res, = n.execute(value=-10)
|
||||
assert res == 10
|
||||
|
||||
|
||||
def test_int_average():
|
||||
n = IntAverage()
|
||||
res, = n.execute(value0=2, value1=4, value2=6)
|
||||
assert res == 4
|
||||
|
||||
|
||||
def test_float_lerp():
|
||||
n = FloatLerp()
|
||||
res, = n.execute(a=0.0, b=1.0, t=0.5, clamped=True)
|
||||
assert res == 0.5
|
||||
|
||||
res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=True)
|
||||
assert res == 1.0
|
||||
|
||||
res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=False)
|
||||
assert res == 1.5
|
||||
|
||||
|
||||
def test_int_lerp():
|
||||
n = IntLerp()
|
||||
res, = n.execute(a=0, b=10, t=0.5, clamped=True)
|
||||
assert res == 5
|
||||
|
||||
res, = n.execute(a=0, b=10, t=1.5, clamped=True)
|
||||
assert res == 10
|
||||
|
||||
res, = n.execute(a=0, b=10, t=1.5, clamped=False)
|
||||
assert res == 15
|
||||
|
||||
|
||||
def test_float_inverse_lerp():
|
||||
n = FloatInverseLerp()
|
||||
res, = n.execute(a=0.0, b=1.0, value=0.5, clamped=True)
|
||||
assert res == 0.5
|
||||
|
||||
res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=True)
|
||||
assert res == 1.0
|
||||
|
||||
res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=False)
|
||||
assert res == 1.5
|
||||
|
||||
|
||||
def test_float_clamp():
|
||||
n = FloatClamp()
|
||||
res, = n.execute(value=0.5, min=0.0, max=1.0)
|
||||
assert res == 0.5
|
||||
|
||||
res, = n.execute(value=1.5, min=0.0, max=1.0)
|
||||
assert res == 1.0
|
||||
|
||||
res, = n.execute(value=-0.5, min=0.0, max=1.0)
|
||||
assert res == 0.0
|
||||
|
||||
|
||||
def test_int_inverse_lerp():
|
||||
n = IntInverseLerp()
|
||||
res, = n.execute(a=0, b=10, value=5, clamped=True)
|
||||
assert res == 0.5
|
||||
|
||||
res, = n.execute(a=0, b=10, value=15, clamped=True)
|
||||
assert res == 1.0
|
||||
|
||||
res, = n.execute(a=0, b=10, value=15, clamped=False)
|
||||
assert res == 1.5
|
||||
|
||||
|
||||
def test_int_clamp():
|
||||
n = IntClamp()
|
||||
res, = n.execute(value=5, min=0, max=10)
|
||||
assert res == 5
|
||||
|
||||
res, = n.execute(value=15, min=0, max=10)
|
||||
assert res == 10
|
||||
|
||||
res, = n.execute(value=-5, min=0, max=10)
|
||||
assert res == 0
|
||||
49
tests/nodes/test_colormap_unit.py
Normal file
49
tests/nodes/test_colormap_unit.py
Normal file
@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
import torch
|
||||
from comfy_extras.nodes.nodes_apply_color_map import ImageApplyColorMap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_image():
|
||||
# Create a 1x1x2x1 tensor representing an image with absolute distances of 1.3 meters and 300 meters
|
||||
return torch.tensor([[[[1.3], [300.0]]]], dtype=torch.float32)
|
||||
|
||||
|
||||
def test_apply_colormap_grayscale(input_image):
|
||||
node = ImageApplyColorMap()
|
||||
colored_image, = node.execute(image=input_image, colormap="Grayscale", min_depth=1.3, max_depth=300.0)
|
||||
|
||||
assert colored_image.shape == (1, 1, 2, 3)
|
||||
assert colored_image.dtype == torch.float32
|
||||
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([1.0, 1.0, 1.0]))
|
||||
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0]))
|
||||
|
||||
|
||||
def test_apply_colormap_inferno(input_image):
|
||||
node = ImageApplyColorMap()
|
||||
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", min_depth=1.3, max_depth=300.0)
|
||||
|
||||
assert colored_image.shape == (1, 1, 2, 3)
|
||||
assert colored_image.dtype == torch.float32
|
||||
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.000, 0.6431]), atol=1e-4)
|
||||
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)
|
||||
|
||||
|
||||
def test_apply_colormap_clipping(input_image):
|
||||
node = ImageApplyColorMap()
|
||||
|
||||
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=False, min_depth=1.3, max_depth=300.0)
|
||||
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
|
||||
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
|
||||
|
||||
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=False, min_depth=1.3, max_depth=300.0)
|
||||
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4)
|
||||
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)
|
||||
|
||||
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=True, min_depth=1.3, max_depth=200.0)
|
||||
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
|
||||
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
|
||||
|
||||
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=True, min_depth=1.3, max_depth=200.0)
|
||||
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4)
|
||||
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)
|
||||
Loading…
Reference in New Issue
Block a user