From e517a6bf4cf2f180c4eb7d40507102369284f35a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 17 Dec 2025 19:35:06 -0800 Subject: [PATCH] Add Resize Image/Mask node that thanks to MatchType+DynamicCombo is 16-nodes-in-1 --- comfy_extras/nodes_post_processing.py | 203 ++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 34c388a5a..0d241e327 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -4,11 +4,14 @@ import torch import torch.nn.functional as F from PIL import Image import math +from enum import Enum +from typing import TypedDict, Literal import comfy.utils import comfy.model_management import node_helpers from comfy_api.latest import ComfyExtension, io +from nodes import MAX_RESOLUTION class Blend(io.ComfyNode): @classmethod @@ -240,6 +243,205 @@ class ImageScaleToTotalPixels(io.ComfyNode): s = s.movedim(1,-1) return io.NodeOutput(s) +class ResizeType(str, Enum): + SCALE_BY = "scale by multiplier" + SCALE_DIMENSIONS = "scale dimensions" + SCALE_LONGER_DIMENSION = "scale longer dimension" + SCALE_SHORTER_DIMENSION = "scale shorter dimension" + SCALE_WIDTH = "scale width" + SCALE_HEIGHT = "scale height" + SCALE_TOTAL_PIXELS = "scale total pixels" + MATCH_SIZE = "match size" + +def is_image(input: torch.Tensor) -> bool: + # images have 4 dimensions: [batch, height, width, channels] + # masks have 3 dimensions: [batch, height, width] + return len(input.shape) == 4 + +def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(-1, 1) + else: + input = input.unsqueeze(1) + return input + +def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(1, -1) + else: + input = input.squeeze(1) + return input + +def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = round(input.shape[-1] * multiplier) + height = round(input.shape[-2] * multiplier) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor: + if width == 0 and height == 0: + return input + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + + if width == 0: + width = max(1, round(input.shape[-1] * height / input.shape[-2])) + elif height == 0: + height = max(1, round(input.shape[-2] * width / input.shape[-1])) + + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height > width: + width = round((width / height) * longer_size) + height = longer_size + elif width > height: + height = round((height / width) * longer_size) + width = longer_size + else: + height = longer_size + width = longer_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height < width: + width = round((width / height) * shorter_size) + height = shorter_size + elif width > height: + height = round((height / width) * shorter_size) + width = shorter_size + else: + height = shorter_size + width = shorter_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2])) + width = round(input.shape[-1] * scale_by) + height = round(input.shape[-2] * scale_by) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + match = init_image_mask_input(match, is_image(match)) + + width = match.shape[-1] + height = match.shape[-2] + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +class ResizeImageMaskNode(io.ComfyNode): + + scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + class ResizeTypedDict(TypedDict): + resize_type: ResizeType + scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop: Literal["disabled", "center"] + multiplier: float + width: int + height: int + longer_size: int + shorter_size: int + megapixels: float + + @classmethod + def define_schema(cls): + template = io.MatchType.Template("input_type", [io.Image, io.Mask]) + crop_combo = io.Combo.Input("crop", options=cls.crop_methods) + return io.Schema( + node_id="ResizeImageMaskNode", + display_name="Resize Image/Mask", + category="transform", + inputs=[ + io.MatchType.Input("input", template=template), + io.DynamicCombo.Input("resize_type", options=[ + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + crop_combo, + ]), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask]), + crop_combo, + ]), + ]), + io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + ], + outputs=[io.MatchType.Output(template=template, display_name="resized")] + ) + + @classmethod + def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput: + selected_type = resize_type["resize_type"] + if selected_type == ResizeType.SCALE_BY: + return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method)) + elif selected_type == ResizeType.SCALE_DIMENSIONS: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_LONGER_DIMENSION: + return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method)) + elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION: + return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method)) + elif selected_type == ResizeType.SCALE_WIDTH: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method)) + elif selected_type == ResizeType.SCALE_HEIGHT: + return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method)) + elif selected_type == ResizeType.SCALE_TOTAL_PIXELS: + return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) + elif selected_type == ResizeType.MATCH_SIZE: + return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + raise ValueError(f"Unsupported resize type: {selected_type}") + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -249,6 +451,7 @@ class PostProcessingExtension(ComfyExtension): Quantize, Sharpen, ImageScaleToTotalPixels, + ResizeImageMaskNode, ] async def comfy_entrypoint() -> PostProcessingExtension: