mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Improve scaling and fit for diffusion
This commit is contained in:
parent
dbc8ee92a5
commit
667b77149e
@ -1580,8 +1580,6 @@ class LoadImage:
|
|||||||
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
|
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
|
|
||||||
img = node_helpers.pillow(Image.open, image_path)
|
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
output_masks = []
|
output_masks = []
|
||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|||||||
@ -2,16 +2,60 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Literal, Tuple
|
from typing import Literal, Tuple
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
|
from comfy import utils
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
from comfy.component_model.tensor_types import ImageBatch
|
from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch
|
||||||
|
from comfy.nodes.base_nodes import ImageScale
|
||||||
from comfy.nodes.common import MAX_RESOLUTION
|
from comfy.nodes.common import MAX_RESOLUTION
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
|
|
||||||
|
def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch:
|
||||||
|
"""
|
||||||
|
Apply a levels adjustment to an sRGB image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): Input image tensor of shape (B, H, W, C) with values in range [0, 1]
|
||||||
|
black_level (float): Black point (default: 0.0)
|
||||||
|
mid_level (float): Midtone point (default: 0.5)
|
||||||
|
white_level (float): White point (default: 1.0)
|
||||||
|
clip (bool): Whether to clip the output values to [0, 1] range (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Adjusted image tensor of shape (B, H, W, C)
|
||||||
|
"""
|
||||||
|
# Ensure input is in correct shape and range
|
||||||
|
assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)"
|
||||||
|
assert 0 <= black_level < mid_level < white_level <= 1, "Levels should be in ascending order in range [0, 1]"
|
||||||
|
|
||||||
|
def srgb_to_linear(x):
|
||||||
|
return torch.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
|
||||||
|
|
||||||
|
def linear_to_srgb(x):
|
||||||
|
return torch.where(x <= 0.0031308, x * 12.92, 1.055 * x ** (1 / 2.4) - 0.055)
|
||||||
|
|
||||||
|
linear = srgb_to_linear(image)
|
||||||
|
|
||||||
|
adjusted = (linear - black_level) / (white_level - black_level)
|
||||||
|
|
||||||
|
power_factor = torch.log2(torch.tensor(0.5, device=image.device)) / torch.log2(torch.tensor(mid_level, device=image.device))
|
||||||
|
|
||||||
|
# apply power function to avoid nans
|
||||||
|
adjusted = torch.where(adjusted > 0, torch.pow(adjusted.clamp(min=1e-8), power_factor), adjusted)
|
||||||
|
|
||||||
|
result = linear_to_srgb(adjusted)
|
||||||
|
|
||||||
|
if clip:
|
||||||
|
result = torch.clamp(result, 0.0, 1.0)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCrop:
|
class ImageCrop:
|
||||||
@ -228,7 +272,8 @@ class ImageResize:
|
|||||||
"required": {
|
"required": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
"resize_mode": (["cover", "contain", "auto"], {"default": "cover"}),
|
"resize_mode": (["cover", "contain", "auto"], {"default": "cover"}),
|
||||||
"resolutions": (["SDXL/SD3/Flux", "SD1.5", ], {"default": "SDXL/SD3/Flux"})
|
"resolutions": (["SDXL/SD3/Flux", "SD1.5"], {"default": "SDXL/SD3/Flux"}),
|
||||||
|
"interpolation": (ImageScale.upscale_methods, {"default": "bilinear"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +281,7 @@ class ImageResize:
|
|||||||
FUNCTION = "resize_image"
|
FUNCTION = "resize_image"
|
||||||
CATEGORY = "image/transform"
|
CATEGORY = "image/transform"
|
||||||
|
|
||||||
def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5",]) -> Tuple[ImageBatch]:
|
def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str) -> Tuple[RGBImageBatch]:
|
||||||
if resolutions == "SDXL/SD3/Flux":
|
if resolutions == "SDXL/SD3/Flux":
|
||||||
supported_resolutions = [
|
supported_resolutions = [
|
||||||
(640, 1536),
|
(640, 1536),
|
||||||
@ -256,58 +301,112 @@ class ImageResize:
|
|||||||
|
|
||||||
resized_images = []
|
resized_images = []
|
||||||
for img in image:
|
for img in image:
|
||||||
img_np = (img.cpu().numpy() * 255).astype(np.uint8)
|
h, w = img.shape[:2]
|
||||||
h, w = img_np.shape[:2]
|
|
||||||
current_aspect_ratio = w / h
|
current_aspect_ratio = w / h
|
||||||
target_resolution = min(supported_resolutions,
|
target_resolution = min(supported_resolutions,
|
||||||
key=lambda res: abs(res[0] / res[1] - current_aspect_ratio))
|
key=lambda res: abs(res[0] / res[1] - current_aspect_ratio))
|
||||||
scale_w, scale_h = target_resolution[0] / w, target_resolution[1] / h
|
|
||||||
|
|
||||||
if resize_mode == "cover":
|
if resize_mode == "cover":
|
||||||
scale = max(scale_w, scale_h)
|
scale = max(target_resolution[0] / w, target_resolution[1] / h)
|
||||||
new_w, new_h = int(w * scale), int(h * scale)
|
new_w, new_h = int(w * scale), int(h * scale)
|
||||||
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
|
||||||
x1 = (new_w - target_resolution[0]) // 2
|
|
||||||
y1 = (new_h - target_resolution[1]) // 2
|
|
||||||
resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
|
|
||||||
elif resize_mode == "contain":
|
elif resize_mode == "contain":
|
||||||
scale = min(scale_w, scale_h)
|
scale = min(target_resolution[0] / w, target_resolution[1] / h)
|
||||||
new_w, new_h = int(w * scale), int(h * scale)
|
new_w, new_h = int(w * scale), int(h * scale)
|
||||||
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
else: # auto
|
||||||
canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8)
|
|
||||||
x1 = (target_resolution[0] - new_w) // 2
|
|
||||||
y1 = (target_resolution[1] - new_h) // 2
|
|
||||||
canvas[y1:y1 + new_h, x1:x1 + new_w] = resized
|
|
||||||
resized = canvas
|
|
||||||
else:
|
|
||||||
if current_aspect_ratio > target_resolution[0] / target_resolution[1]:
|
if current_aspect_ratio > target_resolution[0] / target_resolution[1]:
|
||||||
scale = scale_w
|
new_w, new_h = target_resolution[0], int(h * target_resolution[0] / w)
|
||||||
else:
|
else:
|
||||||
scale = scale_h
|
new_w, new_h = int(w * target_resolution[1] / h), target_resolution[1]
|
||||||
new_w, new_h = int(w * scale), int(h * scale)
|
|
||||||
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
# convert to b, c, h, w
|
||||||
if new_w > target_resolution[0] or new_h > target_resolution[1]:
|
img_tensor = img.permute(2, 0, 1).unsqueeze(0)
|
||||||
x1 = (new_w - target_resolution[0]) // 2
|
|
||||||
y1 = (new_h - target_resolution[1]) // 2
|
# Use common_upscale for resizing
|
||||||
resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
|
resized = utils.common_upscale(img_tensor, new_w, new_h, interpolation, "disabled")
|
||||||
else:
|
|
||||||
canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8)
|
# handle padding or cropping
|
||||||
x1 = (target_resolution[0] - new_w) // 2
|
if resize_mode == "contain":
|
||||||
|
canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype)
|
||||||
|
y1 = (target_resolution[1] - new_h) // 2
|
||||||
|
x1 = (target_resolution[0] - new_w) // 2
|
||||||
|
canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized
|
||||||
|
resized = canvas
|
||||||
|
elif resize_mode == "cover":
|
||||||
|
y1 = (new_h - target_resolution[1]) // 2
|
||||||
|
x1 = (new_w - target_resolution[0]) // 2
|
||||||
|
resized = resized[:, :, y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
|
||||||
|
else: # auto
|
||||||
|
if new_w != target_resolution[0] or new_h != target_resolution[1]:
|
||||||
|
canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype)
|
||||||
y1 = (target_resolution[1] - new_h) // 2
|
y1 = (target_resolution[1] - new_h) // 2
|
||||||
canvas[y1:y1 + new_h, x1:x1 + new_w] = resized
|
x1 = (target_resolution[0] - new_w) // 2
|
||||||
|
canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized
|
||||||
resized = canvas
|
resized = canvas
|
||||||
|
|
||||||
resized_images.append(resized)
|
resized_images.append(resized.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0))
|
||||||
|
|
||||||
return (torch.from_numpy(np.stack(resized_images)).float() / 255.0,)
|
return (torch.stack(resized_images),)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageLevels(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"black_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"mid_level": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 0.99, "step": 0.01}),
|
||||||
|
"white_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"clip": ("BOOLEAN", {"default": True}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "apply_levels"
|
||||||
|
CATEGORY = "image/adjust"
|
||||||
|
|
||||||
|
def apply_levels(self, image: ImageBatch, black_level: float, mid_level: float, white_level: float, clip: bool) -> Tuple[ImageBatch]:
|
||||||
|
adjusted_image = levels_adjustment(image, black_level, mid_level, white_level, clip)
|
||||||
|
return (adjusted_image,)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageLuminance:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "compute_luminance"
|
||||||
|
CATEGORY = "image/color"
|
||||||
|
|
||||||
|
def compute_luminance(self, image: ImageBatch) -> Tuple[ImageBatch]:
|
||||||
|
assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)"
|
||||||
|
|
||||||
|
# define srgb luminance coefficients
|
||||||
|
coeffs = torch.tensor([0.2126, 0.7152, 0.0722], device=image.device, dtype=image.dtype)
|
||||||
|
|
||||||
|
luminance = torch.sum(image * coeffs, dim=-1, keepdim=True)
|
||||||
|
luminance = luminance.expand(-1, -1, -1, 3)
|
||||||
|
|
||||||
|
return (luminance,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageResize": ImageResize,
|
"ImageResize": ImageResize,
|
||||||
"ImageShape": ImageShape,
|
"ImageShape": ImageShape,
|
||||||
"ImageCrop": ImageCrop,
|
"ImageCrop": ImageCrop,
|
||||||
|
"ImageLevels": ImageLevels,
|
||||||
|
"ImageLuminance": ImageLuminance,
|
||||||
"RepeatImageBatch": RepeatImageBatch,
|
"RepeatImageBatch": RepeatImageBatch,
|
||||||
"ImageFromBatch": ImageFromBatch,
|
"ImageFromBatch": ImageFromBatch,
|
||||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"ImageResize": "Fit Image to Diffusion Size"
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user