ComfyUI/comfy_extras/mara/k_centroid_downscale.py
Benjamin Berman 87cf8f613e mara nodes
2023-08-04 15:44:02 -07:00

92 lines
3.0 KiB
Python

# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Downscale an image using k-centroid scaling. This is useful for images with a
low number of separate colors, such as those generated by Astropulse's pixel
art model.
'''
from itertools import product
import numpy as np
from PIL import Image
import torch
MAX_RESOLUTION = 1024
AUTO_FACTOR = 8
def k_centroid_downscale(images, width, height, centroids=2):
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8)
for ii, image in enumerate(images):
i = 255. * image.cpu().numpy()
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
factor = (image.width // width, image.height // height)
for x, y in product(range(width), range(height)):
tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1]))
# quantize tile to fixed number of colors (creates palettized image)
tile = tile.quantize(colors=centroids, method=1, kmeans=centroids)
# get most common (median) color
color_counts = tile.getcolors()
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3]
downscaled = downscaled.astype(np.float32) / 255.0
return torch.from_numpy(downscaled)
class ImageKCentroidDownscale:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, width, height, centroids):
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
class ImageKCentroidAutoDownscale:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, centroids):
width = image.shape[2] // AUTO_FACTOR
height = image.shape[1] // AUTO_FACTOR
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageKCentroidDownscale": ImageKCentroidDownscale,
"ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageKCentroidDownscale": "K-Centroid Downscale",
"ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)"
}