ComfyUI/comfy_extras/nodes/pixel_art/k_centroid_downscale.py
Benjamin Berman cff13ace64 node tweaks
2023-08-21 11:44:51 -07:00

96 lines
3.1 KiB
Python

# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Downscale an image using k-centroid scaling. This is useful for images with a
low number of separate colors, such as those generated by Astropulse's pixel
art model.
'''
from itertools import product
import numpy as np
from PIL import Image
import torch
from comfy.nodes.package_typing import CustomNode
MAX_RESOLUTION = 1024
AUTO_FACTOR = 8
def k_centroid_downscale(images, width, height, centroids=2):
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8)
for ii, image in enumerate(images):
i = 255. * image.cpu().numpy()
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
factor = (image.width // width, image.height // height)
for x, y in product(range(width), range(height)):
tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1]))
# quantize tile to fixed number of colors (creates palettized image)
tile = tile.quantize(colors=centroids, method=1, kmeans=centroids)
# get most common (median) color
color_counts = tile.getcolors()
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx * 3:(most_common_idx + 1) * 3]
downscaled = downscaled.astype(np.float32) / 255.0
return torch.from_numpy(downscaled)
class ImageKCentroidDownscale(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, width, height, centroids):
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
class ImageKCentroidAutoDownscale(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, centroids):
width = image.shape[2] // AUTO_FACTOR
height = image.shape[1] // AUTO_FACTOR
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageKCentroidDownscale": ImageKCentroidDownscale,
"ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageKCentroidDownscale": "K-Centroid Downscale",
"ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)"
}