Improved support for ControlNet workflows with depth

- ComfyUI can now load EXR files.
 - There are new arithmetic nodes for floats and integers.
 - EXR nodes can load depth maps and be remapped with
   ImageApplyColormap. This allows end users to use ground truth depth
   data from video game engines or 3D graphics tools and recolor it to
   the format expected by depth ControlNets: grayscale inverse depth
   maps and "inferno" colored inverse depth maps.
 - Fixed license notes.
 - Added an additional known ControlNet model.
 - Because CV2 is now used to read OpenEXR files, an environment
   variable must be set early on in the application, before CV2 is
   imported. This file, main_pre, is now imported early on in more
   places.
This commit is contained in:
doctorpangloss 2024-03-26 22:32:15 -07:00
parent d8846fcb39
commit b0be335d59
24 changed files with 1157 additions and 59 deletions

View File

@ -100,7 +100,7 @@ class EmbeddedComfyClient:
if self._configuration is None:
options.enable_args_parsing()
else:
from ..cli_args import args
from ..cmd.main_pre import args
args.clear()
args.update(self._configuration)

View File

@ -1,4 +1,3 @@
# Suppress warnings during import
import asyncio
import gc
import itertools
@ -8,9 +7,10 @@ import shutil
import threading
import time
# main_pre must be the earliest import since it suppresses some spurious warnings
from .main_pre import args
from ..utils import hijack_progress
from .extra_model_paths import load_extra_path_config
from .main_pre import args
from .. import model_management
from ..analytics.analytics import initialize_event_tracking
from ..cmd import cuda_malloc

View File

@ -1,3 +1,12 @@
"""
This should be imported before entrypoints to correctly configure global options prior to importing packages like torch and cv2.
Use this instead of cli_args to import the args:
>>> from comfy.cmd.main_pre import args
It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes.
"""
import os
from .. import options
@ -9,6 +18,8 @@ options.enable_args_parsing()
if os.name == "nt":
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.")
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
from ..cli_args import args
@ -20,4 +31,5 @@ if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
__all__ = ["args"]

View File

@ -38,6 +38,7 @@ from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from ..digest import digest
from ..nodes.package_typing import ExportedNodes
from ..images import open_image
class HeuristicPath(NamedTuple):
@ -289,9 +290,12 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=400)
if os.path.isfile(file):
if 'preview' in request.rel_url.query:
with Image.open(file) as img:
preview_info = request.rel_url.query['preview'].split(';')
# todo: any image file we upload that browsers don't support, we should encode a preview
# todo: image handling has to be a little bit more standardized, sometimes we want a Pillow Image, sometimes
# we want something that will render to the user, sometimes we want tensors
if 'preview' in request.rel_url.query or file.endswith(".exr"):
with open_image(file) as img:
preview_info = request.rel_url.query.get("preview", "jpeg;90").split(';')
image_format = preview_info[0]
if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
image_format = 'webp'

View File

@ -4,11 +4,7 @@ import os
import logging
from .extra_model_paths import load_extra_path_config
from .. import options
options.enable_args_parsing()
from ..cli_args import args
from .main_pre import args
async def main():
@ -17,28 +13,19 @@ async def main():
args.distributed_queue_frontend = False
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info(f"Set cuda device to: {args.cuda_device}")
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
# configure paths
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
logging.info(f"Setting output directory to: {output_dir}")
from ..cmd import folder_paths
folder_paths.set_output_directory(output_dir)
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}")
from ..cmd import folder_paths
folder_paths.set_input_directory(input_dir)
if args.temp_directory:

View File

@ -0,0 +1,8 @@
from typing import NamedTuple
from torch import Tensor
class RgbMaskTuple(NamedTuple):
rgb: Tensor
mask: Tensor

19
comfy/images.py Normal file
View File

@ -0,0 +1,19 @@
import os.path
from contextlib import contextmanager
import cv2
from PIL import Image
def _open_exr(exr_path) -> Image.Image:
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR))
@contextmanager
def open_image(file_path: str) -> Image.Image:
_, ext = os.path.splitext(file_path)
if ext == ".exr":
yield _open_exr(file_path)
else:
with Image.open(file_path) as image:
yield image

View File

@ -755,7 +755,32 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
"""
Portions of this function are adapted from the repository
https://github.com/Carzit/sd-webui-samplers-scheduler
MIT License
Copyright (c) 2023 Carzit
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_end = sigmas[-1]

View File

@ -51,7 +51,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
else:
linked_filename = None
try:
os.symlink(os.path.join(destination,known_file.filename), linked_filename)
os.symlink(os.path.join(destination, known_file.filename), linked_filename)
except Exception as exc_info:
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info)
else:
@ -213,6 +213,7 @@ KNOWN_CONTROLNETS = [
HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_faid_vidit.safetensors"),
HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_depth_zeed.safetensors"),
HuggingFile("lllyasviel/sd_control_collection", "sargezt_xl_softedge.safetensors"),
HuggingFile("SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe", "depth-zoe-xl-v1.0-controlnet.safetensors"),
HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_canny.safetensors"),
HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_midas.safetensors"),
HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_diffusers_xl_depth_zoe.safetensors"),

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import dataclasses
from os.path import split
from typing import Optional, List
from typing import Optional, List, Sequence
from typing_extensions import TypedDict, NotRequired
@ -15,10 +15,12 @@ class CivitFile:
model_id (int): The ID of the model
model_version_id (int): The version
filename (str): The name of the file in the model
trigger_words (List[str]): Trigger words associated with the model
"""
model_id: int
model_version_id: int
filename: str
trigger_words: Optional[Sequence[str]] = dataclasses.field(default_factory=tuple)
def __str__(self):
return self.filename

View File

@ -9,6 +9,7 @@ import logging
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
from natsort import natsorted
from pkg_resources import resource_filename
import numpy as np
import safetensors.torch
@ -23,10 +24,13 @@ from .. import model_management
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
from ..open_exr import load_exr
class CLIPTextEncode:
@classmethod
@ -1454,38 +1458,49 @@ class PreviewImage(SaveImage):
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
class LoadImage:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"image": (sorted(files), {"image_upload": True})},
}
return {
"required": {
"image": (natsorted(files), {"image_upload": True}),
},
}
CATEGORY = "image"
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
def load_image(self, image: str):
image_path = folder_paths.get_annotated_filepath(image)
img = Image.open(image_path)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
# maintain the legacy path
# this will ultimately return a tensor, so we'd rather have the tensors directly
# from cv2 rather than get them out of a PIL image
_, ext = os.path.splitext(image)
if ext == ".exr":
return load_exr(image_path, srgb=False)
with open_image(image_path) as img:
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
@ -1494,7 +1509,7 @@ class LoadImage:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
return output_image, output_mask
@classmethod
def IS_CHANGED(s, image):

View File

@ -53,7 +53,7 @@ BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]

86
comfy/open_exr.py Normal file
View File

@ -0,0 +1,86 @@
"""
Portions of this code are adapted from the repository
https://github.com/spacepxl/ComfyUI-HQ-Image-Save
MIT License
Copyright (c) 2023 spacepxl
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import copy
from typing import Sequence, Tuple
import cv2 as cv
import numpy as np
import torch
from torch import Tensor
from .component_model.images_types import RgbMaskTuple
def mut_srgb_to_linear(np_array) -> None:
less = np_array <= 0.0404482362771082
np_array[less] = np_array[less] / 12.92
np_array[~less] = np.power((np_array[~less] + 0.055) / 1.055, 2.4)
def mut_linear_to_srgb(np_array) -> None:
less = np_array <= 0.0031308
np_array[less] = np_array[less] * 12.92
np_array[~less] = np.power(np_array[~less], 1 / 2.4) * 1.055 - 0.055
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
rgb = np.flip(image[:, :, :3], 2).copy()
if srgb:
mut_linear_to_srgb(rgb)
rgb = np.clip(rgb, 0, 1)
rgb = torch.unsqueeze(torch.from_numpy(rgb), 0)
mask = torch.zeros((1, image.shape[0], image.shape[1]), dtype=torch.float32)
if image.shape[2] > 3:
mask[0] = torch.from_numpy(np.clip(image[:, :, 3], 0, 1))
return RgbMaskTuple(rgb, mask)
def load_exr_latent(file_path: str) -> Tuple[Tensor]:
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
image = image[:, :, np.array([2, 1, 0, 3])]
image = torch.unsqueeze(torch.from_numpy(image), 0)
image = torch.movedim(image, -1, 1)
return image,
def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linear"):
linear = images.detach().clone().cpu().numpy().astype(np.float32)
if colorspace == "linear":
mut_srgb_to_linear(linear[:, :, :, :3]) # only convert RGB, not Alpha
bgr = copy.deepcopy(linear)
bgr[:, :, :, 0] = linear[:, :, :, 2] # flip RGB to BGR for opencv
bgr[:, :, :, 2] = linear[:, :, :, 0]
if bgr.shape[-1] > 3:
bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha
for i in range(len(linear.shape[0])):
cv.imwrite(filepaths_batched[i], bgr[i])

View File

@ -334,7 +334,7 @@ def get_attr(obj, attr):
def bislerp(samples, width, height):
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
c = b1.shape[-1]
#norms
@ -359,16 +359,16 @@ def bislerp(samples, width, height):
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
#edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
@ -379,7 +379,7 @@ def bislerp(samples, width, height):
samples = samples.float()
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
@ -496,6 +496,17 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function
class _DisabledProgressBar:
def __init__(self, *args, **kwargs):
pass
def update(self, *args, **kwargs):
pass
def update_absolute(self, *args, **kwargs):
pass
class ProgressBar:
def __init__(self, total: float):
global PROGRESS_BAR_HOOK
@ -545,3 +556,12 @@ def comfy_tqdm():
# Restore original tqdm
tqdm.__init__ = _original_init
tqdm.update = _original_update
@contextmanager
def comfy_progress(total: float) -> ProgressBar:
global PROGRESS_BAR_ENABLED
if PROGRESS_BAR_ENABLED:
yield ProgressBar(total)
else:
yield _DisabledProgressBar()

View File

@ -469,7 +469,7 @@ export const ComfyWidgets = {
const fileInput = document.createElement("input");
Object.assign(fileInput, {
type: "file",
accept: "image/jpeg,image/png,image/webp",
accept: "image/jpeg,image/png,image/webp,image/x-exr,.exr",
style: "display: none",
onchange: async () => {
if (fileInput.files.length) {

View File

@ -0,0 +1,85 @@
import cv2
import numpy as np
import torch
from torch import Tensor
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
_available_colormaps = ["Grayscale"] + [attr for attr in dir(cv2) if attr.startswith('COLORMAP')]
class ImageApplyColorMap(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"image": ("IMAGE", {}),
"colormap": (_available_colormaps, {"default": "COLORMAP_INFERNO"}),
"gamma": ("FLOAT", {"default": 1.0, "min": 0.001, "step": 0.001, "round": 0.001}),
"min_depth": ("FLOAT", {"default": 0.001, "min": 0.001, "round": 0.00001, "step": 0.001}),
"max_depth": ("FLOAT", {"default": 1e2, "round": 0.00001, "step": 0.1}),
"one_minus": ("BOOLEAN", {"default": False}),
"clip_min": ("BOOLEAN", {"default": True}),
"clip_max": ("BOOLEAN", {"default": False}),
}
}
RETURN_TYPES = ("IMAGE",)
CATEGORY = "image/postprocessing"
FUNCTION = "execute"
def execute(self,
image: Tensor,
gamma: float = 1.0,
min_depth: float = 0.001,
max_depth: float = 1e3,
colormap: str = "COLORMAP_INFERNO",
one_minus: bool = False,
clip_min: bool = True,
clip_max: bool = False,
) -> ValidatedNodeResult:
"""
Invert and apply a colormap to a batch of absolute distance depth images.
For Zoe and Midas, set colormap to be `COLORMAP_INFERNO`. Diffusers Depth expects `Grayscale`.
As per https://huggingface.co/SargeZT/controlnet-v1e-sdxl-depth/discussions/7, some ControlNet checkpoints
expect one_minus to be true.
"""
colored_images = []
for i in range(image.shape[0]):
depth_image = image[i, :, :, 0].numpy()
depth_image = np.where(depth_image <= min_depth, np.nan if not clip_min else min_depth, depth_image)
if clip_max:
depth_image = np.where(depth_image >= max_depth, max_depth, depth_image)
depth_image = np.power(depth_image, 1.0 / gamma)
inv_depth_image = 1.0 / depth_image
xp = [1.0 / max_depth, 1.0 / min_depth]
fp = [0, 1]
normalized_depth = np.interp(inv_depth_image, xp, fp, left=0, right=1)
normalized_depth = np.nan_to_num(normalized_depth, nan=0)
normalized_depth_uint8 = (normalized_depth * 255).astype(np.uint8)
if one_minus:
normalized_depth_uint8 = 255 - normalized_depth_uint8
if colormap == "Grayscale":
colored_image = normalized_depth_uint8
else:
cv2_colormap = getattr(cv2, colormap)
colored_image = cv2.applyColorMap(normalized_depth_uint8, cv2_colormap)
colored_image_rgb = cv2.cvtColor(colored_image, cv2.COLOR_BGR2RGB)
rgb_tensor = torch.tensor(colored_image_rgb) * 1.0 / 255.0
colored_images.append(rgb_tensor)
return torch.stack(colored_images),
NODE_CLASS_MAPPINGS = {
ImageApplyColorMap.__name__: ImageApplyColorMap,
}
NODE_DISPLAY_NAME_MAPPINGS = {
ImageApplyColorMap.__name__: "Apply ColorMap to Image (CV2)",
}

View File

@ -0,0 +1,514 @@
from functools import reduce
from operator import add, mul, pow
from comfy.nodes.package_typing import CustomNode, InputTypes
class FloatAdd(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("FLOAT", {})}
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (reduce(add, kwargs.values(), 0.0),)
class FloatSubtract(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value0": ("FLOAT", {}),
"value1": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, value0, value1):
return (value0 - value1,)
class FloatMultiply(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("FLOAT", {})}
range_.update({f"value{i}": ("FLOAT", {"default": 1.0}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (reduce(mul, kwargs.values(), 1.0),)
class FloatDivide(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value0": ("FLOAT", {}),
"value1": ("FLOAT", {"default": 1.0}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, value0, value1):
return (value0 / value1 if value1 != 0 else float("inf"),)
class FloatPower(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"base": ("FLOAT", {}),
"exponent": ("FLOAT", {"default": 1.0}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, base, exponent):
return (pow(base, exponent),)
class IntAdd(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("INT", {})}
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (reduce(add, kwargs.values(), 0),)
class IntSubtract(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value0": ("INT", {}),
"value1": ("INT", {"default": 0}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, value0, value1):
return (value0 - value1,)
class IntMultiply(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("INT", {})}
range_.update({f"value{i}": ("INT", {"default": 1}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (reduce(mul, kwargs.values(), 1),)
class IntDivide(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value0": ("INT", {}),
"value1": ("INT", {"default": 1}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, value0, value1):
return (value0 // value1 if value1 != 0 else 0,)
class IntMod(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value0": ("INT", {}),
"value1": ("INT", {"default": 1}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, value0, value1):
return (value0 % value1 if value1 != 0 else 0,)
class IntPower(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"base": ("INT", {}),
"exponent": ("INT", {"default": 1}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, base, exponent):
return (pow(base, exponent),)
class FloatMin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("FLOAT", {})}
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (min(kwargs.values()),)
class FloatMax(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("FLOAT", {})}
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (max(kwargs.values()),)
class FloatAbs(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("FLOAT", {})
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, value):
return (abs(value),)
class FloatAverage(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("FLOAT", {})}
range_.update({f"value{i}": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (sum(kwargs.values()) / len(kwargs),)
class IntMin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("INT", {})}
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (min(kwargs.values()),)
class IntMax(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("INT", {})}
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (max(kwargs.values()),)
class IntAbs(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("INT", {})
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, value):
return (abs(value),)
class IntAverage(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": ("INT", {})}
range_.update({f"value{i}": ("INT", {"default": 0}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, **kwargs):
return (sum(kwargs.values()) // len(kwargs),)
class FloatLerp(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
"b": ("FLOAT", {"default": 1.0}),
"t": ("FLOAT", {}),
"clamped": ("BOOLEAN", {"default": True}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, a, b, t, clamped):
value = a + (b - a) * t
if clamped:
value = min(max(value, a), b)
return (value,)
class FloatInverseLerp(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"a": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
"b": ("FLOAT", {"default": 1.0}),
"value": ("FLOAT", {}),
"clamped": ("BOOLEAN", {"default": True}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, a, b, value, clamped):
if a == b:
return (0.0,)
t = (value - a) / (b - a)
if clamped:
t = min(max(t, 0.0), 1.0)
return (t,)
class FloatClamp(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("FLOAT", {}),
"min": ("FLOAT", {"default": 0.0, "step": 0.01, "round": 0.000001}),
"max": ("FLOAT", {"default": 1.0}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, value: float = 0, **kwargs):
v_min: float = kwargs['min']
v_max: float = kwargs['max']
return (min(max(value, v_min), v_max),)
class IntLerp(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"a": ("INT", {"default": 0}),
"b": ("INT", {"default": 10}),
"t": ("FLOAT", {}),
"clamped": ("BOOLEAN", {"default": True}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, a, b, t, clamped):
value = int(round(a + (b - a) * t))
if clamped:
value = min(max(value, a), b)
return (value,)
class IntInverseLerp(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"a": ("INT", {"default": 0}),
"b": ("INT", {"default": 10}),
"value": ("INT", {}),
"clamped": ("BOOLEAN", {"default": True}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, a, b, value, clamped):
if a == b:
return (0.0,)
t = (value - a) / (b - a)
if clamped:
t = min(max(t, 0.0), 1.0)
return (t,)
class IntClamp(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("INT", {}),
"min": ("INT", {"default": 0}),
"max": ("INT", {"default": 1}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, value: int = 0, **kwargs):
v_min: int = kwargs['min']
v_max: int = kwargs['max']
return (min(max(value, v_min), v_max),)
NODE_CLASS_MAPPINGS = {}
for cls in (
FloatAdd,
FloatSubtract,
FloatMultiply,
FloatDivide,
FloatPower,
FloatMin,
FloatMax,
FloatAbs,
FloatAverage,
FloatLerp,
FloatInverseLerp,
FloatClamp,
IntAdd,
IntSubtract,
IntMultiply,
IntDivide,
IntMod,
IntPower,
IntMin,
IntMax,
IntAbs,
IntAverage,
IntLerp,
IntInverseLerp,
IntClamp,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -1,4 +1,28 @@
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
"""
Portions of this code are adapted from the repository
https://github.com/ChenyangSi/FreeU
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import torch
import logging

View File

@ -0,0 +1,48 @@
from torch import Tensor
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
class ImageMin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"image": ("IMAGE", {})
}
}
RETURN_TYPES = ("FLOAT",)
CATEGORY = "image/postprocessing"
FUNCTION = "execute"
def execute(self, image: Tensor) -> ValidatedNodeResult:
return float(image.min().item()),
class ImageMax(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"image": ("IMAGE", {})
}
}
RETURN_TYPES = ("FLOAT",)
CATEGORY = "image/postprocessing"
FUNCTION = "execute"
def execute(self, image: Tensor) -> ValidatedNodeResult:
return float(image.max().item()),
NODE_CLASS_MAPPINGS = {
ImageMin.__name__: ImageMin,
ImageMax.__name__: ImageMax,
}
NODE_DISPLAY_NAME_MAPPINGS = {
ImageMin.__name__: "Image Minimum Value",
ImageMax.__name__: "Image Maximum Value"
}

View File

@ -16,12 +16,12 @@ import fsspec
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from fsspec.core import OpenFiles, OpenFile
from fsspec.core import OpenFile
from fsspec.generic import GenericFileSystem
from fsspec.implementations.local import LocalFileSystem
from joblib import Parallel, delayed
from torch import Tensor
from natsort import natsorted
from torch import Tensor
from comfy.cmd import folder_paths
from comfy.digest import digest

View File

@ -97,8 +97,8 @@ class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01, "round": 0.01}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

View File

@ -35,4 +35,5 @@ huggingface_hub
lazy-object-proxy
can_ada
fsspec
natsort
natsort
OpenEXR

View File

@ -0,0 +1,198 @@
import pytest
from comfy_extras.nodes.nodes_arithmetic import IntAdd, IntSubtract, IntMultiply, IntDivide, IntMod, IntPower, FloatAdd, FloatSubtract, FloatMultiply, FloatDivide, FloatPower, FloatMin, FloatMax, FloatAbs, FloatAverage, IntMin, IntMax, IntAbs, IntAverage, FloatLerp, IntLerp, IntClamp, IntInverseLerp, FloatClamp, FloatInverseLerp
def test_int_add():
n = IntAdd()
res, = n.execute(value0=1, value1=2, value2=3)
assert res == 6
def test_int_subtract():
n = IntSubtract()
res, = n.execute(value0=10, value1=3)
assert res == 7
def test_int_multiply():
n = IntMultiply()
res, = n.execute(value0=2, value1=3, value2=4)
assert res == 24
def test_int_divide():
n = IntDivide()
res, = n.execute(value0=10, value1=3)
assert res == 3
res, = n.execute(value0=10, value1=0)
assert res == 0
def test_int_mod():
n = IntMod()
res, = n.execute(value0=10, value1=3)
assert res == 1
res, = n.execute(value0=10, value1=0)
assert res == 0
def test_int_power():
n = IntPower()
res, = n.execute(base=2, exponent=3)
assert res == 8
def test_float_add():
n = FloatAdd()
res, = n.execute(value0=1.5, value1=2.3, value2=3.7)
assert pytest.approx(res) == 7.5
def test_float_subtract():
n = FloatSubtract()
res, = n.execute(value0=10.5, value1=3.2)
assert pytest.approx(res) == 7.3
def test_float_multiply():
n = FloatMultiply()
res, = n.execute(value0=2.5, value1=3.0, value2=4.0)
assert pytest.approx(res) == 30.0
def test_float_divide():
n = FloatDivide()
res, = n.execute(value0=10.0, value1=4.0)
assert pytest.approx(res) == 2.5
res, = n.execute(value0=10.0, value1=0.0)
assert res == float("inf")
def test_float_power():
n = FloatPower()
res, = n.execute(base=2.5, exponent=3.0)
assert pytest.approx(res) == 15.625
def test_float_min():
n = FloatMin()
res, = n.execute(value0=1.5, value1=2.3, value2=0.7)
assert res == 0.7
def test_float_max():
n = FloatMax()
res, = n.execute(value0=1.5, value1=2.3, value2=0.7)
assert res == 2.3
def test_float_abs():
n = FloatAbs()
res, = n.execute(value=-3.14)
assert res == 3.14
def test_float_average():
n = FloatAverage()
res, = n.execute(value0=1.5, value1=2.5, value2=3.5)
assert res == 2.5
def test_int_min():
n = IntMin()
res, = n.execute(value0=5, value1=2, value2=7)
assert res == 2
def test_int_max():
n = IntMax()
res, = n.execute(value0=5, value1=2, value2=7)
assert res == 7
def test_int_abs():
n = IntAbs()
res, = n.execute(value=-10)
assert res == 10
def test_int_average():
n = IntAverage()
res, = n.execute(value0=2, value1=4, value2=6)
assert res == 4
def test_float_lerp():
n = FloatLerp()
res, = n.execute(a=0.0, b=1.0, t=0.5, clamped=True)
assert res == 0.5
res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=True)
assert res == 1.0
res, = n.execute(a=0.0, b=1.0, t=1.5, clamped=False)
assert res == 1.5
def test_int_lerp():
n = IntLerp()
res, = n.execute(a=0, b=10, t=0.5, clamped=True)
assert res == 5
res, = n.execute(a=0, b=10, t=1.5, clamped=True)
assert res == 10
res, = n.execute(a=0, b=10, t=1.5, clamped=False)
assert res == 15
def test_float_inverse_lerp():
n = FloatInverseLerp()
res, = n.execute(a=0.0, b=1.0, value=0.5, clamped=True)
assert res == 0.5
res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=True)
assert res == 1.0
res, = n.execute(a=0.0, b=1.0, value=1.5, clamped=False)
assert res == 1.5
def test_float_clamp():
n = FloatClamp()
res, = n.execute(value=0.5, min=0.0, max=1.0)
assert res == 0.5
res, = n.execute(value=1.5, min=0.0, max=1.0)
assert res == 1.0
res, = n.execute(value=-0.5, min=0.0, max=1.0)
assert res == 0.0
def test_int_inverse_lerp():
n = IntInverseLerp()
res, = n.execute(a=0, b=10, value=5, clamped=True)
assert res == 0.5
res, = n.execute(a=0, b=10, value=15, clamped=True)
assert res == 1.0
res, = n.execute(a=0, b=10, value=15, clamped=False)
assert res == 1.5
def test_int_clamp():
n = IntClamp()
res, = n.execute(value=5, min=0, max=10)
assert res == 5
res, = n.execute(value=15, min=0, max=10)
assert res == 10
res, = n.execute(value=-5, min=0, max=10)
assert res == 0

View File

@ -0,0 +1,49 @@
import pytest
import torch
from comfy_extras.nodes.nodes_apply_color_map import ImageApplyColorMap
@pytest.fixture
def input_image():
# Create a 1x1x2x1 tensor representing an image with absolute distances of 1.3 meters and 300 meters
return torch.tensor([[[[1.3], [300.0]]]], dtype=torch.float32)
def test_apply_colormap_grayscale(input_image):
node = ImageApplyColorMap()
colored_image, = node.execute(image=input_image, colormap="Grayscale", min_depth=1.3, max_depth=300.0)
assert colored_image.shape == (1, 1, 2, 3)
assert colored_image.dtype == torch.float32
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([1.0, 1.0, 1.0]))
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0]))
def test_apply_colormap_inferno(input_image):
node = ImageApplyColorMap()
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", min_depth=1.3, max_depth=300.0)
assert colored_image.shape == (1, 1, 2, 3)
assert colored_image.dtype == torch.float32
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.000, 0.6431]), atol=1e-4)
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)
def test_apply_colormap_clipping(input_image):
node = ImageApplyColorMap()
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=False, min_depth=1.3, max_depth=300.0)
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=False, min_depth=1.3, max_depth=300.0)
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4)
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=False, clip_max=True, min_depth=1.3, max_depth=200.0)
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0, 0.0, 0.0157]), atol=1e-4)
colored_image, = node.execute(image=input_image, colormap="COLORMAP_INFERNO", clip_min=True, clip_max=True, min_depth=1.3, max_depth=200.0)
assert torch.allclose(colored_image[0, 0, 0], torch.tensor([0.9882, 1.0000, 0.6431]), atol=1e-4)
assert torch.allclose(colored_image[0, 0, 1], torch.tensor([0.0000, 0.0000, 0.0157]), atol=1e-4)