From 92ab48531f5f2ab3d1e1b64965e6bd36532fd2c5 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Wed, 29 Apr 2026 12:12:12 +0800 Subject: [PATCH] Iterate on new Save Image node --- comfy_extras/nodes_convert_color_space.py | 109 ++++++++ comfy_extras/nodes_images.py | 308 +++++++++++++++++++++- nodes.py | 3 +- 3 files changed, 418 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/nodes_convert_color_space.py diff --git a/comfy_extras/nodes_convert_color_space.py b/comfy_extras/nodes_convert_color_space.py new file mode 100644 index 000000000..0481eaa71 --- /dev/null +++ b/comfy_extras/nodes_convert_color_space.py @@ -0,0 +1,109 @@ + +import torch +from comfy_api.latest import IO +from typing_extensions import override +from comfy_api.latest import ComfyExtension + + +# Rec.709 to Rec.2020 Gamut Conversion Matrix +M_709_to_2020 = torch.tensor([[0.6274, 0.3293, 0.0433],[0.0691, 0.9195, 0.0114],[0.0164, 0.0880, 0.8956] +]) + +# Rec.2020 to Rec.709 Gamut Conversion Matrix +M_2020_to_709 = torch.tensor([[ 1.6605, -0.5876, -0.0728],[-0.1246, 1.1329, -0.0083],[-0.0182, -0.1006, 1.1187] +]) + +def srgb_to_linear(tensor): + mask = tensor <= 0.04045 + return torch.where(mask, tensor / 12.92, torch.pow((tensor + 0.055) / 1.055, 2.4)) + +def linear_to_srgb(tensor): + mask = tensor <= 0.0031308 + return torch.where(mask, tensor * 12.92, 1.055 * torch.pow(tensor.clamp(min=1e-8), 1.0 / 2.4) - 0.055) + +def linear_to_pq(linear_tensor): + """SMPTE ST 2084 (PQ) encoding""" + m1, m2 = (2610 / 4096 / 4), (2523 / 4096 * 128) + c1, c2, c3 = (3424 / 4096), (2413 / 4096 * 32), (2392 / 4096 * 32) + l_norm = torch.clamp(linear_tensor, 0.0, 1.0) + l_m1 = torch.pow(l_norm, m1) + return torch.pow((c1 + c2 * l_m1) / (1 + c3 * l_m1), m2) + +def pq_to_linear(pq_tensor): + """Inverse SMPTE ST 2084 (PQ) decoding""" + m1, m2 = (2610 / 4096 / 4), (2523 / 4096 * 128) + c1, c2, c3 = (3424 / 4096), (2413 / 4096 * 32), (2392 / 4096 * 32) + n = torch.pow(torch.clamp(pq_tensor, 0.0, 1.0), 1/m2) + return torch.pow(torch.clamp((n - c1) / (c2 - c3 * n), min=0.0), 1/m1) + +class ConvertColorSpace(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Convert Color Space", + category="image/color", + inputs=[ + IO.Image.Input("images"), + IO.Combo.Input("source_color_space", options=["sRGB", "Linear", "HDR (Rec.2020)", "Grayscale"], default="sRGB"), + IO.Combo.Input("target_color_space", options=["sRGB", "Linear", "HDR (Rec.2020)", "Grayscale"], default="Linear"), + ], + outputs=[ + IO.Image.Output("images"), + ] + ) + + @classmethod + def execute(cls, images, source_color_space, target_color_space) -> IO.NodeOutput: + img_tensor = images.clone() + device = img_tensor.device + + has_alpha = img_tensor.shape[-1] == 4 + alpha = img_tensor[..., 3:4] if has_alpha else None + rgb = img_tensor[..., :3] + + # turn source into linear + if source_color_space == "sRGB": + rgb = srgb_to_linear(rgb) + + elif source_color_space == "Grayscale": + # assume Grayscale has sRGB gamma + luma = 0.2126 * rgb[..., 0] + 0.7152 * rgb[..., 1] + 0.0722 * rgb[..., 2] + rgb = luma.unsqueeze(-1).repeat(1, 1, 1, 3) + rgb = linear_to_srgb(rgb) + + elif source_color_space == "HDR (Rec.2020)": + # assuming Linear Rec.2020 input. Convert to Linear Rec.709 + matrix = M_2020_to_709.to(device) + rgb = pq_to_linear(rgb) + rgb = torch.matmul(rgb, matrix.T) + + + # turn source into target space + if target_color_space == "sRGB": + rgb = linear_to_srgb(rgb) + + elif target_color_space == "Grayscale": + luma = 0.2126 * rgb[..., 0] + 0.7152 * rgb[..., 1] + 0.0722 * rgb[..., 2] + rgb = luma.unsqueeze(-1).repeat(1, 1, 1, 3) + rgb = linear_to_srgb(rgb) # reapply srgb gamma + + elif target_color_space == "HDR (Rec.2020)": + # convert Gamut from Linear Rec.709 to Linear Rec.2020 + rgb = torch.matmul(rgb, M_709_to_2020.to(device).T).clamp(min=0) + rgb = linear_to_pq(rgb) + + img_tensor = torch.cat([rgb, alpha], dim=-1) if has_alpha else rgb + + return IO.NodeOutput(images=img_tensor) + + +class ConvertColorSpaceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ConvertColorSpace + ] + + +async def comfy_entrypoint() -> ConvertColorSpaceExtension: + return ConvertColorSpaceExtension() diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index a77f0641f..60d1fe739 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -3,15 +3,22 @@ from __future__ import annotations import nodes import folder_paths +import av +import io import json +import logging import os import re import math +import numpy as np +import struct import torch +import zlib import comfy.utils from server import PromptServer -from comfy_api.latest import ComfyExtension, IO, UI +from comfy_api.latest import ComfyExtension, Input, IO, UI +from comfy.cli_args import args from typing_extensions import override SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. @@ -823,6 +830,304 @@ class ImageMergeTileList(IO.ComfyNode): return IO.NodeOutput(merged_image) +def _create_png_chunk(chunk_type: bytes, data: bytes) -> bytes: + """Creates a valid PNG chunk with Length, Type, Data, and CRC32.""" + chunk = struct.pack('>I', len(data)) + chunk_type + data + crc = zlib.crc32(chunk_type + data) & 0xffffffff + return chunk + struct.pack('>I', crc) + + +def _inject_metadata_png(png_bytes, prompt=None, extra_pnginfo=None): + # IEND chunk is the last 12 bytes of png files + content = png_bytes[:-12] + iend = png_bytes[-12:] + + metadata_chunks = b"" + + if prompt is not None: + payload = b'prompt\x00' + json.dumps(prompt).encode('utf-8') + metadata_chunks += _create_png_chunk(b'tEXt', payload) + + if extra_pnginfo is not None: + for k, v in extra_pnginfo.items(): + payload = k.encode('utf-8') + b'\x00' + json.dumps(v).encode('utf-8') + metadata_chunks += _create_png_chunk(b'tEXt', payload) + + return content + metadata_chunks + iend + +def _inject_metadata_exr(exr_bytes: bytes, prompt, extra_pnginfo) -> bytes: + # skip magic and version + idx = 8 + + # parse through existing attributes to find the end of the header + while True: + name_start = idx + while exr_bytes[idx] != 0: + idx += 1 + name = exr_bytes[name_start:idx] + idx += 1 + + # empty name means we hit the header terminator + if len(name) == 0: + break + + # skip attribute type string + while exr_bytes[idx] != 0: + idx += 1 + idx += 1 + + # read attribute size and skip the value + attr_size = struct.unpack(' bytes: + metadata = {} + if prompt is not None: + metadata["prompt"] = prompt + if extra_pnginfo is not None: + for k, v in extra_pnginfo.items(): + metadata[k] = v + + payload = json.dumps(metadata).encode('utf-8') + + # 16-byte uuid required by isobmff spec + # 'comfyui_workflow' is exactly 16 bytes long! + comfy_uuid = b'comfyui_workflow' + + # box size: 4 (size) + 4 (type) + 16 (uuid) + payload length + box_size = 4 + 4 + 16 + len(payload) + uuid_box = struct.pack('>I', box_size) + b'uuid' + comfy_uuid + payload + + # isobmff allows top-level boxes at the end of the file. + return avif_bytes + uuid_box + + +class SaveImageAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveImageAdvanced", + search_aliases=["save", "save image", "export image", "output image", "write image", "download"], + display_name="Save Image", + description="Saves the input images to your ComfyUI output directory.", + category="image", + essentials_category="Basics", + inputs=[ + IO.Image.Input( + "images", + tooltip="The images to save." + ), + IO.String.Input( + "filename_prefix", + default="ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + IO.DynamicCombo.Input( + "file_format", + options=[ + IO.DynamicCombo.Option( + "png", + [ + IO.Combo.Input( + "bit_depth", + options=["8-bit", "16-bit"], + default="8-bit", + advanced=True, + ), + IO.Combo.Input( + "color_space", + options=["Raw/Data", "sRGB"], + default="sRGB", + advanced=True, + ), + ], + ), + IO.DynamicCombo.Option( + "avif", + [ + IO.Combo.Input( + "bit_depth", + options=["8-bit", "10-bit", "12-bit"], + default="8-bit", + advanced=True, + ), + IO.Combo.Input( + "color_space", + options=["sRGB"], + default="sRGB", + advanced=True, + ), + ], + ), + IO.DynamicCombo.Option( + "exr", + [ + IO.Combo.Input( + "bit_depth", + options=["16-bit (half-float)", "32-bit"], + default="16-bit (half-float)", + advanced=True, + ), + IO.Combo.Input( + "color_space", + options=["Linear", "Raw/Data"], + default="Linear", + advanced=True, + ), + ], + ), + ], + tooltip="The file format in which to save the image.", + ), + IO.Boolean.Input("embed_workflow", default=True, advanced=True), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute( + cls, + images: Input.Image, + filename_prefix: str, + file_format: dict, + embed_workflow: bool, + prompt=None, + extra_pnginfo=None + ) -> IO.NodeOutput: + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + + for batch_number, image in enumerate(images): + img_tensor = image.clone() + height, width, num_channels = img_tensor.shape + has_alpha = (num_channels == 4) + + # file pathing + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}.{file_format}" + file_path = os.path.join(full_output_folder, file) + + # get widget values from dynamic combo + format = file_format["file_format"] + bit_depth = file_format["bit_depth"] + color_space = file_format["color_space"] + + if bit_depth == "32-bit": + img_np = img_tensor.cpu().numpy() + # rgba128le handles 4x32f, gbrpf32le handles 3x32f planar + av_fmt = 'rgba128le' if has_alpha else 'gbrpf32le' + elif bit_depth == "16-bit": + img_np = (img_tensor * 65535.0).clamp(0, 65535).to(torch.int32).cpu().numpy().astype(np.uint16) + if format == "png": + # png requires Big-Endian (be) for 16-bit + av_fmt = 'rgba64be' if has_alpha else 'rgb48be' + img_np = img_np.byteswap().view(img_np.dtype.newbyteorder('>')) + else: + av_fmt = 'rgba64le' if has_alpha else 'rgb48le' + else: + img_np = (img_tensor * 255.0).clamp(0, 255).to(torch.int32).cpu().numpy().astype(np.uint8) + av_fmt = 'rgba' if has_alpha else 'rgb24' + + memory_buffer = io.BytesIO() + container_format = "image2" if format in ["png", "exr"] else "avif" + container = av.open(memory_buffer, mode='w', format=container_format) + + if format == "exr": + stream = container.add_stream('exr', rate=1) + stream.pix_fmt = av_fmt + elif format == "avif": + stream = container.add_stream('av1', rate=1) + # YUV color spaces + stream.pix_fmt = 'yuv444p12le' if bit_depth in ["16-bit", "32-bit"] else 'yuv444p' + elif format == "png": + stream = container.add_stream('png', rate=1) + stream.pix_fmt = av_fmt + + stream.width = width + stream.height = height + + # planar: all red, all blue, all green instead of r, g, b, r, g, b + is_planar = av_fmt.startswith('gbrp') or 'p' in av_fmt.split('rgba')[-1] + if is_planar: + img_np = img_np.transpose(2, 0, 1) + + try: + frame = av.VideoFrame.from_ndarray(img_np, format=av_fmt) + except ValueError: + # FFMPEG Float32 Fallback: not all ffmpeg versions are able to handle float32 format for images + # float16 fallback conversion + logging.warning("[WARNING] Current FFMPEG Binary can't save float32 images. Fallbacking to float16") + img_np = (img_tensor * 65535.0).clamp(0, 65535).to(torch.int32).cpu().numpy().astype(np.uint16) + av_fmt = 'rgba64le' if has_alpha else 'rgb48le' + frame = av.VideoFrame.from_ndarray(img_np, format=av_fmt) + if file_format == "exr" or file_format == "png": + stream.pix_fmt = av_fmt + + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) + + container.close() + + final_bytes = memory_buffer.getvalue() + + if embed_workflow and not args.disable_metadata: + if format == "png": + final_bytes = _inject_metadata_png(final_bytes, prompt, extra_pnginfo) + elif format == "exr": + final_bytes = _inject_metadata_exr(final_bytes, prompt, extra_pnginfo) + else: + final_bytes = _inject_metadata_avif(final_bytes, prompt, extra_pnginfo) + + with open(file_path, "wb") as f: + f.write(final_bytes) + + results.append({ + "filename": file, + "subfolder": subfolder, + "type": "output" + }) + counter += 1 + + return IO.NodeOutput(ui={"images": results}) + + class ImagesExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -835,6 +1140,7 @@ class ImagesExtension(ComfyExtension): ImageAddNoise, SaveAnimatedWEBP, SaveAnimatedPNG, + SaveImageAdvanced, SaveSVGNode, ImageStitch, ResizeAndPadImage, diff --git a/nodes.py b/nodes.py index db989a501..2c50d3021 100644 --- a/nodes.py +++ b/nodes.py @@ -1652,6 +1652,7 @@ class SaveImage: ESSENTIALS_CATEGORY = "Basics" DESCRIPTION = "Saves the input images to your ComfyUI output directory." SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"] + DEPRECATED = True def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append @@ -2157,7 +2158,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", # Image - "SaveImage": "Save Image", + "SaveImage": "Save Image (DEPRECATED)", "PreviewImage": "Preview Image", "LoadImage": "Load Image", "LoadImageMask": "Load Image (as Mask)",