modify existing node for needed feature instead of new node

This commit is contained in:
Kohaku-Blueleaf 2025-12-13 11:22:25 +08:00
parent d330bb2a37
commit a6ae08b7ee
2 changed files with 10 additions and 85 deletions

View File

@ -1,6 +1,5 @@
import logging import logging
import os import os
import math
import json import json
import numpy as np import numpy as np
@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode):
# ========== Image Transform Nodes ========== # ========== Image Transform Nodes ==========
class ResizeImagesToSameSizeNode(ImageProcessingNode):
node_id = "ResizeImagesToSameSize"
display_name = "Resize Images to Same Size"
description = "Resize all images to the same width and height."
extra_inputs = [
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
io.Combo.Input(
"mode",
options=["stretch", "crop_center", "pad"],
default="stretch",
tooltip="Resize mode.",
),
]
@classmethod
def _process(cls, image, width, height, mode):
img = tensor_to_pil(image)
if mode == "stretch":
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "crop_center":
left = max(0, (img.width - width) // 2)
top = max(0, (img.height - height) // 2)
right = min(img.width, left + width)
bottom = min(img.height, top + height)
img = img.crop((left, top, right, bottom))
if img.width != width or img.height != height:
img = img.resize((width, height), Image.Resampling.LANCZOS)
elif mode == "pad":
img.thumbnail((width, height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (width, height), (0, 0, 0))
paste_x = (width - img.width) // 2
paste_y = (height - img.height) // 2
new_img.paste(img, (paste_x, paste_y))
img = new_img
return pil_to_tensor(img)
class ResizeImagesToPixelCountNode(ImageProcessingNode):
node_id = "ResizeImagesToPixelCount"
display_name = "Resize Images to Pixel Count"
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
extra_inputs = [
io.Int.Input(
"pixel_count",
default=512 * 512,
min=1,
max=8192 * 8192,
tooltip="Target pixel count.",
),
io.Int.Input(
"steps",
default=64,
min=1,
max=128,
tooltip="The stepping for resize width/height.",
),
]
@classmethod
def _process(cls, image, pixel_count, steps):
img = tensor_to_pil(image)
w, h = img.size
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
new_w = int(w * pixel_count_ratio / steps) * steps
new_h = int(h * pixel_count_ratio / steps) * steps
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
return pil_to_tensor(img)
class ResizeImagesByShorterEdgeNode(ImageProcessingNode): class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
node_id = "ResizeImagesByShorterEdge" node_id = "ResizeImagesByShorterEdge"
display_name = "Resize Images by Shorter Edge" display_name = "Resize Images by Shorter Edge"
@ -1285,8 +1211,8 @@ class ResolutionBucket(io.ComfyNode):
buckets = {} # (H, W) -> {"latents": list, "conditions": list} buckets = {} # (H, W) -> {"latents": list, "conditions": list}
for latent, cond in zip(flat_latents, flat_conditions): for latent, cond in zip(flat_latents, flat_conditions):
# latent shape is (C, H, W) # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
h, w = latent.shape[1], latent.shape[2] h, w = latent.shape[-2], latent.shape[-1]
key = (h, w) key = (h, w)
if key not in buckets: if key not in buckets:
@ -1296,11 +1222,11 @@ class ResolutionBucket(io.ComfyNode):
buckets[key]["conditions"].append(cond) buckets[key]["conditions"].append(cond)
# Convert buckets to output format # Convert buckets to output format
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W) output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
for (h, w), bucket_data in buckets.items(): for (h, w), bucket_data in buckets.items():
# Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W) # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
stacked_latents = torch.stack(bucket_data["latents"], dim=0) stacked_latents = torch.stack(bucket_data["latents"], dim=0)
output_latents.append({"samples": stacked_latents}) output_latents.append({"samples": stacked_latents})
@ -1589,8 +1515,6 @@ class DatasetExtension(ComfyExtension):
SaveImageDataSetToFolderNode, SaveImageDataSetToFolderNode,
SaveImageTextDataSetToFolderNode, SaveImageTextDataSetToFolderNode,
# Image transform nodes # Image transform nodes
ResizeImagesToSameSizeNode,
ResizeImagesToPixelCountNode,
ResizeImagesByShorterEdgeNode, ResizeImagesByShorterEdgeNode,
ResizeImagesByLongerEdgeNode, ResizeImagesByLongerEdgeNode,
CenterCropImagesNode, CenterCropImagesNode,

View File

@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
io.Image.Input("image"), io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods), io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
], ],
outputs=[ outputs=[
io.Image.Output(), io.Image.Output(),
@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput: def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
samples = image.movedim(-1,1) samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024) total = megapixels * 1024 * 1024
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by) width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
height = round(samples.shape[2] * scale_by) height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
s = s.movedim(1,-1) s = s.movedim(1,-1)
return io.NodeOutput(s) return io.NodeOutput(s)