mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Improve scaling and fit for diffusion
This commit is contained in:
parent
dbc8ee92a5
commit
667b77149e
@ -1580,8 +1580,6 @@ class LoadImage:
|
||||
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
@ -2,16 +2,60 @@ import json
|
||||
import os
|
||||
from typing import Literal, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from comfy import utils
|
||||
from comfy.cli_args import args
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.tensor_types import ImageBatch
|
||||
from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch
|
||||
from comfy.nodes.base_nodes import ImageScale
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch:
|
||||
"""
|
||||
Apply a levels adjustment to an sRGB image.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): Input image tensor of shape (B, H, W, C) with values in range [0, 1]
|
||||
black_level (float): Black point (default: 0.0)
|
||||
mid_level (float): Midtone point (default: 0.5)
|
||||
white_level (float): White point (default: 1.0)
|
||||
clip (bool): Whether to clip the output values to [0, 1] range (default: True)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Adjusted image tensor of shape (B, H, W, C)
|
||||
"""
|
||||
# Ensure input is in correct shape and range
|
||||
assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)"
|
||||
assert 0 <= black_level < mid_level < white_level <= 1, "Levels should be in ascending order in range [0, 1]"
|
||||
|
||||
def srgb_to_linear(x):
|
||||
return torch.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
|
||||
|
||||
def linear_to_srgb(x):
|
||||
return torch.where(x <= 0.0031308, x * 12.92, 1.055 * x ** (1 / 2.4) - 0.055)
|
||||
|
||||
linear = srgb_to_linear(image)
|
||||
|
||||
adjusted = (linear - black_level) / (white_level - black_level)
|
||||
|
||||
power_factor = torch.log2(torch.tensor(0.5, device=image.device)) / torch.log2(torch.tensor(mid_level, device=image.device))
|
||||
|
||||
# apply power function to avoid nans
|
||||
adjusted = torch.where(adjusted > 0, torch.pow(adjusted.clamp(min=1e-8), power_factor), adjusted)
|
||||
|
||||
result = linear_to_srgb(adjusted)
|
||||
|
||||
if clip:
|
||||
result = torch.clamp(result, 0.0, 1.0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
class ImageCrop:
|
||||
@ -228,7 +272,8 @@ class ImageResize:
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"resize_mode": (["cover", "contain", "auto"], {"default": "cover"}),
|
||||
"resolutions": (["SDXL/SD3/Flux", "SD1.5", ], {"default": "SDXL/SD3/Flux"})
|
||||
"resolutions": (["SDXL/SD3/Flux", "SD1.5"], {"default": "SDXL/SD3/Flux"}),
|
||||
"interpolation": (ImageScale.upscale_methods, {"default": "bilinear"}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,7 +281,7 @@ class ImageResize:
|
||||
FUNCTION = "resize_image"
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5",]) -> Tuple[ImageBatch]:
|
||||
def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str) -> Tuple[RGBImageBatch]:
|
||||
if resolutions == "SDXL/SD3/Flux":
|
||||
supported_resolutions = [
|
||||
(640, 1536),
|
||||
@ -256,58 +301,112 @@ class ImageResize:
|
||||
|
||||
resized_images = []
|
||||
for img in image:
|
||||
img_np = (img.cpu().numpy() * 255).astype(np.uint8)
|
||||
h, w = img_np.shape[:2]
|
||||
h, w = img.shape[:2]
|
||||
current_aspect_ratio = w / h
|
||||
target_resolution = min(supported_resolutions,
|
||||
key=lambda res: abs(res[0] / res[1] - current_aspect_ratio))
|
||||
scale_w, scale_h = target_resolution[0] / w, target_resolution[1] / h
|
||||
|
||||
if resize_mode == "cover":
|
||||
scale = max(scale_w, scale_h)
|
||||
scale = max(target_resolution[0] / w, target_resolution[1] / h)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
||||
x1 = (new_w - target_resolution[0]) // 2
|
||||
y1 = (new_h - target_resolution[1]) // 2
|
||||
resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
|
||||
elif resize_mode == "contain":
|
||||
scale = min(scale_w, scale_h)
|
||||
scale = min(target_resolution[0] / w, target_resolution[1] / h)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
||||
canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8)
|
||||
x1 = (target_resolution[0] - new_w) // 2
|
||||
y1 = (target_resolution[1] - new_h) // 2
|
||||
canvas[y1:y1 + new_h, x1:x1 + new_w] = resized
|
||||
resized = canvas
|
||||
else:
|
||||
else: # auto
|
||||
if current_aspect_ratio > target_resolution[0] / target_resolution[1]:
|
||||
scale = scale_w
|
||||
new_w, new_h = target_resolution[0], int(h * target_resolution[0] / w)
|
||||
else:
|
||||
scale = scale_h
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
|
||||
if new_w > target_resolution[0] or new_h > target_resolution[1]:
|
||||
x1 = (new_w - target_resolution[0]) // 2
|
||||
y1 = (new_h - target_resolution[1]) // 2
|
||||
resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
|
||||
else:
|
||||
canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8)
|
||||
x1 = (target_resolution[0] - new_w) // 2
|
||||
new_w, new_h = int(w * target_resolution[1] / h), target_resolution[1]
|
||||
|
||||
# convert to b, c, h, w
|
||||
img_tensor = img.permute(2, 0, 1).unsqueeze(0)
|
||||
|
||||
# Use common_upscale for resizing
|
||||
resized = utils.common_upscale(img_tensor, new_w, new_h, interpolation, "disabled")
|
||||
|
||||
# handle padding or cropping
|
||||
if resize_mode == "contain":
|
||||
canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype)
|
||||
y1 = (target_resolution[1] - new_h) // 2
|
||||
x1 = (target_resolution[0] - new_w) // 2
|
||||
canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized
|
||||
resized = canvas
|
||||
elif resize_mode == "cover":
|
||||
y1 = (new_h - target_resolution[1]) // 2
|
||||
x1 = (new_w - target_resolution[0]) // 2
|
||||
resized = resized[:, :, y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
|
||||
else: # auto
|
||||
if new_w != target_resolution[0] or new_h != target_resolution[1]:
|
||||
canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype)
|
||||
y1 = (target_resolution[1] - new_h) // 2
|
||||
canvas[y1:y1 + new_h, x1:x1 + new_w] = resized
|
||||
x1 = (target_resolution[0] - new_w) // 2
|
||||
canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized
|
||||
resized = canvas
|
||||
|
||||
resized_images.append(resized)
|
||||
resized_images.append(resized.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0))
|
||||
|
||||
return (torch.from_numpy(np.stack(resized_images)).float() / 255.0,)
|
||||
return (torch.stack(resized_images),)
|
||||
|
||||
|
||||
class ImageLevels(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"black_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"mid_level": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 0.99, "step": 0.01}),
|
||||
"white_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"clip": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "apply_levels"
|
||||
CATEGORY = "image/adjust"
|
||||
|
||||
def apply_levels(self, image: ImageBatch, black_level: float, mid_level: float, white_level: float, clip: bool) -> Tuple[ImageBatch]:
|
||||
adjusted_image = levels_adjustment(image, black_level, mid_level, white_level, clip)
|
||||
return (adjusted_image,)
|
||||
|
||||
|
||||
class ImageLuminance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "compute_luminance"
|
||||
CATEGORY = "image/color"
|
||||
|
||||
def compute_luminance(self, image: ImageBatch) -> Tuple[ImageBatch]:
|
||||
assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)"
|
||||
|
||||
# define srgb luminance coefficients
|
||||
coeffs = torch.tensor([0.2126, 0.7152, 0.0722], device=image.device, dtype=image.dtype)
|
||||
|
||||
luminance = torch.sum(image * coeffs, dim=-1, keepdim=True)
|
||||
luminance = luminance.expand(-1, -1, -1, 3)
|
||||
|
||||
return (luminance,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageResize": ImageResize,
|
||||
"ImageShape": ImageShape,
|
||||
"ImageCrop": ImageCrop,
|
||||
"ImageLevels": ImageLevels,
|
||||
"ImageLuminance": ImageLuminance,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"ImageFromBatch": ImageFromBatch,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageResize": "Fit Image to Diffusion Size"
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user