diff --git a/.gitignore b/.gitignore index 2ec09be4d..b1146714d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.DS_Store /[Oo]utput/ /[Ii]nput/ !/input/example.png @@ -167,4 +168,4 @@ dmypy.json # Cython debug symbols cython_debug/ -.openapi-generator/ \ No newline at end of file +.openapi-generator/ diff --git a/comfy/app/__init__.py b/comfy/app/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/mara/__init__.py b/comfy_extras/mara/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/mara/k_centroid_downscale.py b/comfy_extras/mara/k_centroid_downscale.py new file mode 100644 index 000000000..2f70c694c --- /dev/null +++ b/comfy_extras/mara/k_centroid_downscale.py @@ -0,0 +1,91 @@ +# Mara Huldra 2023 +# SPDX-License-Identifier: MIT +''' +Downscale an image using k-centroid scaling. This is useful for images with a +low number of separate colors, such as those generated by Astropulse's pixel +art model. +''' +from itertools import product + +import numpy as np +from PIL import Image +import torch + +MAX_RESOLUTION = 1024 +AUTO_FACTOR = 8 + +def k_centroid_downscale(images, width, height, centroids=2): + '''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.''' + + downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8) + + for ii, image in enumerate(images): + i = 255. * image.cpu().numpy() + image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + factor = (image.width // width, image.height // height) + + for x, y in product(range(width), range(height)): + tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1])) + # quantize tile to fixed number of colors (creates palettized image) + tile = tile.quantize(colors=centroids, method=1, kmeans=centroids) + # get most common (median) color + color_counts = tile.getcolors() + most_common_idx = max(color_counts, key=lambda x: x[0])[1] + downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3] + + downscaled = downscaled.astype(np.float32) / 255.0 + return torch.from_numpy(downscaled) + + +class ImageKCentroidDownscale: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "downscale" + + CATEGORY = "image/downscaling" + + def downscale(self, image, width, height, centroids): + s = k_centroid_downscale(image, width, height, centroids) + return (s,) + +class ImageKCentroidAutoDownscale: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "downscale" + + CATEGORY = "image/downscaling" + + def downscale(self, image, centroids): + width = image.shape[2] // AUTO_FACTOR + height = image.shape[1] // AUTO_FACTOR + s = k_centroid_downscale(image, width, height, centroids) + return (s,) + + +NODE_CLASS_MAPPINGS = { + "ImageKCentroidDownscale": ImageKCentroidDownscale, + "ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageKCentroidDownscale": "K-Centroid Downscale", + "ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)" +} diff --git a/comfy_extras/mara/make_model_tileable.py b/comfy_extras/mara/make_model_tileable.py new file mode 100644 index 000000000..b2192dbd8 --- /dev/null +++ b/comfy_extras/mara/make_model_tileable.py @@ -0,0 +1,78 @@ +# Mara Huldra 2023 +# SPDX-License-Identifier: MIT +''' +Patches the SD model and VAE to make it possible to generate seamlessly tilable +graphics. Horizontal and vertical direction are configurable separately. +''' +from typing import Optional + +import torch +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +def flatten_modules(m): + '''Return submodules of module m in flattened form.''' + yield m + for c in m.children(): + yield from flatten_modules(c) + +# from: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py +def make_seamless_xy(model, x, y): + for layer in flatten_modules(model): + if type(layer) == torch.nn.Conv2d: + layer.padding_modeX = 'circular' if x else 'constant' + layer.padding_modeY = 'circular' if y else 'constant' + layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0) + layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3]) + layer._conv_forward = __replacementConv2DConvForward.__get__(layer, torch.nn.Conv2d) + +def restore_conv2d_methods(model): + for layer in flatten_modules(model): + if type(layer) == torch.nn.Conv2d: + layer._conv_forward = torch.nn.Conv2d._conv_forward.__get__(layer, torch.nn.Conv2d) + +def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]): + working = F.pad(input, self.paddingX, mode=self.padding_modeX) + working = F.pad(working, self.paddingY, mode=self.padding_modeY) + return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups) + + +class MakeModelTileable: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "vae": ("VAE",), + "tile_x": (["disabled", "enabled"], { "default": "disabled", }), + "tile_y": (["disabled", "enabled"], { "default": "disabled", }), + } + } + + RETURN_TYPES = ("MODEL", "VAE") + FUNCTION = "patch_models" + + CATEGORY = "advanced/patchers" + + def patch_models(self, model, vae, tile_x, tile_y): + tile_x = (tile_x == 'enabled') + tile_y = (tile_y == 'enabled') + # XXX ideally, we'd return a clone of the model, not patch the model itself + #model = model.clone() + #vae = vae.???() + + restore_conv2d_methods(model.model) + restore_conv2d_methods(vae.first_stage_model) + make_seamless_xy(model.model, tile_x, tile_y) + make_seamless_xy(vae.first_stage_model, tile_x, tile_y) + return (model, vae) + + +NODE_CLASS_MAPPINGS = { + "MakeModelTileable": MakeModelTileable, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "MakeModelTileable": "Patch model tileability" +} diff --git a/comfy_extras/mara/mask_ops.py b/comfy_extras/mara/mask_ops.py new file mode 100644 index 000000000..bd870eab4 --- /dev/null +++ b/comfy_extras/mara/mask_ops.py @@ -0,0 +1,68 @@ +# Mara Huldra 2023 +# SPDX-License-Identifier: MIT +''' +Extra mask operations. +''' +import numpy as np +import rembg +import torch + + +class BinarizeMask: + '''Binarize (threshold) a mask.''' + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "threshold": ("INT", { + "default": 250, + "min": 0, + "max": 255, + "step": 1, + }), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "binarize" + + CATEGORY = "mask" + + def binarize(self, mask, threshold): + t = torch.Tensor([threshold / 255.]) + s = (mask >= t).float() + return (s,) + + +class ImageCutout: + '''Perform basic image cutout (adds alpha channel from mask).''' + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "cutout" + + CATEGORY = "image/postprocessing" + + def cutout(self, image, mask): + # XXX check compatible dimensions. + o = np.zeros((image.shape[0], image.shape[1], image.shape[2], 4)) + o[:, :, :, 0:3] = image.cpu().numpy() + o[:, :, :, 3] = mask.cpu().numpy() + return (torch.from_numpy(o),) + + +NODE_CLASS_MAPPINGS = { + "BinarizeMask": BinarizeMask, + "ImageCutout": ImageCutout, +} + diff --git a/comfy_extras/mara/palettize.py b/comfy_extras/mara/palettize.py new file mode 100644 index 000000000..36d3155b4 --- /dev/null +++ b/comfy_extras/mara/palettize.py @@ -0,0 +1,155 @@ +# Mara Huldra 2023 +# SPDX-License-Identifier: MIT +''' +Palettize an image. +''' +import os + +import numpy as np +from PIL import Image +import torch + +PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../..', 'palettes') +PAL_EXT = '.png' + +QUANTIZE_METHODS = { + 'median_cut': Image.Quantize.MEDIANCUT, + 'max_coverage': Image.Quantize.MAXCOVERAGE, + 'fast_octree': Image.Quantize.FASTOCTREE, +} + +# Determine optimal number of colors. +# FROM: astropulse/sd-palettize +# +# Use FASTOCTREE for determining the best k, as it is +# - its faster +# - it does a better job fitting the image to lower color counts than the other options +# Max converge is best for reducing an image's colors more accurately, but +# since for best k we only care about the best number of colors, a faster more +# predictable method is better. +# (Astropulse, 2023-06-05) +def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE): + # Convert the image to RGB mode + image = image.convert("RGB") + + # Prepare arrays for distortion calculation + pixels = np.array(image) + pixel_indices = np.reshape(pixels, (-1, 3)) + + # Calculate distortion for different values of k + distortions = [] + for k in range(1, max_k + 1): + quantized_image = image.quantize(colors=k, method=quantize_method, kmeans=k, dither=0) + centroids = np.array(quantized_image.getpalette()[:k * 3]).reshape(-1, 3) + + # Calculate distortions + distances = np.linalg.norm(pixel_indices[:, np.newaxis] - centroids, axis=2) + min_distances = np.min(distances, axis=1) + distortions.append(np.sum(min_distances ** 2)) + + # Calculate the rate of change of distortions + rate_of_change = np.diff(distortions) / np.array(distortions[:-1]) + + # Find the elbow point (best k value) + if len(rate_of_change) == 0: + best_k = 2 + else: + elbow_index = np.argmax(rate_of_change) + 1 + best_k = elbow_index + 2 + + return best_k + +palette_warned = False + +def list_palettes(): + global palette_warned + palettes = [] + try: + for filename in os.listdir(PALETTES_PATH): + if filename.endswith(PAL_EXT): + palettes.append(filename[0:-len(PAL_EXT)]) + except FileNotFoundError: + pass + if not palettes and not palette_warned: + palette_warned = True + print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.") + return palettes + + +def get_image_colors(pal_img): + palette = [] + pal_img = pal_img.convert('RGB') + for i in pal_img.getcolors(16777216): + palette.append(i[1][0]) + palette.append(i[1][1]) + palette.append(i[1][2]) + return palette + + +def load_palette(name): + return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT))) + + +class ImagePalettize: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "palette": (["auto_best_k", "auto_fixed_k"] + list_palettes(), { + "default": "auto_best_k", + }), + "max_k": ("INT", { + "default": 64, + "min": 1, + "max": 256, + "step": 1, + }), + "method": (list(QUANTIZE_METHODS.keys()), { + "default": "max_coverage", + }), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "palettize" + + CATEGORY = "image/postprocessing" + + def palettize(self, image, palette, max_k, method): + k = None + pal_img = None + if palette not in {'auto_best_k', 'auto_fixed_k'}: + pal_entries = load_palette(palette) + k = len(pal_entries) // 3 + pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette + pal_img.putpalette(pal_entries) + + results = [] + + for i in image: + i = 255. * i.cpu().numpy() + i = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + + if palette == 'auto_best_k': + k = determine_best_k(i, max_k) + print(f'Auto number of colors: {k}') + elif palette == 'auto_fixed_k': + k = max_k + + i = i.quantize(colors=k, method=QUANTIZE_METHODS[method], kmeans=k, dither=0, palette=pal_img) + i = i.convert('RGB') + + results.append(np.array(i)) + + result = np.array(results).astype(np.float32) / 255.0 + return (torch.from_numpy(result), ) + + +NODE_CLASS_MAPPINGS = { + "ImagePalettize": ImagePalettize, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImagePalettize": "ImagePalettize" +} diff --git a/comfy_extras/mara/patterngen.py b/comfy_extras/mara/patterngen.py new file mode 100644 index 000000000..76d588af6 --- /dev/null +++ b/comfy_extras/mara/patterngen.py @@ -0,0 +1,45 @@ +# Mara Huldra 2023 +# SPDX-License-Identifier: MIT +''' +Simple image pattern generators. +''' +import os + +import numpy as np +from PIL import Image +import torch + +MAX_RESOLUTION = 8192 + +class ImageSolidColor: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "r": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), + "g": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), + "b": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "render" + + CATEGORY = "image/pattern" + + def render(self, width, height, r, g, b): + color = torch.tensor([r, g, b]) / 255.0 + result = color.expand(1, height, width, 3) + return (result, ) + + +NODE_CLASS_MAPPINGS = { + "ImageSolidColor": ImageSolidColor, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageSolidColor": "Solid Color", +} + diff --git a/comfy_extras/mara/remove_background.py b/comfy_extras/mara/remove_background.py new file mode 100644 index 000000000..6ba498c5b --- /dev/null +++ b/comfy_extras/mara/remove_background.py @@ -0,0 +1,119 @@ +# Mara Huldra 2023 +# SPDX-License-Identifier: MIT +''' +Estimate what pixels belong to the background and perform a cut-out, using the 'rembg' models. +''' +import numpy as np +import rembg +import torch + + +MODELS = rembg.sessions.sessions_names + + +class ImageRemoveBackground: + '''Remove background from image (adds an alpha channel)''' + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "model": (MODELS, { + "default": "u2net", + }), + "alpha_matting": (["disabled", "enabled"], { + "default": "disabled", + }), + "am_foreground_thr": ("INT", { + "default": 240, + "min": 0, + "max": 255, + "step": 1, + }), + "am_background_thr": ("INT", { + "default": 10, + "min": 0, + "max": 255, + "step": 1, + }), + "am_erode_size": ("INT", { + "default": 10, + "min": 0, + "max": 255, + "step": 1, + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "remove_background" + + CATEGORY = "image/postprocessing" + + def remove_background(self, image, model, alpha_matting, am_foreground_thr, am_background_thr, am_erode_size): + session = rembg.new_session(model) + results = [] + + for i in image: + i = 255. * i.cpu().numpy() + i = np.clip(i, 0, 255).astype(np.uint8) + i = rembg.remove(i, + alpha_matting=(alpha_matting == "enabled"), + alpha_matting_foreground_threshold=am_foreground_thr, + alpha_matting_background_threshold=am_background_thr, + alpha_matting_erode_size=am_erode_size, + session=session, + ) + results.append(i.astype(np.float32) / 255.0) + + s = torch.from_numpy(np.array(results)) + return (s,) + +class ImageEstimateForegroundMask: + ''' + Return a mask of which pixels are estimated to belong to foreground. + Only estimates the mask, does not perform cutout like + ImageRemoveBackground. + ''' + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "model": (MODELS, { + "default": "u2net", + }), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "estimate_background" + + CATEGORY = "image/postprocessing" + + def estimate_background(self, image, model): + session = rembg.new_session(model) + results = [] + + for i in image: + i = 255. * i.cpu().numpy() + i = np.clip(i, 0, 255).astype(np.uint8) + i = rembg.remove(i, only_mask=True, session=session) + results.append(i.astype(np.float32) / 255.0) + + s = torch.from_numpy(np.array(results)) + print(s.shape) + return (s,) + + +NODE_CLASS_MAPPINGS = { + "ImageRemoveBackground": ImageRemoveBackground, + "ImageEstimateForegroundMask": ImageEstimateForegroundMask, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageRemoveBackground": "Remove Background (rembg)", + "ImageEstimateForegroundMask": "Estimate Foreground (rembg)", +} diff --git a/nodes.py b/nodes.py index 92baffe30..4598455b0 100644 --- a/nodes.py +++ b/nodes.py @@ -1697,4 +1697,10 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "k_centroid_downscale.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "make_model_tileable.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "mask_ops.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "palettize.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "patterngen.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "remove_background.py")) load_custom_nodes()