from __future__ import annotations import nodes import folder_paths import av import json import os import re import math import numpy as np import struct import torch import zlib import comfy.utils from fractions import Fraction from server import PromptServer from comfy_api.latest import ComfyExtension, IO, UI from comfy.cli_args import args from typing_extensions import override SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. MAX_RESOLUTION = nodes.MAX_RESOLUTION class ImageCrop(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageCrop", search_aliases=["trim"], display_name="Crop Image (DEPRECATED)", category="image/transform", is_deprecated=True, essentials_category="Image Tools", inputs=[ IO.Image.Input("image"), IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, width, height, x, y) -> IO.NodeOutput: x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] return IO.NodeOutput(img) crop = execute # TODO: remove class ImageCropV2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageCropV2", search_aliases=["trim"], display_name="Crop Image", category="image/transform", essentials_category="Image Tools", has_intermediate_output=True, inputs=[ IO.Image.Input("image"), IO.BoundingBox.Input("crop_region", component="ImageCrop"), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, crop_region) -> IO.NodeOutput: x = crop_region.get("x", 0) y = crop_region.get("y", 0) width = crop_region.get("width", 512) height = crop_region.get("height", 512) x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] return IO.NodeOutput(img, ui=UI.PreviewImage(img)) class BoundingBox(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="PrimitiveBoundingBox", display_name="Bounding Box", category="utils/primitive", inputs=[ IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION), IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION), IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION), IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION), ], outputs=[IO.BoundingBox.Output()], ) @classmethod def execute(cls, x, y, width, height) -> IO.NodeOutput: return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height}) class RepeatImageBatch(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="RepeatImageBatch", search_aliases=["duplicate image", "clone image"], display_name="Repeat Image Batch", category="image/batch", inputs=[ IO.Image.Input("image"), IO.Int.Input("amount", default=1, min=1, max=4096), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, amount) -> IO.NodeOutput: s = image.repeat((amount, 1,1,1)) return IO.NodeOutput(s) repeat = execute # TODO: remove class ImageFromBatch(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageFromBatch", search_aliases=["select image", "pick from batch", "extract image"], display_name="Get Image from Batch", category="image/batch", inputs=[ IO.Image.Input("image"), IO.Int.Input("batch_index", default=0, min=0, max=4095), IO.Int.Input("length", default=1, min=1, max=4096), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, batch_index, length) -> IO.NodeOutput: s_in = image batch_index = min(s_in.shape[0] - 1, batch_index) length = min(s_in.shape[0] - batch_index, length) s = s_in[batch_index:batch_index + length].clone() return IO.NodeOutput(s) frombatch = execute # TODO: remove class ImageAddNoise(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageAddNoise", search_aliases=["film grain"], display_name="Add Noise to Image", category="image/postprocessing", inputs=[ IO.Image.Input("image"), IO.Int.Input( "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, control_after_generate=True, tooltip="The random seed used for creating the noise.", ), IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, seed, strength) -> IO.NodeOutput: generator = torch.manual_seed(seed) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) return IO.NodeOutput(s) repeat = execute # TODO: remove class SaveAnimatedWEBP(IO.ComfyNode): COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} @classmethod def define_schema(cls): return IO.Schema( node_id="SaveAnimatedWEBP", category="image/animation", inputs=[ IO.Image.Input("images"), IO.String.Input("filename_prefix", default="ComfyUI"), IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), IO.Boolean.Input("lossless", default=True), IO.Int.Input("quality", default=80, min=0, max=100), IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: return IO.NodeOutput( ui=UI.ImageSaveHelper.get_save_animated_webp_ui( images=images, filename_prefix=filename_prefix, cls=cls, fps=fps, lossless=lossless, quality=quality, method=cls.COMPRESS_METHODS.get(method) ) ) save_images = execute # TODO: remove class SaveAnimatedPNG(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="SaveAnimatedPNG", category="image/animation", inputs=[ IO.Image.Input("images"), IO.String.Input("filename_prefix", default="ComfyUI"), IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), IO.Int.Input("compress_level", default=4, min=0, max=9, advanced=True), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: return IO.NodeOutput( ui=UI.ImageSaveHelper.get_save_animated_png_ui( images=images, filename_prefix=filename_prefix, cls=cls, fps=fps, compress_level=compress_level, ) ) save_images = execute # TODO: remove class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @classmethod def define_schema(cls): return IO.Schema( node_id="ImageStitch", search_aliases=["combine images", "join images", "concatenate images", "side by side"], display_name="Stitch Images", description="Stitches image2 to image1 in the specified direction.\n" "If image2 is not provided, returns image1 unchanged.\n" "Optional spacing can be added between images.", category="image/transform", inputs=[ IO.Image.Input("image1"), IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), IO.Boolean.Input("match_image_size", default=True), IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2, advanced=True), IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white", advanced=True), IO.Image.Input("image2", optional=True), ], outputs=[IO.Image.Output()], ) @classmethod def execute( cls, image1, direction, match_image_size, spacing_width, spacing_color, image2=None, ) -> IO.NodeOutput: if image2 is None: return IO.NodeOutput(image1) # Handle batch size differences if image1.shape[0] != image2.shape[0]: max_batch = max(image1.shape[0], image2.shape[0]) if image1.shape[0] < max_batch: image1 = torch.cat( [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)] ) if image2.shape[0] < max_batch: image2 = torch.cat( [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)] ) # Match image sizes if requested if match_image_size: h1, w1 = image1.shape[1:3] h2, w2 = image2.shape[1:3] aspect_ratio = w2 / h2 if direction in ["left", "right"]: target_h, target_w = h1, int(h1 * aspect_ratio) else: # up, down target_w, target_h = w1, int(w1 / aspect_ratio) image2 = comfy.utils.common_upscale( image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" ).movedim(1, -1) color_map = { "white": 1.0, "black": 0.0, "red": (1.0, 0.0, 0.0), "green": (0.0, 1.0, 0.0), "blue": (0.0, 0.0, 1.0), } color_val = color_map[spacing_color] # When not matching sizes, pad to align non-concat dimensions if not match_image_size: h1, w1 = image1.shape[1:3] h2, w2 = image2.shape[1:3] pad_value = 0.0 if not isinstance(color_val, tuple): pad_value = color_val if direction in ["left", "right"]: # For horizontal concat, pad heights to match if h1 != h2: target_h = max(h1, h2) if h1 < target_h: pad_h = target_h - h1 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) if h2 < target_h: pad_h = target_h - h2 pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value) else: # up, down # For vertical concat, pad widths to match if w1 != w2: target_w = max(w1, w2) if w1 < target_w: pad_w = target_w - w1 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) if w2 < target_w: pad_w = target_w - w2 pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value) # Ensure same number of channels if image1.shape[-1] != image2.shape[-1]: max_channels = max(image1.shape[-1], image2.shape[-1]) if image1.shape[-1] < max_channels: image1 = torch.cat( [ image1, torch.ones( *image1.shape[:-1], max_channels - image1.shape[-1], device=image1.device, ), ], dim=-1, ) if image2.shape[-1] < max_channels: image2 = torch.cat( [ image2, torch.ones( *image2.shape[:-1], max_channels - image2.shape[-1], device=image2.device, ), ], dim=-1, ) # Add spacing if specified if spacing_width > 0: spacing_width = spacing_width + (spacing_width % 2) # Ensure even if direction in ["left", "right"]: spacing_shape = ( image1.shape[0], max(image1.shape[1], image2.shape[1]), spacing_width, image1.shape[-1], ) else: spacing_shape = ( image1.shape[0], spacing_width, max(image1.shape[2], image2.shape[2]), image1.shape[-1], ) spacing = torch.full(spacing_shape, 0.0, device=image1.device) if isinstance(color_val, tuple): for i, c in enumerate(color_val): if i < spacing.shape[-1]: spacing[..., i] = c if spacing.shape[-1] == 4: # Add alpha spacing[..., 3] = 1.0 else: spacing[..., : min(3, spacing.shape[-1])] = color_val if spacing.shape[-1] == 4: spacing[..., 3] = 1.0 # Concatenate images images = [image2, image1] if direction in ["left", "up"] else [image1, image2] if spacing_width > 0: images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 return IO.NodeOutput(torch.cat(images, dim=concat_dim)) stitch = execute # TODO: remove class ResizeAndPadImage(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ResizeAndPadImage", search_aliases=["fit to size"], display_name="Resize And Pad Image", category="image/transform", inputs=[ IO.Image.Input("image"), IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), IO.Combo.Input("padding_color", options=["white", "black"], advanced=True), IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"], advanced=True), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: batch_size, orig_height, orig_width, channels = image.shape scale_w = target_width / orig_width scale_h = target_height / orig_height scale = min(scale_w, scale_h) new_width = int(orig_width * scale) new_height = int(orig_height * scale) image_permuted = image.permute(0, 3, 1, 2) resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled") pad_value = 0.0 if padding_color == "black" else 1.0 padded = torch.full( (batch_size, channels, target_height, target_width), pad_value, dtype=image.dtype, device=image.device ) y_offset = (target_height - new_height) // 2 x_offset = (target_width - new_width) // 2 padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized output = padded.permute(0, 2, 3, 1) return IO.NodeOutput(output) resize_and_pad = execute # TODO: remove class SaveSVGNode(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="SaveSVGNode", search_aliases=["export vector", "save vector graphics"], display_name="Save SVG", description="Save SVG files on disk.", category="image/save", inputs=[ IO.SVG.Input("svg"), IO.String.Input( "filename_prefix", default="svg/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", ), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) results: list[UI.SavedResult] = [] # Prepare metadata JSON metadata_dict = {} if cls.hidden.prompt is not None: metadata_dict["prompt"] = cls.hidden.prompt if cls.hidden.extra_pnginfo is not None: metadata_dict.update(cls.hidden.extra_pnginfo) # Convert metadata to JSON string metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None for batch_number, svg_bytes in enumerate(svg.data): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.svg" # Read SVG content svg_bytes.seek(0) svg_content = svg_bytes.read().decode('utf-8') # Inject metadata if available if metadata_json: # Create metadata element with CDATA section metadata_element = f""" """ # Insert metadata after opening svg tag using regex with a replacement function def replacement(match): # match.group(1) contains the captured tag return match.group(1) + '\n' + metadata_element # Apply the substitution svg_content = re.sub(r'(]*>)', replacement, svg_content, flags=re.UNICODE) # Write the modified SVG to file with open(os.path.join(full_output_folder, file), 'wb') as svg_file: svg_file.write(svg_content.encode('utf-8')) results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 return IO.NodeOutput(ui={"images": results}) save_svg = execute # TODO: remove class GetImageSize(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="GetImageSize", search_aliases=["dimensions", "resolution", "image info"], display_name="Get Image Size", description="Returns width and height of the image, and passes it through unchanged.", category="image", inputs=[ IO.Image.Input("image"), ], outputs=[ IO.Int.Output(display_name="width"), IO.Int.Output(display_name="height"), IO.Int.Output(display_name="batch_size"), ], hidden=[IO.Hidden.unique_id], ) @classmethod def execute(cls, image) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] batch_size = image.shape[0] # Send progress text to display size on the node if cls.hidden.unique_id: PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) return IO.NodeOutput(width, height, batch_size) get_size = execute # TODO: remove class ImageRotate(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageRotate", display_name="Rotate Image", search_aliases=["turn", "flip orientation"], category="image/transform", essentials_category="Image Tools", inputs=[ IO.Image.Input("image"), IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, rotation) -> IO.NodeOutput: rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 elif rotation.startswith("180"): rotate_by = 2 elif rotation.startswith("270"): rotate_by = 3 image = torch.rot90(image, k=rotate_by, dims=[2, 1]) return IO.NodeOutput(image) rotate = execute # TODO: remove class ImageFlip(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageFlip", search_aliases=["mirror", "reflect"], display_name="Flip Image", category="image/transform", inputs=[ IO.Image.Input("image"), IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, flip_method) -> IO.NodeOutput: if flip_method.startswith("x"): image = torch.flip(image, dims=[1]) elif flip_method.startswith("y"): image = torch.flip(image, dims=[2]) return IO.NodeOutput(image) flip = execute # TODO: remove class ImageScaleToMaxDimension(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageScaleToMaxDimension", display_name="Scale Image to Max Dimension", category="image/upscaling", inputs=[ IO.Image.Input("image"), IO.Combo.Input( "upscale_method", options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], ), IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), ], outputs=[IO.Image.Output()], ) @classmethod def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] if height > width: width = round((width / height) * largest_size) height = largest_size elif width > height: height = round((height / width) * largest_size) width = largest_size else: height = largest_size width = largest_size samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1, -1) return IO.NodeOutput(s) upscale = execute # TODO: remove class SplitImageToTileList(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="SplitImageToTileList", category="image/batch", search_aliases=["split image", "tile image", "slice image"], display_name="Split Image into List of Tiles", description="Splits an image into a batched list of tiles with a specified overlap.", inputs=[ IO.Image.Input("image"), IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION), IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION), IO.Int.Input("overlap", default=128, min=0, max=4096), ], outputs=[ IO.Image.Output(is_output_list=True), ], ) @staticmethod def get_grid_coords(width, height, tile_width, tile_height, overlap): coords = [] stride_x = round(max(tile_width * 0.25, tile_width - overlap)) stride_y = round(max(tile_height * 0.25, tile_height - overlap)) y = 0 while y < height: x = 0 y_end = min(y + tile_height, height) y_start = max(0, y_end - tile_height) while x < width: x_end = min(x + tile_width, width) x_start = max(0, x_end - tile_width) coords.append((x_start, y_start, x_end, y_end)) if x_end >= width: break x += stride_x if y_end >= height: break y += stride_y return coords @classmethod def execute(cls, image, tile_width, tile_height, overlap): b, h, w, c = image.shape coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap) output_list = [] for (x_start, y_start, x_end, y_end) in coords: tile = image[:, y_start:y_end, x_start:x_end, :] output_list.append(tile) return IO.NodeOutput(output_list) class ImageMergeTileList(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ImageMergeTileList", display_name="Merge List of Tiles to Image", category="image/batch", search_aliases=["split image", "tile image", "slice image"], is_input_list=True, inputs=[ IO.Image.Input("image_list"), IO.Int.Input("final_width", default=1024, min=64, max=32768), IO.Int.Input("final_height", default=1024, min=64, max=32768), IO.Int.Input("overlap", default=128, min=0, max=4096), ], outputs=[ IO.Image.Output(is_output_list=False), ], ) @classmethod def execute(cls, image_list, final_width, final_height, overlap): w = final_width[0] h = final_height[0] ovlp = overlap[0] feather_str = 1.0 first_tile = image_list[0] b, t_h, t_w, c = first_tile.shape device = first_tile.device dtype = first_tile.dtype coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp) canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) if ovlp > 0: y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype)) x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype)) y_w = torch.clamp(y_w, min=1e-5) x_w = torch.clamp(x_w, min=1e-5) sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1) flat_mask = torch.ones_like(sine_mask) weight_mask = torch.lerp(flat_mask, sine_mask, feather_str) else: weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype) for i, (x_start, y_start, x_end, y_end) in enumerate(coords): if i >= len(image_list): break tile = image_list[i] region_h = y_end - y_start region_w = x_end - x_start real_h = min(region_h, tile.shape[1]) real_w = min(region_w, tile.shape[2]) y_end_actual = y_start + real_h x_end_actual = x_start + real_w tile_crop = tile[:, :real_h, :real_w, :] mask_crop = weight_mask[:, :real_h, :real_w, :] canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop weights[weights == 0] = 1.0 merged_image = canvas / weights return IO.NodeOutput(merged_image) # --------------------------------------------------------------------------- # Format specifications # --------------------------------------------------------------------------- # Maps (file_format, bit_depth, has_alpha) -> (numpy dtype scale, av pixel format, # stream pix_fmt). Keeps the encode path declarative instead of branchy. _FORMAT_SPECS = { ("png", "8-bit", False): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"}, ("png", "8-bit", True): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"}, ("png", "16-bit", False): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"}, ("png", "16-bit", True): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"}, ("exr", "32-bit float", False): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"}, ("exr", "32-bit float", True): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"}, } # --------------------------------------------------------------------------- # Color transforms # --------------------------------------------------------------------------- def srgb_to_linear(t: torch.Tensor) -> torch.Tensor: """Inverse sRGB EOTF (IEC 61966-2-1). Operates on RGB channels only; alpha (if present as the 4th channel) is passed through unchanged.""" if t.shape[-1] == 4: rgb, alpha = t[..., :3], t[..., 3:] return torch.cat([srgb_to_linear(rgb), alpha], dim=-1) # Piecewise: linear toe below 0.04045, gamma curve above. low = t / 12.92 high = ((t.clamp(min=0.0) + 0.055) / 1.055) ** 2.4 return torch.where(t <= 0.04045, low, high) # HLG OETF constants from BT.2100 Table 5. _HLG_A = 0.17883277 _HLG_B = 0.28466892 _HLG_C = 0.55991072928 # = 0.5 - a*ln(4*a) def hlg_to_linear(t: torch.Tensor) -> torch.Tensor: """Inverse HLG OETF (BT.2100). Maps a non-linear HLG signal in [0, 1] to *scene*-linear light in [0, 1]. Per BT.2100 Note 5a, this is the correct transform when converting HLG to a linear scene-light representation (rather than display-light, which would also involve the HLG OOTF). Operates on RGB channels only; alpha is passed through unchanged.""" if t.shape[-1] == 4: rgb, alpha = t[..., :3], t[..., 3:] return torch.cat([hlg_to_linear(rgb), alpha], dim=-1) # Piecewise: sqrt branch below 0.5, log branch above. # Clamp inside the log branch so negative / out-of-range values don't blow up; # values above 1.0 are allowed and extrapolate naturally. low = (t ** 2) / 3.0 high = (torch.exp((t.clamp(min=_HLG_C) - _HLG_C) / _HLG_A) + _HLG_B) / 12.0 return torch.where(t <= 0.5, low, high) # --------------------------------------------------------------------------- # Metadata injection # --------------------------------------------------------------------------- _PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" def _png_chunk(chunk_type: bytes, data: bytes) -> bytes: """Build a single PNG chunk: length | type | data | CRC32(type+data).""" crc = zlib.crc32(chunk_type + data) & 0xFFFFFFFF return struct.pack(">I", len(data)) + chunk_type + data + struct.pack(">I", crc) def _png_text_chunk(keyword: str, text: str) -> bytes: """tEXt chunk: latin-1 keyword + NUL + latin-1 text.""" payload = keyword.encode("latin-1") + b"\x00" + text.encode("latin-1", errors="replace") return _png_chunk(b"tEXt", payload) def inject_png_metadata(png_bytes: bytes, prompt: dict | None, extra_pnginfo: dict | None) -> bytes: """Insert ComfyUI prompt/workflow as tEXt chunks right after IHDR.""" if not png_bytes.startswith(_PNG_SIGNATURE): return png_bytes chunks: list[bytes] = [] if prompt is not None: chunks.append(_png_text_chunk("prompt", json.dumps(prompt))) if extra_pnginfo: for key, value in extra_pnginfo.items(): chunks.append(_png_text_chunk(key, json.dumps(value))) if not chunks: return png_bytes # IHDR is always the first chunk; insert ours immediately after it. ihdr_length = struct.unpack(">I", png_bytes[8:12])[0] ihdr_end = 8 + 8 + ihdr_length + 4 # signature + (len+type) + data + crc return png_bytes[:ihdr_end] + b"".join(chunks) + png_bytes[ihdr_end:] # Standard chromaticities (CIE 1931 xy) for the colorspaces this node writes. # Each tuple is (Rx, Ry, Gx, Gy, Bx, By, Wx, Wy). All share D65 white point. _CHROMATICITIES = { # ITU-R BT.709 / sRGB primaries "Rec.709": (0.6400, 0.3300, 0.3000, 0.6000, 0.1500, 0.0600, 0.3127, 0.3290), # ITU-R BT.2020 (UHDTV / wide-gamut HDR) primaries "Rec.2020": (0.7080, 0.2920, 0.1700, 0.7970, 0.1310, 0.0460, 0.3127, 0.3290), } def _pack_chromaticities(primaries: tuple) -> bytes: """Serialize 8 chromaticity floats into the EXR `chromaticities` payload.""" return struct.pack("<8f", *primaries) def _exr_attribute(name: str, attr_type: str, value: bytes) -> bytes: """Serialize one EXR header attribute: name\\0 type\\0 size:int32 value.""" return ( name.encode("utf-8") + b"\x00" + attr_type.encode("utf-8") + b"\x00" + struct.pack(" bytes: """Insert ComfyUI metadata and color-space info into an EXR header. Color: EXR pixels are linear by convention. The standard way to describe their RGB→XYZ relationship is the `chromaticities` attribute. We pick the primaries that match what the user told us their input was: colorspace="sRGB" → Rec. 709 / sRGB primaries (D65) colorspace="HDR" → Rec. 2020 / BT.2100 primaries (D65) Pixels are always converted to linear scene light upstream (sRGB EOTF inverse for sRGB; HLG OETF inverse for HDR), so the file content is scene-linear in the indicated gamut. OpenEXR has no standard transfer- function attribute (the OpenEXR TSC has discussed adding one but it doesn't exist), so we don't invent one — `chromaticities` plus the EXR linear-by-convention rule fully specifies the color. Prompt/workflow: written as plain `string` attributes using the same keys (`prompt`, `workflow`, ...) that Comfy uses for PNG tEXt chunks, so the same readers can pull them out symmetrically. Implementation note: the chunk-offset table that follows the header stores *absolute* byte offsets into the file. Inserting N bytes into the header means every offset must be incremented by N or the file becomes unreadable. """ if len(exr_bytes) < 8 or exr_bytes[:4] != b"\x76\x2f\x31\x01": return exr_bytes new_blob = b"" if prompt is not None: new_blob += _exr_attribute("prompt", "string", json.dumps(prompt).encode("utf-8")) if extra_pnginfo: for key, value in extra_pnginfo.items(): new_blob += _exr_attribute(key, "string", json.dumps(value).encode("utf-8")) if colorspace is not None: # Map each colorspace option to the RGB primaries the linear pixels # are now in. "sRGB" and "linear" both produce Rec. 709 linear; "HDR" # (HLG-encoded Rec. 2020 input) produces Rec. 2020 linear. primaries_name = { "sRGB": "Rec.709", "linear": "Rec.709", "HDR": "Rec.2020", }.get(colorspace, "Rec.709") new_blob += _exr_attribute( "chromaticities", "chromaticities", _pack_chromaticities(_CHROMATICITIES[primaries_name]), ) if not new_blob: return exr_bytes # Walk header attributes to find the terminating null byte, and pick up # dataWindow + compression so we know how many chunks the offset table has. pos = 8 # past magic (4) + version (4) data_window = None compression = 0 while pos < len(exr_bytes) and exr_bytes[pos] != 0: name_end = exr_bytes.index(b"\x00", pos) attr_name = exr_bytes[pos:name_end].decode("latin-1", errors="replace") type_end = exr_bytes.index(b"\x00", name_end + 1) attr_type = exr_bytes[name_end + 1:type_end].decode("latin-1", errors="replace") size = struct.unpack(" bytes: """Encode a single HxWxC tensor to PNG or EXR bytes in memory. For EXR the input is interpreted according to `colorspace` and converted to scene-linear (EXR's convention) before writing: "sRGB" → input is sRGB-encoded Rec. 709; apply inverse sRGB EOTF. "HDR" → input is HLG-encoded Rec. 2020 (BT.2100); apply inverse HLG OETF to get scene-linear, per BT.2100 Note 5a. "linear" → input is already scene-linear (Rec. 709 primaries); write through unchanged. Use this for renderer/compositor output. For PNG, colorspace selection does not modify pixels — PNG is delivered sRGB-encoded and there is no PNG path for wide-gamut HDR in this node. """ height, width, num_channels = img_tensor.shape has_alpha = num_channels == 4 spec = _FORMAT_SPECS[(file_format, bit_depth, has_alpha)] if spec["dtype"] == np.float32: # EXR path: preserve full range, no clamp. if colorspace == "sRGB": img_tensor = srgb_to_linear(img_tensor) elif colorspace == "HDR": img_tensor = hlg_to_linear(img_tensor) img_np = img_tensor.cpu().numpy().astype(np.float32) else: # PNG path: quantize to integer range. scaled = (img_tensor * spec["scale"]).clamp(0, spec["scale"]) img_np = scaled.to(torch.int32).cpu().numpy().astype(spec["dtype"]) # Encode directly via CodecContext. PyAV's `image2` muxer does NOT write to # BytesIO (it expects a real file path), so we bypass the container entirely. # For single-frame PNG/EXR the raw codec output IS the file. codec = av.CodecContext.create(file_format, "w") codec.width = width codec.height = height codec.pix_fmt = spec["stream_fmt"] codec.time_base = Fraction(1, 1) frame = av.VideoFrame.from_ndarray(img_np, format=spec["frame_fmt"]) if spec["frame_fmt"] != spec["stream_fmt"]: frame = frame.reformat(format=spec["stream_fmt"]) frame.pts = 0 frame.time_base = codec.time_base packets = list(codec.encode(frame)) + list(codec.encode(None)) # flush with None return b"".join(bytes(p) for p in packets) # --------------------------------------------------------------------------- # Node # --------------------------------------------------------------------------- class SaveImageAdvanced(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="SaveImageAdvanced", search_aliases=["save", "save image", "export image", "output image", "write image"], display_name="Save Image (Advanced)", description="Saves the input images to your ComfyUI output directory.", category="image", essentials_category="Basics", inputs=[ IO.Image.Input("images", tooltip="The images to save."), IO.String.Input( "filename_prefix", default="ComfyUI", tooltip=( "The prefix for the file to save. May include formatting tokens " "such as %date:yyyy-MM-dd% or %Empty Latent Image.width%." ), ), IO.DynamicCombo.Input( "image_format", options=[ IO.DynamicCombo.Option("png", [ IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], default="8-bit", advanced=True), IO.Combo.Input("colorspace", options=["sRGB"], default="sRGB", advanced=True), ]), IO.DynamicCombo.Option("exr", [ IO.Combo.Input("bit_depth", options=["32-bit float"], default="32-bit float", advanced=True), IO.Combo.Input( "colorspace", options=["sRGB", "HDR", "linear"], default="sRGB", advanced=True, tooltip=( "Colorspace of the input tensor. The EXR is " "always written as scene-linear in the matching " "gamut.\n" " 'sRGB' — input is sRGB-encoded Rec.709; " "the inverse sRGB EOTF is applied.\n" " 'HDR' — input is HLG-encoded Rec.2020 " "(BT.2100); the inverse HLG OETF is applied " "to get scene-linear light.\n" " 'linear' — input is already scene-linear " "(Rec.709 primaries); written through unchanged. " "Use this for renderer/compositor output." ), ), ]), ], tooltip="The file format in which to save the image.", ), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod def execute(cls, images, filename_prefix: str, image_format: dict) -> IO.NodeOutput: file_format = image_format["image_format"] bit_depth = image_format["bit_depth"] colorspace = image_format.get("colorspace", "sRGB") output_dir = folder_paths.get_output_directory() full_output_folder, filename, counter, subfolder, filename_prefix = ( folder_paths.get_save_image_path( filename_prefix, output_dir, images[0].shape[1], images[0].shape[0] ) ) prompt = cls.hidden.prompt extra_pnginfo = cls.hidden.extra_pnginfo write_metadata = not args.disable_metadata results = [] for batch_number, image in enumerate(images): encoded = _encode_image(image, file_format, bit_depth, colorspace) if write_metadata: if file_format == "png": encoded = inject_png_metadata(encoded, prompt, extra_pnginfo) elif file_format == "exr": encoded = inject_exr_metadata(encoded, prompt, extra_pnginfo, colorspace) name = filename.replace("%batch_num%", str(batch_number)) file = f"{name}_{counter:05}.{file_format}" with open(os.path.join(full_output_folder, file), "wb") as f: f.write(encoded) results.append({"filename": file, "subfolder": subfolder, "type": "output"}) counter += 1 return IO.NodeOutput(ui={"images": results}) class ImagesExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ ImageCrop, ImageCropV2, BoundingBox, RepeatImageBatch, ImageFromBatch, ImageAddNoise, SaveAnimatedWEBP, SaveAnimatedPNG, SaveImageAdvanced, SaveSVGNode, ImageStitch, ResizeAndPadImage, GetImageSize, ImageRotate, ImageFlip, ImageScaleToMaxDimension, SplitImageToTileList, ImageMergeTileList, ] async def comfy_entrypoint() -> ImagesExtension: return ImagesExtension()