From a6ae08b7ee4af07c70bf8746cf56b958ff2c58b5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 13 Dec 2025 11:22:25 +0800 Subject: [PATCH] modify existing node for needed feature instead of new node --- comfy_extras/nodes_dataset.py | 84 ++------------------------- comfy_extras/nodes_post_processing.py | 11 ++-- 2 files changed, 10 insertions(+), 85 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 2f1f304f8..6af571117 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,6 +1,5 @@ import logging import os -import math import json import numpy as np @@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode): # ========== Image Transform Nodes ========== -class ResizeImagesToSameSizeNode(ImageProcessingNode): - node_id = "ResizeImagesToSameSize" - display_name = "Resize Images to Same Size" - description = "Resize all images to the same width and height." - extra_inputs = [ - io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), - io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), - io.Combo.Input( - "mode", - options=["stretch", "crop_center", "pad"], - default="stretch", - tooltip="Resize mode.", - ), - ] - - @classmethod - def _process(cls, image, width, height, mode): - img = tensor_to_pil(image) - - if mode == "stretch": - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "crop_center": - left = max(0, (img.width - width) // 2) - top = max(0, (img.height - height) // 2) - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - if img.width != width or img.height != height: - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "pad": - img.thumbnail((width, height), Image.Resampling.LANCZOS) - new_img = Image.new("RGB", (width, height), (0, 0, 0)) - paste_x = (width - img.width) // 2 - paste_y = (height - img.height) // 2 - new_img.paste(img, (paste_x, paste_y)) - img = new_img - - return pil_to_tensor(img) - - -class ResizeImagesToPixelCountNode(ImageProcessingNode): - node_id = "ResizeImagesToPixelCount" - display_name = "Resize Images to Pixel Count" - description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." - extra_inputs = [ - io.Int.Input( - "pixel_count", - default=512 * 512, - min=1, - max=8192 * 8192, - tooltip="Target pixel count.", - ), - io.Int.Input( - "steps", - default=64, - min=1, - max=128, - tooltip="The stepping for resize width/height.", - ), - ] - - @classmethod - def _process(cls, image, pixel_count, steps): - img = tensor_to_pil(image) - w, h = img.size - pixel_count_ratio = math.sqrt(pixel_count / (w * h)) - new_w = int(w * pixel_count_ratio / steps) * steps - new_h = int(h * pixel_count_ratio / steps) * steps - logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) - - class ResizeImagesByShorterEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByShorterEdge" display_name = "Resize Images by Shorter Edge" @@ -1285,8 +1211,8 @@ class ResolutionBucket(io.ComfyNode): buckets = {} # (H, W) -> {"latents": list, "conditions": list} for latent, cond in zip(flat_latents, flat_conditions): - # latent shape is (C, H, W) - h, w = latent.shape[1], latent.shape[2] + # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W) + h, w = latent.shape[-2], latent.shape[-1] key = (h, w) if key not in buckets: @@ -1296,11 +1222,11 @@ class ResolutionBucket(io.ComfyNode): buckets[key]["conditions"].append(cond) # Convert buckets to output format - output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W) + output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W) output_conditions = [] # list[list[cond]] where each inner list has Bi conditions for (h, w), bucket_data in buckets.items(): - # Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W) + # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W) stacked_latents = torch.stack(bucket_data["latents"], dim=0) output_latents.append({"samples": stacked_latents}) @@ -1589,8 +1515,6 @@ class DatasetExtension(ComfyExtension): SaveImageDataSetToFolderNode, SaveImageTextDataSetToFolderNode, # Image transform nodes - ResizeImagesToSameSizeNode, - ResizeImagesToPixelCountNode, ResizeImagesByShorterEdgeNode, ResizeImagesByLongerEdgeNode, CenterCropImagesNode, diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 34c388a5a..ca2cdeb50 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): io.Image.Input("image"), io.Combo.Input("upscale_method", options=cls.upscale_methods), io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + io.Int.Input("resolution_steps", default=1, min=1, max=256), ], outputs=[ io.Image.Output(), @@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode): ) @classmethod - def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: + def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput: samples = image.movedim(-1,1) - total = int(megapixels * 1024 * 1024) + total = megapixels * 1024 * 1024 scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) + width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps + height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps - s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled") s = s.movedim(1,-1) return io.NodeOutput(s)