Improve scaling and fit for diffusion

This commit is contained in:
doctorpangloss 2024-09-26 18:08:34 -07:00
parent dbc8ee92a5
commit 667b77149e
2 changed files with 133 additions and 36 deletions

View File

@ -1580,8 +1580,6 @@ class LoadImage:
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]: def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)
output_images = [] output_images = []
output_masks = [] output_masks = []
w, h = None, None w, h = None, None

View File

@ -2,16 +2,60 @@ import json
import os import os
from typing import Literal, Tuple from typing import Literal, Tuple
import cv2
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from comfy import utils
from comfy.cli_args import args from comfy.cli_args import args
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.component_model.tensor_types import ImageBatch from comfy.component_model.tensor_types import ImageBatch, RGBImageBatch
from comfy.nodes.base_nodes import ImageScale
from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.common import MAX_RESOLUTION
from comfy.nodes.package_typing import CustomNode
def levels_adjustment(image: ImageBatch, black_level: float = 0.0, mid_level: float = 0.5, white_level: float = 1.0, clip: bool = True) -> ImageBatch:
"""
Apply a levels adjustment to an sRGB image.
Args:
image (torch.Tensor): Input image tensor of shape (B, H, W, C) with values in range [0, 1]
black_level (float): Black point (default: 0.0)
mid_level (float): Midtone point (default: 0.5)
white_level (float): White point (default: 1.0)
clip (bool): Whether to clip the output values to [0, 1] range (default: True)
Returns:
torch.Tensor: Adjusted image tensor of shape (B, H, W, C)
"""
# Ensure input is in correct shape and range
assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)"
assert 0 <= black_level < mid_level < white_level <= 1, "Levels should be in ascending order in range [0, 1]"
def srgb_to_linear(x):
return torch.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
def linear_to_srgb(x):
return torch.where(x <= 0.0031308, x * 12.92, 1.055 * x ** (1 / 2.4) - 0.055)
linear = srgb_to_linear(image)
adjusted = (linear - black_level) / (white_level - black_level)
power_factor = torch.log2(torch.tensor(0.5, device=image.device)) / torch.log2(torch.tensor(mid_level, device=image.device))
# apply power function to avoid nans
adjusted = torch.where(adjusted > 0, torch.pow(adjusted.clamp(min=1e-8), power_factor), adjusted)
result = linear_to_srgb(adjusted)
if clip:
result = torch.clamp(result, 0.0, 1.0)
return result
class ImageCrop: class ImageCrop:
@ -228,7 +272,8 @@ class ImageResize:
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"resize_mode": (["cover", "contain", "auto"], {"default": "cover"}), "resize_mode": (["cover", "contain", "auto"], {"default": "cover"}),
"resolutions": (["SDXL/SD3/Flux", "SD1.5", ], {"default": "SDXL/SD3/Flux"}) "resolutions": (["SDXL/SD3/Flux", "SD1.5"], {"default": "SDXL/SD3/Flux"}),
"interpolation": (ImageScale.upscale_methods, {"default": "bilinear"}),
} }
} }
@ -236,7 +281,7 @@ class ImageResize:
FUNCTION = "resize_image" FUNCTION = "resize_image"
CATEGORY = "image/transform" CATEGORY = "image/transform"
def resize_image(self, image: ImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5",]) -> Tuple[ImageBatch]: def resize_image(self, image: RGBImageBatch, resize_mode: Literal["cover", "contain", "auto"], resolutions: Literal["SDXL/SD3/Flux", "SD1.5"], interpolation: str) -> Tuple[RGBImageBatch]:
if resolutions == "SDXL/SD3/Flux": if resolutions == "SDXL/SD3/Flux":
supported_resolutions = [ supported_resolutions = [
(640, 1536), (640, 1536),
@ -256,58 +301,112 @@ class ImageResize:
resized_images = [] resized_images = []
for img in image: for img in image:
img_np = (img.cpu().numpy() * 255).astype(np.uint8) h, w = img.shape[:2]
h, w = img_np.shape[:2]
current_aspect_ratio = w / h current_aspect_ratio = w / h
target_resolution = min(supported_resolutions, target_resolution = min(supported_resolutions,
key=lambda res: abs(res[0] / res[1] - current_aspect_ratio)) key=lambda res: abs(res[0] / res[1] - current_aspect_ratio))
scale_w, scale_h = target_resolution[0] / w, target_resolution[1] / h
if resize_mode == "cover": if resize_mode == "cover":
scale = max(scale_w, scale_h) scale = max(target_resolution[0] / w, target_resolution[1] / h)
new_w, new_h = int(w * scale), int(h * scale) new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
x1 = (new_w - target_resolution[0]) // 2
y1 = (new_h - target_resolution[1]) // 2
resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
elif resize_mode == "contain": elif resize_mode == "contain":
scale = min(scale_w, scale_h) scale = min(target_resolution[0] / w, target_resolution[1] / h)
new_w, new_h = int(w * scale), int(h * scale) new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) else: # auto
canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8)
x1 = (target_resolution[0] - new_w) // 2
y1 = (target_resolution[1] - new_h) // 2
canvas[y1:y1 + new_h, x1:x1 + new_w] = resized
resized = canvas
else:
if current_aspect_ratio > target_resolution[0] / target_resolution[1]: if current_aspect_ratio > target_resolution[0] / target_resolution[1]:
scale = scale_w new_w, new_h = target_resolution[0], int(h * target_resolution[0] / w)
else: else:
scale = scale_h new_w, new_h = int(w * target_resolution[1] / h), target_resolution[1]
new_w, new_h = int(w * scale), int(h * scale)
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) # convert to b, c, h, w
if new_w > target_resolution[0] or new_h > target_resolution[1]: img_tensor = img.permute(2, 0, 1).unsqueeze(0)
x1 = (new_w - target_resolution[0]) // 2
y1 = (new_h - target_resolution[1]) // 2 # Use common_upscale for resizing
resized = resized[y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]] resized = utils.common_upscale(img_tensor, new_w, new_h, interpolation, "disabled")
else:
canvas = np.zeros((target_resolution[1], target_resolution[0], 3), dtype=np.uint8) # handle padding or cropping
x1 = (target_resolution[0] - new_w) // 2 if resize_mode == "contain":
canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype)
y1 = (target_resolution[1] - new_h) // 2
x1 = (target_resolution[0] - new_w) // 2
canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized
resized = canvas
elif resize_mode == "cover":
y1 = (new_h - target_resolution[1]) // 2
x1 = (new_w - target_resolution[0]) // 2
resized = resized[:, :, y1:y1 + target_resolution[1], x1:x1 + target_resolution[0]]
else: # auto
if new_w != target_resolution[0] or new_h != target_resolution[1]:
canvas = torch.zeros((1, 3, target_resolution[1], target_resolution[0]), device=resized.device, dtype=resized.dtype)
y1 = (target_resolution[1] - new_h) // 2 y1 = (target_resolution[1] - new_h) // 2
canvas[y1:y1 + new_h, x1:x1 + new_w] = resized x1 = (target_resolution[0] - new_w) // 2
canvas[:, :, y1:y1 + new_h, x1:x1 + new_w] = resized
resized = canvas resized = canvas
resized_images.append(resized) resized_images.append(resized.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0))
return (torch.from_numpy(np.stack(resized_images)).float() / 255.0,) return (torch.stack(resized_images),)
class ImageLevels(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"black_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"mid_level": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 0.99, "step": 0.01}),
"white_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"clip": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "apply_levels"
CATEGORY = "image/adjust"
def apply_levels(self, image: ImageBatch, black_level: float, mid_level: float, white_level: float, clip: bool) -> Tuple[ImageBatch]:
adjusted_image = levels_adjustment(image, black_level, mid_level, white_level, clip)
return (adjusted_image,)
class ImageLuminance:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "compute_luminance"
CATEGORY = "image/color"
def compute_luminance(self, image: ImageBatch) -> Tuple[ImageBatch]:
assert image.dim() == 4 and image.shape[-1] == 3, "Input should be of shape (B, H, W, 3)"
# define srgb luminance coefficients
coeffs = torch.tensor([0.2126, 0.7152, 0.0722], device=image.device, dtype=image.dtype)
luminance = torch.sum(image * coeffs, dim=-1, keepdim=True)
luminance = luminance.expand(-1, -1, -1, 3)
return (luminance,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageResize": ImageResize, "ImageResize": ImageResize,
"ImageShape": ImageShape, "ImageShape": ImageShape,
"ImageCrop": ImageCrop, "ImageCrop": ImageCrop,
"ImageLevels": ImageLevels,
"ImageLuminance": ImageLuminance,
"RepeatImageBatch": RepeatImageBatch, "RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch, "ImageFromBatch": ImageFromBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG, "SaveAnimatedPNG": SaveAnimatedPNG,
} }
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageResize": "Fit Image to Diffusion Size"
}