mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
modify existing node for needed feature instead of new node
This commit is contained in:
parent
d330bb2a37
commit
a6ae08b7ee
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode):
|
||||
# ========== Image Transform Nodes ==========
|
||||
|
||||
|
||||
class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesToSameSize"
|
||||
display_name = "Resize Images to Same Size"
|
||||
description = "Resize all images to the same width and height."
|
||||
extra_inputs = [
|
||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
|
||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
|
||||
io.Combo.Input(
|
||||
"mode",
|
||||
options=["stretch", "crop_center", "pad"],
|
||||
default="stretch",
|
||||
tooltip="Resize mode.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, width, height, mode):
|
||||
img = tensor_to_pil(image)
|
||||
|
||||
if mode == "stretch":
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "crop_center":
|
||||
left = max(0, (img.width - width) // 2)
|
||||
top = max(0, (img.height - height) // 2)
|
||||
right = min(img.width, left + width)
|
||||
bottom = min(img.height, top + height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
if img.width != width or img.height != height:
|
||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
elif mode == "pad":
|
||||
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
||||
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
||||
paste_x = (width - img.width) // 2
|
||||
paste_y = (height - img.height) // 2
|
||||
new_img.paste(img, (paste_x, paste_y))
|
||||
img = new_img
|
||||
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesToPixelCount"
|
||||
display_name = "Resize Images to Pixel Count"
|
||||
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
|
||||
extra_inputs = [
|
||||
io.Int.Input(
|
||||
"pixel_count",
|
||||
default=512 * 512,
|
||||
min=1,
|
||||
max=8192 * 8192,
|
||||
tooltip="Target pixel count.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
default=64,
|
||||
min=1,
|
||||
max=128,
|
||||
tooltip="The stepping for resize width/height.",
|
||||
),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _process(cls, image, pixel_count, steps):
|
||||
img = tensor_to_pil(image)
|
||||
w, h = img.size
|
||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
||||
new_w = int(w * pixel_count_ratio / steps) * steps
|
||||
new_h = int(h * pixel_count_ratio / steps) * steps
|
||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
return pil_to_tensor(img)
|
||||
|
||||
|
||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||
node_id = "ResizeImagesByShorterEdge"
|
||||
display_name = "Resize Images by Shorter Edge"
|
||||
@ -1285,8 +1211,8 @@ class ResolutionBucket(io.ComfyNode):
|
||||
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||
|
||||
for latent, cond in zip(flat_latents, flat_conditions):
|
||||
# latent shape is (C, H, W)
|
||||
h, w = latent.shape[1], latent.shape[2]
|
||||
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
||||
h, w = latent.shape[-2], latent.shape[-1]
|
||||
key = (h, w)
|
||||
|
||||
if key not in buckets:
|
||||
@ -1296,11 +1222,11 @@ class ResolutionBucket(io.ComfyNode):
|
||||
buckets[key]["conditions"].append(cond)
|
||||
|
||||
# Convert buckets to output format
|
||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W)
|
||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
||||
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||
|
||||
for (h, w), bucket_data in buckets.items():
|
||||
# Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W)
|
||||
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
||||
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||
output_latents.append({"samples": stacked_latents})
|
||||
|
||||
@ -1589,8 +1515,6 @@ class DatasetExtension(ComfyExtension):
|
||||
SaveImageDataSetToFolderNode,
|
||||
SaveImageTextDataSetToFolderNode,
|
||||
# Image transform nodes
|
||||
ResizeImagesToSameSizeNode,
|
||||
ResizeImagesToPixelCountNode,
|
||||
ResizeImagesByShorterEdgeNode,
|
||||
ResizeImagesByLongerEdgeNode,
|
||||
CenterCropImagesNode,
|
||||
|
||||
@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
|
||||
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
||||
samples = image.movedim(-1,1)
|
||||
total = int(megapixels * 1024 * 1024)
|
||||
total = megapixels * 1024 * 1024
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
|
||||
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
||||
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
|
||||
s = s.movedim(1,-1)
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user