diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index a884cfc83..7018b2170 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -1580,8 +1580,6 @@ class LoadImage: def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]: image_path = folder_paths.get_annotated_filepath(image) - img = node_helpers.pillow(Image.open, image_path) - output_images = [] output_masks = [] w, h = None, None diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index 7fe573982..5f281e641 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -2,16 +2,60 @@ import json import os from typing import Literal, Tuple -import cv2 import numpy as np import torch from PIL import Image from PIL.PngImagePlugin import PngInfo +from comfy import utils from comfy.cli_args import args from comfy.cmd import folder_paths -from comfy.component_model.tensor_types import ImageBatch +from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch +from comfy.nodes.base_nodes import ImageScale from comfy.nodes.common import MAX_RESOLUTION +from comfy.nodes.package_typing import CustomNode + + +def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch: + """ + Apply a levels adjustment to an sRGB image. + + Args: + image (torch.Tensor): Input image tensor of shape (B, H, W, C) with values in range [0, 1] + black_level (float): Black point (default: 0.0) + mid_level (float): Midtone point (default: 0.5) + white_level (float): White point (default: 1.0) + clip (bool): Whether to clip the output values to [0, 1] range (default: True) + + Returns: + torch.Tensor: Adjusted image tensor of shape (B, H, W, C) + """ + # Ensure input is in correct shape and range + assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)" + assert 0 <= black_level < mid_level < white_level <= 1, "Levels should be in ascending order in range [0, 1]" + + def srgb_to_linear(x): + return torch.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + def linear_to_srgb(x): + return torch.where(x <= 0.0031308, x * 12.92, 1.055 * x ** (1 / 2.4) - 0.055) + + linear = srgb_to_linear(image) + + adjusted = (linear - black_level) / (white_level - black_level) + + power_factor = torch.log2(torch.tensor(0.5, device=image.device)) / torch.log2(torch.tensor(mid_level, device=image.device)) + + # apply power function to avoid nans + adjusted = torch.where(adjusted > 0, torch.pow(adjusted.clamp(min=1e-8), power_factor), adjusted) + + result = linear_to_srgb(adjusted) + + if clip: + result = torch.clamp(result, 0.0, 1.0) + + return result + class ImageCrop: @@ -228,7 +272,8 @@ class ImageResize: "required": { "image": ("IMAGE",), "resize_mode": (["cover", "contain", "auto"], {"default": "cover"}), - "resolutions": (["SDXL/SD3/Flux", "SD1.5", ], {"default": "SDXL/SD3/Flux"}) + "resolutions": (["SDXL/SD3/Flux", "SD1.5"], {"default": "SDXL/SD3/Flux"}), + "interpolation": (ImageScale.upscale_methods, {"default": "bilinear"}), } } @@ -236,7 +281,7 @@ class ImageResize: FUNCTION = "resize_image" CATEGORY = "image/transform" - def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5",]) -> Tuple[ImageBatch]: + def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str) -> Tuple[RGBImageBatch]: if resolutions == "SDXL/SD3/Flux": supported_resolutions = [ (640, 1536), @@ -256,58 +301,112 @@ class ImageResize: resized_images = [] for img in image: - img_np = (img.cpu().numpy() * 255).astype(np.uint8) - h, w = img_np.shape[:2] + h, w = img.shape[:2] current_aspect_ratio = w / h target_resolution = min(supported_resolutions, key=lambda res: abs(res[0] / res[1] - current_aspect_ratio)) - scale_w, scale_h = target_resolution[0] / w, target_resolution[1] / h if resize_mode == "cover": - scale = max(scale_w, scale_h) + scale = max(target_resolution[0] / w, target_resolution[1] / h) new_w, new_h = int(w * scale), int(h * scale) - resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) - x1 = (new_w - target_resolution[0]) // 2 - y1 = (new_h - target_resolution[1]) // 2 - resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]] elif resize_mode == "contain": - scale = min(scale_w, scale_h) + scale = min(target_resolution[0] / w, target_resolution[1] / h) new_w, new_h = int(w * scale), int(h * scale) - resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) - canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8) - x1 = (target_resolution[0] - new_w) // 2 - y1 = (target_resolution[1] - new_h) // 2 - canvas[y1:y1 + new_h, x1:x1 + new_w] = resized - resized = canvas - else: + else: # auto if current_aspect_ratio > target_resolution[0] / target_resolution[1]: - scale = scale_w + new_w, new_h = target_resolution[0], int(h * target_resolution[0] / w) else: - scale = scale_h - new_w, new_h = int(w * scale), int(h * scale) - resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) - if new_w > target_resolution[0] or new_h > target_resolution[1]: - x1 = (new_w - target_resolution[0]) // 2 - y1 = (new_h - target_resolution[1]) // 2 - resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]] - else: - canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8) - x1 = (target_resolution[0] - new_w) // 2 + new_w, new_h = int(w * target_resolution[1] / h), target_resolution[1] + + # convert to b, c, h, w + img_tensor = img.permute(2, 0, 1).unsqueeze(0) + + # Use common_upscale for resizing + resized = utils.common_upscale(img_tensor, new_w, new_h, interpolation, "disabled") + + # handle padding or cropping + if resize_mode == "contain": + canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype) + y1 = (target_resolution[1] - new_h) // 2 + x1 = (target_resolution[0] - new_w) // 2 + canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized + resized = canvas + elif resize_mode == "cover": + y1 = (new_h - target_resolution[1]) // 2 + x1 = (new_w - target_resolution[0]) // 2 + resized = resized[:, :, y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]] + else: # auto + if new_w != target_resolution[0] or new_h != target_resolution[1]: + canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype) y1 = (target_resolution[1] - new_h) // 2 - canvas[y1:y1 + new_h, x1:x1 + new_w] = resized + x1 = (target_resolution[0] - new_w) // 2 + canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized resized = canvas - resized_images.append(resized) + resized_images.append(resized.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0)) - return (torch.from_numpy(np.stack(resized_images)).float() / 255.0,) + return (torch.stack(resized_images),) + + +class ImageLevels(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "black_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "mid_level": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 0.99, "step": 0.01}), + "white_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "clip": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "apply_levels" + CATEGORY = "image/adjust" + + def apply_levels(self, image: ImageBatch, black_level: float, mid_level: float, white_level: float, clip: bool) -> Tuple[ImageBatch]: + adjusted_image = levels_adjustment(image, black_level, mid_level, white_level, clip) + return (adjusted_image,) + + +class ImageLuminance: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "compute_luminance" + CATEGORY = "image/color" + + def compute_luminance(self, image: ImageBatch) -> Tuple[ImageBatch]: + assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)" + + # define srgb luminance coefficients + coeffs = torch.tensor([0.2126, 0.7152, 0.0722], device=image.device, dtype=image.dtype) + + luminance = torch.sum(image * coeffs, dim=-1, keepdim=True) + luminance = luminance.expand(-1, -1, -1, 3) + + return (luminance,) NODE_CLASS_MAPPINGS = { "ImageResize": ImageResize, "ImageShape": ImageShape, "ImageCrop": ImageCrop, + "ImageLevels": ImageLevels, + "ImageLuminance": ImageLuminance, "RepeatImageBatch": RepeatImageBatch, "ImageFromBatch": ImageFromBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, } + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageResize": "Fit Image to Diffusion Size" +}