From 8a93da95174c9ea1fdaf61d6a9dd57677a17dcf2 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Tue, 19 May 2026 10:49:26 +0800 Subject: [PATCH] Update comfy_extras/nodes_images.py --- comfy_extras/nodes_images.py | 2530 +++++++++++++++++----------------- 1 file changed, 1265 insertions(+), 1265 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index fd743b611..2e7bc2400 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -1,1266 +1,1266 @@ -from __future__ import annotations - -import nodes -import folder_paths - -import av -import json - -import os -import re -import math -import numpy as np -import struct -import torch - -import zlib -import comfy.utils -from fractions import Fraction - -from server import PromptServer -from comfy_api.latest import ComfyExtension, IO, UI -from comfy.cli_args import args -from typing_extensions import override - -SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. - -MAX_RESOLUTION = nodes.MAX_RESOLUTION - -class ImageCrop(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageCrop", - search_aliases=["trim"], - display_name="Crop Image (DEPRECATED)", - category="image/transform", - is_deprecated=True, - essentials_category="Image Tools", - inputs=[ - IO.Image.Input("image"), - IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), - IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), - IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), - IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, width, height, x, y) -> IO.NodeOutput: - x = min(x, image.shape[2] - 1) - y = min(y, image.shape[1] - 1) - to_x = width + x - to_y = height + y - img = image[:,y:to_y, x:to_x, :] - return IO.NodeOutput(img) - - crop = execute # TODO: remove - - -class ImageCropV2(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageCropV2", - search_aliases=["trim"], - display_name="Crop Image", - category="image/transform", - essentials_category="Image Tools", - has_intermediate_output=True, - inputs=[ - IO.Image.Input("image"), - IO.BoundingBox.Input("crop_region", component="ImageCrop"), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, crop_region) -> IO.NodeOutput: - x = crop_region.get("x", 0) - y = crop_region.get("y", 0) - width = crop_region.get("width", 512) - height = crop_region.get("height", 512) - - x = min(x, image.shape[2] - 1) - y = min(y, image.shape[1] - 1) - to_x = width + x - to_y = height + y - img = image[:,y:to_y, x:to_x, :] - return IO.NodeOutput(img, ui=UI.PreviewImage(img)) - - -class BoundingBox(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="PrimitiveBoundingBox", - display_name="Bounding Box", - category="utils/primitive", - inputs=[ - IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION), - IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION), - IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION), - IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION), - ], - outputs=[IO.BoundingBox.Output()], - ) - - @classmethod - def execute(cls, x, y, width, height) -> IO.NodeOutput: - return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height}) - - -class RepeatImageBatch(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="RepeatImageBatch", - search_aliases=["duplicate image", "clone image"], - display_name="Repeat Image Batch", - category="image/batch", - inputs=[ - IO.Image.Input("image"), - IO.Int.Input("amount", default=1, min=1, max=4096), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, amount) -> IO.NodeOutput: - s = image.repeat((amount, 1,1,1)) - return IO.NodeOutput(s) - - repeat = execute # TODO: remove - - -class ImageFromBatch(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageFromBatch", - search_aliases=["select image", "pick from batch", "extract image"], - display_name="Get Image from Batch", - category="image/batch", - inputs=[ - IO.Image.Input("image"), - IO.Int.Input("batch_index", default=0, min=0, max=4095), - IO.Int.Input("length", default=1, min=1, max=4096), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, batch_index, length) -> IO.NodeOutput: - s_in = image - batch_index = min(s_in.shape[0] - 1, batch_index) - length = min(s_in.shape[0] - batch_index, length) - s = s_in[batch_index:batch_index + length].clone() - return IO.NodeOutput(s) - - frombatch = execute # TODO: remove - - -class ImageAddNoise(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageAddNoise", - search_aliases=["film grain"], - display_name="Add Noise to Image", - category="image/postprocessing", - inputs=[ - IO.Image.Input("image"), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, seed, strength) -> IO.NodeOutput: - generator = torch.manual_seed(seed) - s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) - return IO.NodeOutput(s) - - repeat = execute # TODO: remove - - -class SaveAnimatedWEBP(IO.ComfyNode): - COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="SaveAnimatedWEBP", - category="image/animation", - inputs=[ - IO.Image.Input("images"), - IO.String.Input("filename_prefix", default="ComfyUI"), - IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), - IO.Boolean.Input("lossless", default=True), - IO.Int.Input("quality", default=80, min=0, max=100), - IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), - # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), - ], - hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, - ) - - @classmethod - def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: - return IO.NodeOutput( - ui=UI.ImageSaveHelper.get_save_animated_webp_ui( - images=images, - filename_prefix=filename_prefix, - cls=cls, - fps=fps, - lossless=lossless, - quality=quality, - method=cls.COMPRESS_METHODS.get(method) - ) - ) - - save_images = execute # TODO: remove - - -class SaveAnimatedPNG(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="SaveAnimatedPNG", - category="image/animation", - inputs=[ - IO.Image.Input("images"), - IO.String.Input("filename_prefix", default="ComfyUI"), - IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), - IO.Int.Input("compress_level", default=4, min=0, max=9, advanced=True), - ], - hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, - ) - - @classmethod - def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: - return IO.NodeOutput( - ui=UI.ImageSaveHelper.get_save_animated_png_ui( - images=images, - filename_prefix=filename_prefix, - cls=cls, - fps=fps, - compress_level=compress_level, - ) - ) - - save_images = execute # TODO: remove - - -class ImageStitch(IO.ComfyNode): - """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageStitch", - search_aliases=["combine images", "join images", "concatenate images", "side by side"], - display_name="Stitch Images", - description="Stitches image2 to image1 in the specified direction.\n" - "If image2 is not provided, returns image1 unchanged.\n" - "Optional spacing can be added between images.", - category="image/transform", - inputs=[ - IO.Image.Input("image1"), - IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), - IO.Boolean.Input("match_image_size", default=True), - IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2, advanced=True), - IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white", advanced=True), - IO.Image.Input("image2", optional=True), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute( - cls, - image1, - direction, - match_image_size, - spacing_width, - spacing_color, - image2=None, - ) -> IO.NodeOutput: - if image2 is None: - return IO.NodeOutput(image1) - - # Handle batch size differences - if image1.shape[0] != image2.shape[0]: - max_batch = max(image1.shape[0], image2.shape[0]) - if image1.shape[0] < max_batch: - image1 = torch.cat( - [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)] - ) - if image2.shape[0] < max_batch: - image2 = torch.cat( - [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)] - ) - - # Match image sizes if requested - if match_image_size: - h1, w1 = image1.shape[1:3] - h2, w2 = image2.shape[1:3] - aspect_ratio = w2 / h2 - - if direction in ["left", "right"]: - target_h, target_w = h1, int(h1 * aspect_ratio) - else: # up, down - target_w, target_h = w1, int(w1 / aspect_ratio) - - image2 = comfy.utils.common_upscale( - image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" - ).movedim(1, -1) - - color_map = { - "white": 1.0, - "black": 0.0, - "red": (1.0, 0.0, 0.0), - "green": (0.0, 1.0, 0.0), - "blue": (0.0, 0.0, 1.0), - } - - color_val = color_map[spacing_color] - - # When not matching sizes, pad to align non-concat dimensions - if not match_image_size: - h1, w1 = image1.shape[1:3] - h2, w2 = image2.shape[1:3] - pad_value = 0.0 - if not isinstance(color_val, tuple): - pad_value = color_val - - if direction in ["left", "right"]: - # For horizontal concat, pad heights to match - if h1 != h2: - target_h = max(h1, h2) - if h1 < target_h: - pad_h = target_h - h1 - pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) - if h2 < target_h: - pad_h = target_h - h2 - pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) - else: # up, down - # For vertical concat, pad widths to match - if w1 != w2: - target_w = max(w1, w2) - if w1 < target_w: - pad_w = target_w - w1 - pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) - if w2 < target_w: - pad_w = target_w - w2 - pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 - image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) - - # Ensure same number of channels - if image1.shape[-1] != image2.shape[-1]: - max_channels = max(image1.shape[-1], image2.shape[-1]) - if image1.shape[-1] < max_channels: - image1 = torch.cat( - [ - image1, - torch.ones( - *image1.shape[:-1], - max_channels - image1.shape[-1], - device=image1.device, - ), - ], - dim=-1, - ) - if image2.shape[-1] < max_channels: - image2 = torch.cat( - [ - image2, - torch.ones( - *image2.shape[:-1], - max_channels - image2.shape[-1], - device=image2.device, - ), - ], - dim=-1, - ) - - # Add spacing if specified - if spacing_width > 0: - spacing_width = spacing_width + (spacing_width % 2) # Ensure even - - if direction in ["left", "right"]: - spacing_shape = ( - image1.shape[0], - max(image1.shape[1], image2.shape[1]), - spacing_width, - image1.shape[-1], - ) - else: - spacing_shape = ( - image1.shape[0], - spacing_width, - max(image1.shape[2], image2.shape[2]), - image1.shape[-1], - ) - - spacing = torch.full(spacing_shape, 0.0, device=image1.device) - if isinstance(color_val, tuple): - for i, c in enumerate(color_val): - if i < spacing.shape[-1]: - spacing[..., i] = c - if spacing.shape[-1] == 4: # Add alpha - spacing[..., 3] = 1.0 - else: - spacing[..., : min(3, spacing.shape[-1])] = color_val - if spacing.shape[-1] == 4: - spacing[..., 3] = 1.0 - - # Concatenate images - images = [image2, image1] if direction in ["left", "up"] else [image1, image2] - if spacing_width > 0: - images.insert(1, spacing) - - concat_dim = 2 if direction in ["left", "right"] else 1 - return IO.NodeOutput(torch.cat(images, dim=concat_dim)) - - stitch = execute # TODO: remove - - -class ResizeAndPadImage(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ResizeAndPadImage", - search_aliases=["fit to size"], - display_name="Resize And Pad Image", - category="image/transform", - inputs=[ - IO.Image.Input("image"), - IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), - IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), - IO.Combo.Input("padding_color", options=["white", "black"], advanced=True), - IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"], advanced=True), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: - batch_size, orig_height, orig_width, channels = image.shape - - scale_w = target_width / orig_width - scale_h = target_height / orig_height - scale = min(scale_w, scale_h) - - new_width = int(orig_width * scale) - new_height = int(orig_height * scale) - - image_permuted = image.permute(0, 3, 1, 2) - - resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled") - - pad_value = 0.0 if padding_color == "black" else 1.0 - padded = torch.full( - (batch_size, channels, target_height, target_width), - pad_value, - dtype=image.dtype, - device=image.device - ) - - y_offset = (target_height - new_height) // 2 - x_offset = (target_width - new_width) // 2 - - padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized - - output = padded.permute(0, 2, 3, 1) - return IO.NodeOutput(output) - - resize_and_pad = execute # TODO: remove - - -class SaveSVGNode(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="SaveSVGNode", - search_aliases=["export vector", "save vector graphics"], - display_name="Save SVG", - description="Save SVG files on disk.", - category="image/save", - inputs=[ - IO.SVG.Input("svg"), - IO.String.Input( - "filename_prefix", - default="svg/ComfyUI", - tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", - ), - ], - hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, - ) - - @classmethod - def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) - results: list[UI.SavedResult] = [] - - # Prepare metadata JSON - metadata_dict = {} - if cls.hidden.prompt is not None: - metadata_dict["prompt"] = cls.hidden.prompt - if cls.hidden.extra_pnginfo is not None: - metadata_dict.update(cls.hidden.extra_pnginfo) - - # Convert metadata to JSON string - metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None - - - for batch_number, svg_bytes in enumerate(svg.data): - filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.svg" - - # Read SVG content - svg_bytes.seek(0) - svg_content = svg_bytes.read().decode('utf-8') - - # Inject metadata if available - if metadata_json: - # Create metadata element with CDATA section - metadata_element = f""" - - - """ - # Insert metadata after opening svg tag using regex with a replacement function - def replacement(match): - # match.group(1) contains the captured tag - return match.group(1) + '\n' + metadata_element - - # Apply the substitution - svg_content = re.sub(r'(]*>)', replacement, svg_content, flags=re.UNICODE) - - # Write the modified SVG to file - with open(os.path.join(full_output_folder, file), 'wb') as svg_file: - svg_file.write(svg_content.encode('utf-8')) - - results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) - counter += 1 - return IO.NodeOutput(ui={"images": results}) - - save_svg = execute # TODO: remove - - -class GetImageSize(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="GetImageSize", - search_aliases=["dimensions", "resolution", "image info"], - display_name="Get Image Size", - description="Returns width and height of the image, and passes it through unchanged.", - category="image", - inputs=[ - IO.Image.Input("image"), - ], - outputs=[ - IO.Int.Output(display_name="width"), - IO.Int.Output(display_name="height"), - IO.Int.Output(display_name="batch_size"), - ], - hidden=[IO.Hidden.unique_id], - ) - - @classmethod - def execute(cls, image) -> IO.NodeOutput: - height = image.shape[1] - width = image.shape[2] - batch_size = image.shape[0] - - # Send progress text to display size on the node - if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) - - return IO.NodeOutput(width, height, batch_size) - - get_size = execute # TODO: remove - - -class ImageRotate(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageRotate", - display_name="Rotate Image", - search_aliases=["turn", "flip orientation"], - category="image/transform", - essentials_category="Image Tools", - inputs=[ - IO.Image.Input("image"), - IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, rotation) -> IO.NodeOutput: - rotate_by = 0 - if rotation.startswith("90"): - rotate_by = 1 - elif rotation.startswith("180"): - rotate_by = 2 - elif rotation.startswith("270"): - rotate_by = 3 - - image = torch.rot90(image, k=rotate_by, dims=[2, 1]) - return IO.NodeOutput(image) - - rotate = execute # TODO: remove - - -class ImageFlip(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageFlip", - search_aliases=["mirror", "reflect"], - display_name="Flip Image", - category="image/transform", - inputs=[ - IO.Image.Input("image"), - IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, flip_method) -> IO.NodeOutput: - if flip_method.startswith("x"): - image = torch.flip(image, dims=[1]) - elif flip_method.startswith("y"): - image = torch.flip(image, dims=[2]) - - return IO.NodeOutput(image) - - flip = execute # TODO: remove - - -class ImageScaleToMaxDimension(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageScaleToMaxDimension", - display_name="Scale Image to Max Dimension", - category="image/upscaling", - inputs=[ - IO.Image.Input("image"), - IO.Combo.Input( - "upscale_method", - options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], - ), - IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), - ], - outputs=[IO.Image.Output()], - ) - - @classmethod - def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: - height = image.shape[1] - width = image.shape[2] - - if height > width: - width = round((width / height) * largest_size) - height = largest_size - elif width > height: - height = round((height / width) * largest_size) - width = largest_size - else: - height = largest_size - width = largest_size - - samples = image.movedim(-1, 1) - s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") - s = s.movedim(1, -1) - return IO.NodeOutput(s) - - upscale = execute # TODO: remove - - -class SplitImageToTileList(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="SplitImageToTileList", - category="image/batch", - search_aliases=["split image", "tile image", "slice image"], - display_name="Split Image into List of Tiles", - description="Splits an image into a batched list of tiles with a specified overlap.", - inputs=[ - IO.Image.Input("image"), - IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION), - IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION), - IO.Int.Input("overlap", default=128, min=0, max=4096), - ], - outputs=[ - IO.Image.Output(is_output_list=True), - ], - ) - - @staticmethod - def get_grid_coords(width, height, tile_width, tile_height, overlap): - coords = [] - stride_x = round(max(tile_width * 0.25, tile_width - overlap)) - stride_y = round(max(tile_height * 0.25, tile_height - overlap)) - - y = 0 - while y < height: - x = 0 - y_end = min(y + tile_height, height) - y_start = max(0, y_end - tile_height) - - while x < width: - x_end = min(x + tile_width, width) - x_start = max(0, x_end - tile_width) - - coords.append((x_start, y_start, x_end, y_end)) - - if x_end >= width: - break - x += stride_x - - if y_end >= height: - break - y += stride_y - - return coords - - @classmethod - def execute(cls, image, tile_width, tile_height, overlap): - b, h, w, c = image.shape - coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap) - - output_list = [] - for (x_start, y_start, x_end, y_end) in coords: - tile = image[:, y_start:y_end, x_start:x_end, :] - output_list.append(tile) - - return IO.NodeOutput(output_list) - - -class ImageMergeTileList(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ImageMergeTileList", - display_name="Merge List of Tiles to Image", - category="image/batch", - search_aliases=["split image", "tile image", "slice image"], - is_input_list=True, - inputs=[ - IO.Image.Input("image_list"), - IO.Int.Input("final_width", default=1024, min=64, max=32768), - IO.Int.Input("final_height", default=1024, min=64, max=32768), - IO.Int.Input("overlap", default=128, min=0, max=4096), - ], - outputs=[ - IO.Image.Output(is_output_list=False), - ], - ) - - @classmethod - def execute(cls, image_list, final_width, final_height, overlap): - w = final_width[0] - h = final_height[0] - ovlp = overlap[0] - feather_str = 1.0 - - first_tile = image_list[0] - b, t_h, t_w, c = first_tile.shape - device = first_tile.device - dtype = first_tile.dtype - - coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp) - - canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) - weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) - - if ovlp > 0: - y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype)) - x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype)) - y_w = torch.clamp(y_w, min=1e-5) - x_w = torch.clamp(x_w, min=1e-5) - - sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1) - flat_mask = torch.ones_like(sine_mask) - - weight_mask = torch.lerp(flat_mask, sine_mask, feather_str) - else: - weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype) - - for i, (x_start, y_start, x_end, y_end) in enumerate(coords): - if i >= len(image_list): - break - - tile = image_list[i] - - region_h = y_end - y_start - region_w = x_end - x_start - - real_h = min(region_h, tile.shape[1]) - real_w = min(region_w, tile.shape[2]) - - y_end_actual = y_start + real_h - x_end_actual = x_start + real_w - - tile_crop = tile[:, :real_h, :real_w, :] - mask_crop = weight_mask[:, :real_h, :real_w, :] - - canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop - weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop - - weights[weights == 0] = 1.0 - merged_image = canvas / weights - - return IO.NodeOutput(merged_image) - - -# --------------------------------------------------------------------------- -# Format specifications -# --------------------------------------------------------------------------- - -# Maps (file_format, bit_depth, has_alpha) -> (numpy dtype scale, av pixel format, -# stream pix_fmt). Keeps the encode path declarative instead of branchy. -_FORMAT_SPECS = { - ("png", "8-bit", False): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"}, - ("png", "8-bit", True): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"}, - ("png", "16-bit", False): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"}, - ("png", "16-bit", True): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"}, - ("exr", "32-bit float", False): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"}, - ("exr", "32-bit float", True): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"}, -} - - -# --------------------------------------------------------------------------- -# Color transforms -# --------------------------------------------------------------------------- - -def srgb_to_linear(t: torch.Tensor) -> torch.Tensor: - """Inverse sRGB EOTF (IEC 61966-2-1). Operates on RGB channels only; - alpha (if present as the 4th channel) is passed through unchanged.""" - if t.shape[-1] == 4: - rgb, alpha = t[..., :3], t[..., 3:] - return torch.cat([srgb_to_linear(rgb), alpha], dim=-1) - - # Piecewise: linear toe below 0.04045, gamma curve above. - low = t / 12.92 - high = ((t.clamp(min=0.0) + 0.055) / 1.055) ** 2.4 - return torch.where(t <= 0.04045, low, high) - - -# HLG OETF constants from BT.2100 Table 5. -_HLG_A = 0.17883277 -_HLG_B = 0.28466892 -_HLG_C = 0.55991072928 # = 0.5 - a*ln(4*a) - - -def hlg_to_linear(t: torch.Tensor) -> torch.Tensor: - """Inverse HLG OETF (BT.2100). Maps a non-linear HLG signal in [0, 1] to - *scene*-linear light in [0, 1]. Per BT.2100 Note 5a, this is the correct - transform when converting HLG to a linear scene-light representation - (rather than display-light, which would also involve the HLG OOTF). - - Operates on RGB channels only; alpha is passed through unchanged.""" - if t.shape[-1] == 4: - rgb, alpha = t[..., :3], t[..., 3:] - return torch.cat([hlg_to_linear(rgb), alpha], dim=-1) - - # Piecewise: sqrt branch below 0.5, log branch above. - # Clamp inside the log branch so negative / out-of-range values don't blow up; - # values above 1.0 are allowed and extrapolate naturally. - low = (t ** 2) / 3.0 - high = (torch.exp((t.clamp(min=_HLG_C) - _HLG_C) / _HLG_A) + _HLG_B) / 12.0 - return torch.where(t <= 0.5, low, high) - - -# --------------------------------------------------------------------------- -# Metadata injection -# --------------------------------------------------------------------------- - -_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" - - -def _png_chunk(chunk_type: bytes, data: bytes) -> bytes: - """Build a single PNG chunk: length | type | data | CRC32(type+data).""" - crc = zlib.crc32(chunk_type + data) & 0xFFFFFFFF - return struct.pack(">I", len(data)) + chunk_type + data + struct.pack(">I", crc) - - -def _png_text_chunk(keyword: str, text: str) -> bytes: - """tEXt chunk: latin-1 keyword + NUL + latin-1 text.""" - payload = keyword.encode("latin-1") + b"\x00" + text.encode("latin-1", errors="replace") - return _png_chunk(b"tEXt", payload) - - -def inject_png_metadata(png_bytes: bytes, prompt: dict | None, extra_pnginfo: dict | None) -> bytes: - """Insert ComfyUI prompt/workflow as tEXt chunks right after IHDR.""" - if not png_bytes.startswith(_PNG_SIGNATURE): - return png_bytes - - chunks: list[bytes] = [] - if prompt is not None: - chunks.append(_png_text_chunk("prompt", json.dumps(prompt))) - if extra_pnginfo: - for key, value in extra_pnginfo.items(): - chunks.append(_png_text_chunk(key, json.dumps(value))) - if not chunks: - return png_bytes - - # IHDR is always the first chunk; insert ours immediately after it. - ihdr_length = struct.unpack(">I", png_bytes[8:12])[0] - ihdr_end = 8 + 8 + ihdr_length + 4 # signature + (len+type) + data + crc - return png_bytes[:ihdr_end] + b"".join(chunks) + png_bytes[ihdr_end:] - - -# Standard chromaticities (CIE 1931 xy) for the colorspaces this node writes. -# Each tuple is (Rx, Ry, Gx, Gy, Bx, By, Wx, Wy). All share D65 white point. -_CHROMATICITIES = { - # ITU-R BT.709 / sRGB primaries - "Rec.709": (0.6400, 0.3300, 0.3000, 0.6000, 0.1500, 0.0600, 0.3127, 0.3290), - # ITU-R BT.2020 (UHDTV / wide-gamut HDR) primaries - "Rec.2020": (0.7080, 0.2920, 0.1700, 0.7970, 0.1310, 0.0460, 0.3127, 0.3290), -} - - -def _pack_chromaticities(primaries: tuple) -> bytes: - """Serialize 8 chromaticity floats into the EXR `chromaticities` payload.""" - return struct.pack("<8f", *primaries) - - -def _exr_attribute(name: str, attr_type: str, value: bytes) -> bytes: - """Serialize one EXR header attribute: name\\0 type\\0 size:int32 value.""" - return ( - name.encode("utf-8") + b"\x00" - + attr_type.encode("utf-8") + b"\x00" - + struct.pack(" bytes: - """Insert ComfyUI metadata and color-space info into an EXR header. - - Color: EXR pixels are linear by convention. The standard way to describe - their RGB→XYZ relationship is the `chromaticities` attribute. We pick the - primaries that match what the user told us their input was: - - colorspace="sRGB" → Rec. 709 / sRGB primaries (D65) - colorspace="HDR" → Rec. 2020 / BT.2100 primaries (D65) - - Pixels are always converted to linear scene light upstream (sRGB EOTF - inverse for sRGB; HLG OETF inverse for HDR), so the file content is - scene-linear in the indicated gamut. OpenEXR has no standard transfer- - function attribute (the OpenEXR TSC has discussed adding one but it - doesn't exist), so we don't invent one — `chromaticities` plus the EXR - linear-by-convention rule fully specifies the color. - - Prompt/workflow: written as plain `string` attributes using the same keys - (`prompt`, `workflow`, ...) that Comfy uses for PNG tEXt chunks, so the - same readers can pull them out symmetrically. - - Implementation note: the chunk-offset table that follows the header stores - *absolute* byte offsets into the file. Inserting N bytes into the header - means every offset must be incremented by N or the file becomes unreadable. - """ - if len(exr_bytes) < 8 or exr_bytes[:4] != b"\x76\x2f\x31\x01": - return exr_bytes - - new_blob = b"" - if prompt is not None: - new_blob += _exr_attribute("prompt", "string", json.dumps(prompt).encode("utf-8")) - if extra_pnginfo: - for key, value in extra_pnginfo.items(): - new_blob += _exr_attribute(key, "string", json.dumps(value).encode("utf-8")) - if colorspace is not None: - # Map each colorspace option to the RGB primaries the linear pixels - # are now in. "sRGB" and "linear" both produce Rec. 709 linear; "HDR" - # (HLG-encoded Rec. 2020 input) produces Rec. 2020 linear. - primaries_name = { - "sRGB": "Rec.709", - "linear": "Rec.709", - "HDR": "Rec.2020", - }.get(colorspace, "Rec.709") - new_blob += _exr_attribute( - "chromaticities", - "chromaticities", - _pack_chromaticities(_CHROMATICITIES[primaries_name]), - ) - if not new_blob: - return exr_bytes - - # Walk header attributes to find the terminating null byte, and pick up - # dataWindow + compression so we know how many chunks the offset table has. - pos = 8 # past magic (4) + version (4) - data_window = None - compression = 0 - while pos < len(exr_bytes) and exr_bytes[pos] != 0: - name_end = exr_bytes.index(b"\x00", pos) - attr_name = exr_bytes[pos:name_end].decode("latin-1", errors="replace") - type_end = exr_bytes.index(b"\x00", name_end + 1) - attr_type = exr_bytes[name_end + 1:type_end].decode("latin-1", errors="replace") - size = struct.unpack(" bytes: - """Encode a single HxWxC tensor to PNG or EXR bytes in memory. - - For EXR the input is interpreted according to `colorspace` and converted - to scene-linear (EXR's convention) before writing: - - "sRGB" → input is sRGB-encoded Rec. 709; apply inverse sRGB EOTF. - "HDR" → input is HLG-encoded Rec. 2020 (BT.2100); apply inverse HLG - OETF to get scene-linear, per BT.2100 Note 5a. - "linear" → input is already scene-linear (Rec. 709 primaries); write - through unchanged. Use this for renderer/compositor output. - - For PNG, colorspace selection does not modify pixels — PNG is delivered - sRGB-encoded and there is no PNG path for wide-gamut HDR in this node. - """ - height, width, num_channels = img_tensor.shape - has_alpha = num_channels == 4 - - spec = _FORMAT_SPECS[(file_format, bit_depth, has_alpha)] - - if spec["dtype"] == np.float32: - # EXR path: preserve full range, no clamp. - if colorspace == "sRGB": - img_tensor = srgb_to_linear(img_tensor) - elif colorspace == "HDR": - img_tensor = hlg_to_linear(img_tensor) - img_np = img_tensor.cpu().numpy().astype(np.float32) - else: - # PNG path: quantize to integer range. - scaled = (img_tensor * spec["scale"]).clamp(0, spec["scale"]) - img_np = scaled.to(torch.int32).cpu().numpy().astype(spec["dtype"]) - - # Encode directly via CodecContext. PyAV's `image2` muxer does NOT write to - # BytesIO (it expects a real file path), so we bypass the container entirely. - # For single-frame PNG/EXR the raw codec output IS the file. - codec = av.CodecContext.create(file_format, "w") - codec.width = width - codec.height = height - codec.pix_fmt = spec["stream_fmt"] - codec.time_base = Fraction(1, 1) - - frame = av.VideoFrame.from_ndarray(img_np, format=spec["frame_fmt"]) - if spec["frame_fmt"] != spec["stream_fmt"]: - frame = frame.reformat(format=spec["stream_fmt"]) - frame.pts = 0 - frame.time_base = codec.time_base - - packets = list(codec.encode(frame)) + list(codec.encode(None)) # flush with None - return b"".join(bytes(p) for p in packets) - - -# --------------------------------------------------------------------------- -# Node -# --------------------------------------------------------------------------- - -class SaveImageAdvanced(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="SaveImageAdvanced", - search_aliases=["save", "save image", "export image", "output image", "write image"], - display_name="Save Image (Advanced)", - description="Saves the input images to your ComfyUI output directory.", - category="image", - essentials_category="Basics", - inputs=[ - IO.Image.Input("images", tooltip="The images to save."), - IO.String.Input( - "filename_prefix", - default="ComfyUI", - tooltip=( - "The prefix for the file to save. May include formatting tokens " - "such as %date:yyyy-MM-dd% or %Empty Latent Image.width%." - ), - ), - IO.DynamicCombo.Input( - "format", - options=[ - IO.DynamicCombo.Option("png", [ - IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], - default="8-bit", advanced=True), +from __future__ import annotations + +import nodes +import folder_paths + +import av +import json + +import os +import re +import math +import numpy as np +import struct +import torch + +import zlib +import comfy.utils +from fractions import Fraction + +from server import PromptServer +from comfy_api.latest import ComfyExtension, IO, UI +from comfy.cli_args import args +from typing_extensions import override + +SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. + +MAX_RESOLUTION = nodes.MAX_RESOLUTION + +class ImageCrop(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCrop", + search_aliases=["trim"], + display_name="Crop Image (DEPRECATED)", + category="image/transform", + is_deprecated=True, + essentials_category="Image Tools", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, width, height, x, y) -> IO.NodeOutput: + x = min(x, image.shape[2] - 1) + y = min(y, image.shape[1] - 1) + to_x = width + x + to_y = height + y + img = image[:,y:to_y, x:to_x, :] + return IO.NodeOutput(img) + + crop = execute # TODO: remove + + +class ImageCropV2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCropV2", + search_aliases=["trim"], + display_name="Crop Image", + category="image/transform", + essentials_category="Image Tools", + has_intermediate_output=True, + inputs=[ + IO.Image.Input("image"), + IO.BoundingBox.Input("crop_region", component="ImageCrop"), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, crop_region) -> IO.NodeOutput: + x = crop_region.get("x", 0) + y = crop_region.get("y", 0) + width = crop_region.get("width", 512) + height = crop_region.get("height", 512) + + x = min(x, image.shape[2] - 1) + y = min(y, image.shape[1] - 1) + to_x = width + x + to_y = height + y + img = image[:,y:to_y, x:to_x, :] + return IO.NodeOutput(img, ui=UI.PreviewImage(img)) + + +class BoundingBox(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PrimitiveBoundingBox", + display_name="Bounding Box", + category="utils/primitive", + inputs=[ + IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION), + IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION), + IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION), + IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION), + ], + outputs=[IO.BoundingBox.Output()], + ) + + @classmethod + def execute(cls, x, y, width, height) -> IO.NodeOutput: + return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height}) + + +class RepeatImageBatch(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RepeatImageBatch", + search_aliases=["duplicate image", "clone image"], + display_name="Repeat Image Batch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("amount", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, amount) -> IO.NodeOutput: + s = image.repeat((amount, 1,1,1)) + return IO.NodeOutput(s) + + repeat = execute # TODO: remove + + +class ImageFromBatch(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageFromBatch", + search_aliases=["select image", "pick from batch", "extract image"], + display_name="Get Image from Batch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("batch_index", default=0, min=0, max=4095), + IO.Int.Input("length", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, batch_index, length) -> IO.NodeOutput: + s_in = image + batch_index = min(s_in.shape[0] - 1, batch_index) + length = min(s_in.shape[0] - batch_index, length) + s = s_in[batch_index:batch_index + length].clone() + return IO.NodeOutput(s) + + frombatch = execute # TODO: remove + + +class ImageAddNoise(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageAddNoise", + search_aliases=["film grain"], + display_name="Add Noise to Image", + category="image/postprocessing", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, seed, strength) -> IO.NodeOutput: + generator = torch.manual_seed(seed) + s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) + return IO.NodeOutput(s) + + repeat = execute # TODO: remove + + +class SaveAnimatedWEBP(IO.ComfyNode): + COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedWEBP", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Boolean.Input("lossless", default=True), + IO.Int.Input("quality", default=80, min=0, max=100), + IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_webp_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=cls.COMPRESS_METHODS.get(method) + ) + ) + + save_images = execute # TODO: remove + + +class SaveAnimatedPNG(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedPNG", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Int.Input("compress_level", default=4, min=0, max=9, advanced=True), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_png_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + ) + + save_images = execute # TODO: remove + + +class ImageStitch(IO.ComfyNode): + """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageStitch", + search_aliases=["combine images", "join images", "concatenate images", "side by side"], + display_name="Stitch Images", + description="Stitches image2 to image1 in the specified direction.\n" + "If image2 is not provided, returns image1 unchanged.\n" + "Optional spacing can be added between images.", + category="image/transform", + inputs=[ + IO.Image.Input("image1"), + IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), + IO.Boolean.Input("match_image_size", default=True), + IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2, advanced=True), + IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white", advanced=True), + IO.Image.Input("image2", optional=True), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute( + cls, + image1, + direction, + match_image_size, + spacing_width, + spacing_color, + image2=None, + ) -> IO.NodeOutput: + if image2 is None: + return IO.NodeOutput(image1) + + # Handle batch size differences + if image1.shape[0] != image2.shape[0]: + max_batch = max(image1.shape[0], image2.shape[0]) + if image1.shape[0] < max_batch: + image1 = torch.cat( + [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)] + ) + if image2.shape[0] < max_batch: + image2 = torch.cat( + [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)] + ) + + # Match image sizes if requested + if match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + aspect_ratio = w2 / h2 + + if direction in ["left", "right"]: + target_h, target_w = h1, int(h1 * aspect_ratio) + else: # up, down + target_w, target_h = w1, int(w1 / aspect_ratio) + + image2 = comfy.utils.common_upscale( + image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" + ).movedim(1, -1) + + color_map = { + "white": 1.0, + "black": 0.0, + "red": (1.0, 0.0, 0.0), + "green": (0.0, 1.0, 0.0), + "blue": (0.0, 0.0, 1.0), + } + + color_val = color_map[spacing_color] + + # When not matching sizes, pad to align non-concat dimensions + if not match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + pad_value = 0.0 + if not isinstance(color_val, tuple): + pad_value = color_val + + if direction in ["left", "right"]: + # For horizontal concat, pad heights to match + if h1 != h2: + target_h = max(h1, h2) + if h1 < target_h: + pad_h = target_h - h1 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) + if h2 < target_h: + pad_h = target_h - h2 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) + else: # up, down + # For vertical concat, pad widths to match + if w1 != w2: + target_w = max(w1, w2) + if w1 < target_w: + pad_w = target_w - w1 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) + if w2 < target_w: + pad_w = target_w - w2 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) + + # Ensure same number of channels + if image1.shape[-1] != image2.shape[-1]: + max_channels = max(image1.shape[-1], image2.shape[-1]) + if image1.shape[-1] < max_channels: + image1 = torch.cat( + [ + image1, + torch.ones( + *image1.shape[:-1], + max_channels - image1.shape[-1], + device=image1.device, + ), + ], + dim=-1, + ) + if image2.shape[-1] < max_channels: + image2 = torch.cat( + [ + image2, + torch.ones( + *image2.shape[:-1], + max_channels - image2.shape[-1], + device=image2.device, + ), + ], + dim=-1, + ) + + # Add spacing if specified + if spacing_width > 0: + spacing_width = spacing_width + (spacing_width % 2) # Ensure even + + if direction in ["left", "right"]: + spacing_shape = ( + image1.shape[0], + max(image1.shape[1], image2.shape[1]), + spacing_width, + image1.shape[-1], + ) + else: + spacing_shape = ( + image1.shape[0], + spacing_width, + max(image1.shape[2], image2.shape[2]), + image1.shape[-1], + ) + + spacing = torch.full(spacing_shape, 0.0, device=image1.device) + if isinstance(color_val, tuple): + for i, c in enumerate(color_val): + if i < spacing.shape[-1]: + spacing[..., i] = c + if spacing.shape[-1] == 4: # Add alpha + spacing[..., 3] = 1.0 + else: + spacing[..., : min(3, spacing.shape[-1])] = color_val + if spacing.shape[-1] == 4: + spacing[..., 3] = 1.0 + + # Concatenate images + images = [image2, image1] if direction in ["left", "up"] else [image1, image2] + if spacing_width > 0: + images.insert(1, spacing) + + concat_dim = 2 if direction in ["left", "right"] else 1 + return IO.NodeOutput(torch.cat(images, dim=concat_dim)) + + stitch = execute # TODO: remove + + +class ResizeAndPadImage(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ResizeAndPadImage", + search_aliases=["fit to size"], + display_name="Resize And Pad Image", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("padding_color", options=["white", "black"], advanced=True), + IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"], advanced=True), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: + batch_size, orig_height, orig_width, channels = image.shape + + scale_w = target_width / orig_width + scale_h = target_height / orig_height + scale = min(scale_w, scale_h) + + new_width = int(orig_width * scale) + new_height = int(orig_height * scale) + + image_permuted = image.permute(0, 3, 1, 2) + + resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled") + + pad_value = 0.0 if padding_color == "black" else 1.0 + padded = torch.full( + (batch_size, channels, target_height, target_width), + pad_value, + dtype=image.dtype, + device=image.device + ) + + y_offset = (target_height - new_height) // 2 + x_offset = (target_width - new_width) // 2 + + padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized + + output = padded.permute(0, 2, 3, 1) + return IO.NodeOutput(output) + + resize_and_pad = execute # TODO: remove + + +class SaveSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveSVGNode", + search_aliases=["export vector", "save vector graphics"], + display_name="Save SVG", + description="Save SVG files on disk.", + category="image/save", + inputs=[ + IO.SVG.Input("svg"), + IO.String.Input( + "filename_prefix", + default="svg/ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) + results: list[UI.SavedResult] = [] + + # Prepare metadata JSON + metadata_dict = {} + if cls.hidden.prompt is not None: + metadata_dict["prompt"] = cls.hidden.prompt + if cls.hidden.extra_pnginfo is not None: + metadata_dict.update(cls.hidden.extra_pnginfo) + + # Convert metadata to JSON string + metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + + + for batch_number, svg_bytes in enumerate(svg.data): + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.svg" + + # Read SVG content + svg_bytes.seek(0) + svg_content = svg_bytes.read().decode('utf-8') + + # Inject metadata if available + if metadata_json: + # Create metadata element with CDATA section + metadata_element = f""" + + {metadata_json} + + + """ + # Insert metadata after opening svg tag using regex with a replacement function + def replacement(match): + # match.group(1) contains the captured tag + return match.group(1) + '\n' + metadata_element + + # Apply the substitution + svg_content = re.sub(r'(]*>)', replacement, svg_content, flags=re.UNICODE) + + # Write the modified SVG to file + with open(os.path.join(full_output_folder, file), 'wb') as svg_file: + svg_file.write(svg_content.encode('utf-8')) + + results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) + counter += 1 + return IO.NodeOutput(ui={"images": results}) + + save_svg = execute # TODO: remove + + +class GetImageSize(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GetImageSize", + search_aliases=["dimensions", "resolution", "image info"], + display_name="Get Image Size", + description="Returns width and height of the image, and passes it through unchanged.", + category="image", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + IO.Int.Output(display_name="batch_size"), + ], + hidden=[IO.Hidden.unique_id], + ) + + @classmethod + def execute(cls, image) -> IO.NodeOutput: + height = image.shape[1] + width = image.shape[2] + batch_size = image.shape[0] + + # Send progress text to display size on the node + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) + + return IO.NodeOutput(width, height, batch_size) + + get_size = execute # TODO: remove + + +class ImageRotate(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageRotate", + display_name="Rotate Image", + search_aliases=["turn", "flip orientation"], + category="image/transform", + essentials_category="Image Tools", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, rotation) -> IO.NodeOutput: + rotate_by = 0 + if rotation.startswith("90"): + rotate_by = 1 + elif rotation.startswith("180"): + rotate_by = 2 + elif rotation.startswith("270"): + rotate_by = 3 + + image = torch.rot90(image, k=rotate_by, dims=[2, 1]) + return IO.NodeOutput(image) + + rotate = execute # TODO: remove + + +class ImageFlip(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageFlip", + search_aliases=["mirror", "reflect"], + display_name="Flip Image", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, flip_method) -> IO.NodeOutput: + if flip_method.startswith("x"): + image = torch.flip(image, dims=[1]) + elif flip_method.startswith("y"): + image = torch.flip(image, dims=[2]) + + return IO.NodeOutput(image) + + flip = execute # TODO: remove + + +class ImageScaleToMaxDimension(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageScaleToMaxDimension", + display_name="Scale Image to Max Dimension", + category="image/upscaling", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "upscale_method", + options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], + ), + IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: + height = image.shape[1] + width = image.shape[2] + + if height > width: + width = round((width / height) * largest_size) + height = largest_size + elif width > height: + height = round((height / width) * largest_size) + width = largest_size + else: + height = largest_size + width = largest_size + + samples = image.movedim(-1, 1) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1, -1) + return IO.NodeOutput(s) + + upscale = execute # TODO: remove + + +class SplitImageToTileList(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplitImageToTileList", + category="image/batch", + search_aliases=["split image", "tile image", "slice image"], + display_name="Split Image into List of Tiles", + description="Splits an image into a batched list of tiles with a specified overlap.", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION), + IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION), + IO.Int.Input("overlap", default=128, min=0, max=4096), + ], + outputs=[ + IO.Image.Output(is_output_list=True), + ], + ) + + @staticmethod + def get_grid_coords(width, height, tile_width, tile_height, overlap): + coords = [] + stride_x = round(max(tile_width * 0.25, tile_width - overlap)) + stride_y = round(max(tile_height * 0.25, tile_height - overlap)) + + y = 0 + while y < height: + x = 0 + y_end = min(y + tile_height, height) + y_start = max(0, y_end - tile_height) + + while x < width: + x_end = min(x + tile_width, width) + x_start = max(0, x_end - tile_width) + + coords.append((x_start, y_start, x_end, y_end)) + + if x_end >= width: + break + x += stride_x + + if y_end >= height: + break + y += stride_y + + return coords + + @classmethod + def execute(cls, image, tile_width, tile_height, overlap): + b, h, w, c = image.shape + coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap) + + output_list = [] + for (x_start, y_start, x_end, y_end) in coords: + tile = image[:, y_start:y_end, x_start:x_end, :] + output_list.append(tile) + + return IO.NodeOutput(output_list) + + +class ImageMergeTileList(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageMergeTileList", + display_name="Merge List of Tiles to Image", + category="image/batch", + search_aliases=["split image", "tile image", "slice image"], + is_input_list=True, + inputs=[ + IO.Image.Input("image_list"), + IO.Int.Input("final_width", default=1024, min=64, max=32768), + IO.Int.Input("final_height", default=1024, min=64, max=32768), + IO.Int.Input("overlap", default=128, min=0, max=4096), + ], + outputs=[ + IO.Image.Output(is_output_list=False), + ], + ) + + @classmethod + def execute(cls, image_list, final_width, final_height, overlap): + w = final_width[0] + h = final_height[0] + ovlp = overlap[0] + feather_str = 1.0 + + first_tile = image_list[0] + b, t_h, t_w, c = first_tile.shape + device = first_tile.device + dtype = first_tile.dtype + + coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp) + + canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) + weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) + + if ovlp > 0: + y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype)) + x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype)) + y_w = torch.clamp(y_w, min=1e-5) + x_w = torch.clamp(x_w, min=1e-5) + + sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1) + flat_mask = torch.ones_like(sine_mask) + + weight_mask = torch.lerp(flat_mask, sine_mask, feather_str) + else: + weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype) + + for i, (x_start, y_start, x_end, y_end) in enumerate(coords): + if i >= len(image_list): + break + + tile = image_list[i] + + region_h = y_end - y_start + region_w = x_end - x_start + + real_h = min(region_h, tile.shape[1]) + real_w = min(region_w, tile.shape[2]) + + y_end_actual = y_start + real_h + x_end_actual = x_start + real_w + + tile_crop = tile[:, :real_h, :real_w, :] + mask_crop = weight_mask[:, :real_h, :real_w, :] + + canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop + weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop + + weights[weights == 0] = 1.0 + merged_image = canvas / weights + + return IO.NodeOutput(merged_image) + + +# --------------------------------------------------------------------------- +# Format specifications +# --------------------------------------------------------------------------- + +# Maps (file_format, bit_depth, has_alpha) -> (numpy dtype scale, av pixel format, +# stream pix_fmt). Keeps the encode path declarative instead of branchy. +_FORMAT_SPECS = { + ("png", "8-bit", False): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"}, + ("png", "8-bit", True): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"}, + ("png", "16-bit", False): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"}, + ("png", "16-bit", True): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"}, + ("exr", "32-bit float", False): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"}, + ("exr", "32-bit float", True): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"}, +} + + +# --------------------------------------------------------------------------- +# Color transforms +# --------------------------------------------------------------------------- + +def srgb_to_linear(t: torch.Tensor) -> torch.Tensor: + """Inverse sRGB EOTF (IEC 61966-2-1). Operates on RGB channels only; + alpha (if present as the 4th channel) is passed through unchanged.""" + if t.shape[-1] == 4: + rgb, alpha = t[..., :3], t[..., 3:] + return torch.cat([srgb_to_linear(rgb), alpha], dim=-1) + + # Piecewise: linear toe below 0.04045, gamma curve above. + low = t / 12.92 + high = ((t.clamp(min=0.0) + 0.055) / 1.055) ** 2.4 + return torch.where(t <= 0.04045, low, high) + + +# HLG OETF constants from BT.2100 Table 5. +_HLG_A = 0.17883277 +_HLG_B = 0.28466892 +_HLG_C = 0.55991072928 # = 0.5 - a*ln(4*a) + + +def hlg_to_linear(t: torch.Tensor) -> torch.Tensor: + """Inverse HLG OETF (BT.2100). Maps a non-linear HLG signal in [0, 1] to + *scene*-linear light in [0, 1]. Per BT.2100 Note 5a, this is the correct + transform when converting HLG to a linear scene-light representation + (rather than display-light, which would also involve the HLG OOTF). + + Operates on RGB channels only; alpha is passed through unchanged.""" + if t.shape[-1] == 4: + rgb, alpha = t[..., :3], t[..., 3:] + return torch.cat([hlg_to_linear(rgb), alpha], dim=-1) + + # Piecewise: sqrt branch below 0.5, log branch above. + # Clamp inside the log branch so negative / out-of-range values don't blow up; + # values above 1.0 are allowed and extrapolate naturally. + low = (t ** 2) / 3.0 + high = (torch.exp((t.clamp(min=_HLG_C) - _HLG_C) / _HLG_A) + _HLG_B) / 12.0 + return torch.where(t <= 0.5, low, high) + + +# --------------------------------------------------------------------------- +# Metadata injection +# --------------------------------------------------------------------------- + +_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" + + +def _png_chunk(chunk_type: bytes, data: bytes) -> bytes: + """Build a single PNG chunk: length | type | data | CRC32(type+data).""" + crc = zlib.crc32(chunk_type + data) & 0xFFFFFFFF + return struct.pack(">I", len(data)) + chunk_type + data + struct.pack(">I", crc) + + +def _png_text_chunk(keyword: str, text: str) -> bytes: + """tEXt chunk: latin-1 keyword + NUL + latin-1 text.""" + payload = keyword.encode("latin-1") + b"\x00" + text.encode("latin-1", errors="replace") + return _png_chunk(b"tEXt", payload) + + +def inject_png_metadata(png_bytes: bytes, prompt: dict | None, extra_pnginfo: dict | None) -> bytes: + """Insert ComfyUI prompt/workflow as tEXt chunks right after IHDR.""" + if not png_bytes.startswith(_PNG_SIGNATURE): + return png_bytes + + chunks: list[bytes] = [] + if prompt is not None: + chunks.append(_png_text_chunk("prompt", json.dumps(prompt))) + if extra_pnginfo: + for key, value in extra_pnginfo.items(): + chunks.append(_png_text_chunk(key, json.dumps(value))) + if not chunks: + return png_bytes + + # IHDR is always the first chunk; insert ours immediately after it. + ihdr_length = struct.unpack(">I", png_bytes[8:12])[0] + ihdr_end = 8 + 8 + ihdr_length + 4 # signature + (len+type) + data + crc + return png_bytes[:ihdr_end] + b"".join(chunks) + png_bytes[ihdr_end:] + + +# Standard chromaticities (CIE 1931 xy) for the colorspaces this node writes. +# Each tuple is (Rx, Ry, Gx, Gy, Bx, By, Wx, Wy). All share D65 white point. +_CHROMATICITIES = { + # ITU-R BT.709 / sRGB primaries + "Rec.709": (0.6400, 0.3300, 0.3000, 0.6000, 0.1500, 0.0600, 0.3127, 0.3290), + # ITU-R BT.2020 (UHDTV / wide-gamut HDR) primaries + "Rec.2020": (0.7080, 0.2920, 0.1700, 0.7970, 0.1310, 0.0460, 0.3127, 0.3290), +} + + +def _pack_chromaticities(primaries: tuple) -> bytes: + """Serialize 8 chromaticity floats into the EXR `chromaticities` payload.""" + return struct.pack("<8f", *primaries) + + +def _exr_attribute(name: str, attr_type: str, value: bytes) -> bytes: + """Serialize one EXR header attribute: name\\0 type\\0 size:int32 value.""" + return ( + name.encode("utf-8") + b"\x00" + + attr_type.encode("utf-8") + b"\x00" + + struct.pack(" bytes: + """Insert ComfyUI metadata and color-space info into an EXR header. + + Color: EXR pixels are linear by convention. The standard way to describe + their RGB→XYZ relationship is the `chromaticities` attribute. We pick the + primaries that match what the user told us their input was: + + colorspace="sRGB" → Rec. 709 / sRGB primaries (D65) + colorspace="HDR" → Rec. 2020 / BT.2100 primaries (D65) + + Pixels are always converted to linear scene light upstream (sRGB EOTF + inverse for sRGB; HLG OETF inverse for HDR), so the file content is + scene-linear in the indicated gamut. OpenEXR has no standard transfer- + function attribute (the OpenEXR TSC has discussed adding one but it + doesn't exist), so we don't invent one — `chromaticities` plus the EXR + linear-by-convention rule fully specifies the color. + + Prompt/workflow: written as plain `string` attributes using the same keys + (`prompt`, `workflow`, ...) that Comfy uses for PNG tEXt chunks, so the + same readers can pull them out symmetrically. + + Implementation note: the chunk-offset table that follows the header stores + *absolute* byte offsets into the file. Inserting N bytes into the header + means every offset must be incremented by N or the file becomes unreadable. + """ + if len(exr_bytes) < 8 or exr_bytes[:4] != b"\x76\x2f\x31\x01": + return exr_bytes + + new_blob = b"" + if prompt is not None: + new_blob += _exr_attribute("prompt", "string", json.dumps(prompt).encode("utf-8")) + if extra_pnginfo: + for key, value in extra_pnginfo.items(): + new_blob += _exr_attribute(key, "string", json.dumps(value).encode("utf-8")) + if colorspace is not None: + # Map each colorspace option to the RGB primaries the linear pixels + # are now in. "sRGB" and "linear" both produce Rec. 709 linear; "HDR" + # (HLG-encoded Rec. 2020 input) produces Rec. 2020 linear. + primaries_name = { + "sRGB": "Rec.709", + "linear": "Rec.709", + "HDR": "Rec.2020", + }.get(colorspace, "Rec.709") + new_blob += _exr_attribute( + "chromaticities", + "chromaticities", + _pack_chromaticities(_CHROMATICITIES[primaries_name]), + ) + if not new_blob: + return exr_bytes + + # Walk header attributes to find the terminating null byte, and pick up + # dataWindow + compression so we know how many chunks the offset table has. + pos = 8 # past magic (4) + version (4) + data_window = None + compression = 0 + while pos < len(exr_bytes) and exr_bytes[pos] != 0: + name_end = exr_bytes.index(b"\x00", pos) + attr_name = exr_bytes[pos:name_end].decode("latin-1", errors="replace") + type_end = exr_bytes.index(b"\x00", name_end + 1) + attr_type = exr_bytes[name_end + 1:type_end].decode("latin-1", errors="replace") + size = struct.unpack(" bytes: + """Encode a single HxWxC tensor to PNG or EXR bytes in memory. + + For EXR the input is interpreted according to `colorspace` and converted + to scene-linear (EXR's convention) before writing: + + "sRGB" → input is sRGB-encoded Rec. 709; apply inverse sRGB EOTF. + "HDR" → input is HLG-encoded Rec. 2020 (BT.2100); apply inverse HLG + OETF to get scene-linear, per BT.2100 Note 5a. + "linear" → input is already scene-linear (Rec. 709 primaries); write + through unchanged. Use this for renderer/compositor output. + + For PNG, colorspace selection does not modify pixels — PNG is delivered + sRGB-encoded and there is no PNG path for wide-gamut HDR in this node. + """ + height, width, num_channels = img_tensor.shape + has_alpha = num_channels == 4 + + spec = _FORMAT_SPECS[(file_format, bit_depth, has_alpha)] + + if spec["dtype"] == np.float32: + # EXR path: preserve full range, no clamp. + if colorspace == "sRGB": + img_tensor = srgb_to_linear(img_tensor) + elif colorspace == "HDR": + img_tensor = hlg_to_linear(img_tensor) + img_np = img_tensor.cpu().numpy().astype(np.float32) + else: + # PNG path: quantize to integer range. + scaled = (img_tensor * spec["scale"]).clamp(0, spec["scale"]) + img_np = scaled.to(torch.int32).cpu().numpy().astype(spec["dtype"]) + + # Encode directly via CodecContext. PyAV's `image2` muxer does NOT write to + # BytesIO (it expects a real file path), so we bypass the container entirely. + # For single-frame PNG/EXR the raw codec output IS the file. + codec = av.CodecContext.create(file_format, "w") + codec.width = width + codec.height = height + codec.pix_fmt = spec["stream_fmt"] + codec.time_base = Fraction(1, 1) + + frame = av.VideoFrame.from_ndarray(img_np, format=spec["frame_fmt"]) + if spec["frame_fmt"] != spec["stream_fmt"]: + frame = frame.reformat(format=spec["stream_fmt"]) + frame.pts = 0 + frame.time_base = codec.time_base + + packets = list(codec.encode(frame)) + list(codec.encode(None)) # flush with None + return b"".join(bytes(p) for p in packets) + + +# --------------------------------------------------------------------------- +# Node +# --------------------------------------------------------------------------- + +class SaveImageAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveImageAdvanced", + search_aliases=["save", "save image", "export image", "output image", "write image"], + display_name="Save Image (Advanced)", + description="Saves the input images to your ComfyUI output directory.", + category="image", + essentials_category="Basics", + inputs=[ + IO.Image.Input("images", tooltip="The images to save."), + IO.String.Input( + "filename_prefix", + default="ComfyUI", + tooltip=( + "The prefix for the file to save. May include formatting tokens " + "such as %date:yyyy-MM-dd% or %Empty Latent Image.width%." + ), + ), + IO.DynamicCombo.Input( + "format", + options=[ + IO.DynamicCombo.Option("png", [ + IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], + default="8-bit", advanced=True), IO.Combo.Input("input_color_space", options=["sRGB"], - default="sRGB", advanced=True), - ]), - IO.DynamicCombo.Option("exr", [ - IO.Combo.Input("bit_depth", options=["32-bit float"], - default="32-bit float", advanced=True), - IO.Combo.Input( - "input_color_space", - options=["sRGB", "HDR", "linear"], - default="sRGB", - advanced=True, - tooltip=( - "Colorspace of the input tensor. The EXR is " - "always written as scene-linear in the matching " - "gamut.\n" - " 'sRGB' — input is sRGB-encoded Rec.709; " - "the inverse sRGB EOTF is applied.\n" - " 'HDR' — input is HLG-encoded Rec.2020 " - "(BT.2100); the inverse HLG OETF is applied " - "to get scene-linear light.\n" - " 'linear' — input is already scene-linear " - "(Rec.709 primaries); written through unchanged. " - "Use this for renderer/compositor output." - ), - ), - ]), - ], - tooltip="The file format in which to save the image.", - ), - ], - hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, - ) - - @classmethod - def execute(cls, images, filename_prefix: str, image_format: dict) -> IO.NodeOutput: - file_format = image_format["image_format"] - bit_depth = image_format["bit_depth"] - colorspace = image_format.get("colorspace", "sRGB") - - output_dir = folder_paths.get_output_directory() - full_output_folder, filename, counter, subfolder, filename_prefix = ( - folder_paths.get_save_image_path( - filename_prefix, output_dir, images[0].shape[1], images[0].shape[0] - ) - ) - - prompt = cls.hidden.prompt - extra_pnginfo = cls.hidden.extra_pnginfo - write_metadata = not args.disable_metadata - - results = [] - for batch_number, image in enumerate(images): - encoded = _encode_image(image, file_format, bit_depth, colorspace) - - if write_metadata: - if file_format == "png": - encoded = inject_png_metadata(encoded, prompt, extra_pnginfo) - elif file_format == "exr": - encoded = inject_exr_metadata(encoded, prompt, extra_pnginfo, colorspace) - - name = filename.replace("%batch_num%", str(batch_number)) - file = f"{name}_{counter:05}.{file_format}" - with open(os.path.join(full_output_folder, file), "wb") as f: - f.write(encoded) - - results.append({"filename": file, "subfolder": subfolder, "type": "output"}) - counter += 1 - - return IO.NodeOutput(ui={"images": results}) - - -class ImagesExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[IO.ComfyNode]]: - return [ - ImageCrop, - ImageCropV2, - BoundingBox, - RepeatImageBatch, - ImageFromBatch, - ImageAddNoise, - SaveAnimatedWEBP, - SaveAnimatedPNG, - SaveImageAdvanced, - SaveSVGNode, - ImageStitch, - ResizeAndPadImage, - GetImageSize, - ImageRotate, - ImageFlip, - ImageScaleToMaxDimension, - SplitImageToTileList, - ImageMergeTileList, - ] - - -async def comfy_entrypoint() -> ImagesExtension: - return ImagesExtension() + default="sRGB", advanced=True), + ]), + IO.DynamicCombo.Option("exr", [ + IO.Combo.Input("bit_depth", options=["32-bit float"], + default="32-bit float", advanced=True), + IO.Combo.Input( + "input_color_space", + options=["sRGB", "HDR", "linear"], + default="sRGB", + advanced=True, + tooltip=( + "Colorspace of the input tensor. The EXR is " + "always written as scene-linear in the matching " + "gamut.\n" + " 'sRGB' — input is sRGB-encoded Rec.709; " + "the inverse sRGB EOTF is applied.\n" + " 'HDR' — input is HLG-encoded Rec.2020 " + "(BT.2100); the inverse HLG OETF is applied " + "to get scene-linear light.\n" + " 'linear' — input is already scene-linear " + "(Rec.709 primaries); written through unchanged. " + "Use this for renderer/compositor output." + ), + ), + ]), + ], + tooltip="The file format in which to save the image.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, filename_prefix: str, format: dict) -> IO.NodeOutput: + file_format = format["format"] + bit_depth = format["bit_depth"] + colorspace = format.get("input_color_space", "sRGB") + + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path( + filename_prefix, output_dir, images[0].shape[1], images[0].shape[0] + ) + ) + + prompt = cls.hidden.prompt + extra_pnginfo = cls.hidden.extra_pnginfo + write_metadata = not args.disable_metadata + + results = [] + for batch_number, image in enumerate(images): + encoded = _encode_image(image, file_format, bit_depth, colorspace) + + if write_metadata: + if file_format == "png": + encoded = inject_png_metadata(encoded, prompt, extra_pnginfo) + elif file_format == "exr": + encoded = inject_exr_metadata(encoded, prompt, extra_pnginfo, colorspace) + + name = filename.replace("%batch_num%", str(batch_number)) + file = f"{name}_{counter:05}.{file_format}" + with open(os.path.join(full_output_folder, file), "wb") as f: + f.write(encoded) + + results.append({"filename": file, "subfolder": subfolder, "type": "output"}) + counter += 1 + + return IO.NodeOutput(ui={"images": results}) + + +class ImagesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCrop, + ImageCropV2, + BoundingBox, + RepeatImageBatch, + ImageFromBatch, + ImageAddNoise, + SaveAnimatedWEBP, + SaveAnimatedPNG, + SaveImageAdvanced, + SaveSVGNode, + ImageStitch, + ResizeAndPadImage, + GetImageSize, + ImageRotate, + ImageFlip, + ImageScaleToMaxDimension, + SplitImageToTileList, + ImageMergeTileList, + ] + + +async def comfy_entrypoint() -> ImagesExtension: + return ImagesExtension()