mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
modify existing node for needed feature instead of new node
This commit is contained in:
parent
d330bb2a37
commit
a6ae08b7ee
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import math
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode):
|
|||||||
# ========== Image Transform Nodes ==========
|
# ========== Image Transform Nodes ==========
|
||||||
|
|
||||||
|
|
||||||
class ResizeImagesToSameSizeNode(ImageProcessingNode):
|
|
||||||
node_id = "ResizeImagesToSameSize"
|
|
||||||
display_name = "Resize Images to Same Size"
|
|
||||||
description = "Resize all images to the same width and height."
|
|
||||||
extra_inputs = [
|
|
||||||
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."),
|
|
||||||
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."),
|
|
||||||
io.Combo.Input(
|
|
||||||
"mode",
|
|
||||||
options=["stretch", "crop_center", "pad"],
|
|
||||||
default="stretch",
|
|
||||||
tooltip="Resize mode.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _process(cls, image, width, height, mode):
|
|
||||||
img = tensor_to_pil(image)
|
|
||||||
|
|
||||||
if mode == "stretch":
|
|
||||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
|
||||||
elif mode == "crop_center":
|
|
||||||
left = max(0, (img.width - width) // 2)
|
|
||||||
top = max(0, (img.height - height) // 2)
|
|
||||||
right = min(img.width, left + width)
|
|
||||||
bottom = min(img.height, top + height)
|
|
||||||
img = img.crop((left, top, right, bottom))
|
|
||||||
if img.width != width or img.height != height:
|
|
||||||
img = img.resize((width, height), Image.Resampling.LANCZOS)
|
|
||||||
elif mode == "pad":
|
|
||||||
img.thumbnail((width, height), Image.Resampling.LANCZOS)
|
|
||||||
new_img = Image.new("RGB", (width, height), (0, 0, 0))
|
|
||||||
paste_x = (width - img.width) // 2
|
|
||||||
paste_y = (height - img.height) // 2
|
|
||||||
new_img.paste(img, (paste_x, paste_y))
|
|
||||||
img = new_img
|
|
||||||
|
|
||||||
return pil_to_tensor(img)
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeImagesToPixelCountNode(ImageProcessingNode):
|
|
||||||
node_id = "ResizeImagesToPixelCount"
|
|
||||||
display_name = "Resize Images to Pixel Count"
|
|
||||||
description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio."
|
|
||||||
extra_inputs = [
|
|
||||||
io.Int.Input(
|
|
||||||
"pixel_count",
|
|
||||||
default=512 * 512,
|
|
||||||
min=1,
|
|
||||||
max=8192 * 8192,
|
|
||||||
tooltip="Target pixel count.",
|
|
||||||
),
|
|
||||||
io.Int.Input(
|
|
||||||
"steps",
|
|
||||||
default=64,
|
|
||||||
min=1,
|
|
||||||
max=128,
|
|
||||||
tooltip="The stepping for resize width/height.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _process(cls, image, pixel_count, steps):
|
|
||||||
img = tensor_to_pil(image)
|
|
||||||
w, h = img.size
|
|
||||||
pixel_count_ratio = math.sqrt(pixel_count / (w * h))
|
|
||||||
new_w = int(w * pixel_count_ratio / steps) * steps
|
|
||||||
new_h = int(h * pixel_count_ratio / steps) * steps
|
|
||||||
logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}")
|
|
||||||
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
||||||
return pil_to_tensor(img)
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
||||||
node_id = "ResizeImagesByShorterEdge"
|
node_id = "ResizeImagesByShorterEdge"
|
||||||
display_name = "Resize Images by Shorter Edge"
|
display_name = "Resize Images by Shorter Edge"
|
||||||
@ -1285,8 +1211,8 @@ class ResolutionBucket(io.ComfyNode):
|
|||||||
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
||||||
|
|
||||||
for latent, cond in zip(flat_latents, flat_conditions):
|
for latent, cond in zip(flat_latents, flat_conditions):
|
||||||
# latent shape is (C, H, W)
|
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
||||||
h, w = latent.shape[1], latent.shape[2]
|
h, w = latent.shape[-2], latent.shape[-1]
|
||||||
key = (h, w)
|
key = (h, w)
|
||||||
|
|
||||||
if key not in buckets:
|
if key not in buckets:
|
||||||
@ -1296,11 +1222,11 @@ class ResolutionBucket(io.ComfyNode):
|
|||||||
buckets[key]["conditions"].append(cond)
|
buckets[key]["conditions"].append(cond)
|
||||||
|
|
||||||
# Convert buckets to output format
|
# Convert buckets to output format
|
||||||
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, C, H, W)
|
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
||||||
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
||||||
|
|
||||||
for (h, w), bucket_data in buckets.items():
|
for (h, w), bucket_data in buckets.items():
|
||||||
# Stack latents into batch: list of (C, H, W) -> (Bi, C, H, W)
|
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
||||||
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
||||||
output_latents.append({"samples": stacked_latents})
|
output_latents.append({"samples": stacked_latents})
|
||||||
|
|
||||||
@ -1589,8 +1515,6 @@ class DatasetExtension(ComfyExtension):
|
|||||||
SaveImageDataSetToFolderNode,
|
SaveImageDataSetToFolderNode,
|
||||||
SaveImageTextDataSetToFolderNode,
|
SaveImageTextDataSetToFolderNode,
|
||||||
# Image transform nodes
|
# Image transform nodes
|
||||||
ResizeImagesToSameSizeNode,
|
|
||||||
ResizeImagesToPixelCountNode,
|
|
||||||
ResizeImagesByShorterEdgeNode,
|
ResizeImagesByShorterEdgeNode,
|
||||||
ResizeImagesByLongerEdgeNode,
|
ResizeImagesByLongerEdgeNode,
|
||||||
CenterCropImagesNode,
|
CenterCropImagesNode,
|
||||||
|
|||||||
@ -221,6 +221,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||||
|
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(),
|
io.Image.Output(),
|
||||||
@ -228,15 +229,15 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image, upscale_method, megapixels) -> io.NodeOutput:
|
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
||||||
samples = image.movedim(-1,1)
|
samples = image.movedim(-1,1)
|
||||||
total = int(megapixels * 1024 * 1024)
|
total = megapixels * 1024 * 1024
|
||||||
|
|
||||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
width = round(samples.shape[3] * scale_by)
|
width = round(samples.shape[3] * scale_by / resolution_steps) * resolution_steps
|
||||||
height = round(samples.shape[2] * scale_by)
|
height = round(samples.shape[2] * scale_by / resolution_steps) * resolution_steps
|
||||||
|
|
||||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
s = comfy.utils.common_upscale(samples, int(width), int(height), upscale_method, "disabled")
|
||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user