From 0b8e03a1e7c26b15b680220f9fda002075c945f1 Mon Sep 17 00:00:00 2001 From: nomadoor Date: Wed, 14 Jan 2026 09:01:26 +0900 Subject: [PATCH] Make scale-to-multiple shape handling explicit --- comfy_extras/nodes_post_processing.py | 56 +++++++++++++++++---------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 48bf5f965..0433bbda2 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -254,7 +254,7 @@ class ResizeType(str, Enum): SCALE_HEIGHT = "scale height" SCALE_TOTAL_PIXELS = "scale total pixels" MATCH_SIZE = "match size" - CROP_TO_MULTIPLE = "crop to multiple" + SCALE_TO_MULTIPLE = "scale to multiple" def is_image(input: torch.Tensor) -> bool: # images have 4 dimensions: [batch, height, width, channels] @@ -364,26 +364,40 @@ def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str input = finalize_image_mask_input(input, is_type_image) return input -def crop_to_multiple(input: torch.Tensor, multiple: int, crop: str="center") -> torch.Tensor: +def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor: if multiple <= 1: return input - width = input.shape[2] - height = input.shape[1] - new_w = (width // multiple) * multiple - new_h = (height // multiple) * multiple - if new_w == 0 or new_h == 0: - return input - if new_w == width and new_h == height: - return input - if crop == "center": - x0 = (width - new_w) // 2 - y0 = (height - new_h) // 2 + is_type_image = is_image(input) + if is_type_image: + _, height, width, _ = input.shape else: - x0 = 0 - y0 = 0 - x1 = x0 + new_w - y1 = y0 + new_h - if is_image(input): + _, height, width = input.shape + target_w = (width // multiple) * multiple + target_h = (height // multiple) * multiple + if target_w == 0 or target_h == 0: + return input + if target_w == width and target_h == height: + return input + s_w = target_w / width + s_h = target_h / height + if s_w >= s_h: + scaled_w = target_w + scaled_h = int(math.ceil(height * s_w)) + if scaled_h < target_h: + scaled_h = target_h + else: + scaled_h = target_h + scaled_w = int(math.ceil(width * s_h)) + if scaled_w < target_w: + scaled_w = target_w + input = init_image_mask_input(input, is_type_image) + input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + x0 = (scaled_w - target_w) // 2 + y0 = (scaled_h - target_h) // 2 + x1 = x0 + target_w + y1 = y0 + target_h + if is_type_image: return input[:, y0:y1, x0:x1, :] return input[:, y0:y1, x0:x1] @@ -442,7 +456,7 @@ class ResizeImageMaskNode(io.ComfyNode): io.MultiType.Input("match", [io.Image, io.Mask]), crop_combo, ]), - io.DynamicCombo.Option(ResizeType.CROP_TO_MULTIPLE, [ + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1), ]), ]), @@ -470,8 +484,8 @@ class ResizeImageMaskNode(io.ComfyNode): return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) elif selected_type == ResizeType.MATCH_SIZE: return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) - elif selected_type == ResizeType.CROP_TO_MULTIPLE: - return io.NodeOutput(crop_to_multiple(input, resize_type["multiple"])) + elif selected_type == ResizeType.SCALE_TO_MULTIPLE: + return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method)) raise ValueError(f"Unsupported resize type: {selected_type}") def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: