mara nodes

This commit is contained in:
Benjamin Berman 2023-08-01 21:50:05 -07:00
parent 66b857d069
commit 87cf8f613e
10 changed files with 564 additions and 1 deletions

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
.DS_Store
/[Oo]utput/
/[Ii]nput/
!/input/example.png
@ -167,4 +168,4 @@ dmypy.json
# Cython debug symbols
cython_debug/
.openapi-generator/
.openapi-generator/

0
comfy/app/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,91 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Downscale an image using k-centroid scaling. This is useful for images with a
low number of separate colors, such as those generated by Astropulse's pixel
art model.
'''
from itertools import product
import numpy as np
from PIL import Image
import torch
MAX_RESOLUTION = 1024
AUTO_FACTOR = 8
def k_centroid_downscale(images, width, height, centroids=2):
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8)
for ii, image in enumerate(images):
i = 255. * image.cpu().numpy()
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
factor = (image.width // width, image.height // height)
for x, y in product(range(width), range(height)):
tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1]))
# quantize tile to fixed number of colors (creates palettized image)
tile = tile.quantize(colors=centroids, method=1, kmeans=centroids)
# get most common (median) color
color_counts = tile.getcolors()
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3]
downscaled = downscaled.astype(np.float32) / 255.0
return torch.from_numpy(downscaled)
class ImageKCentroidDownscale:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, width, height, centroids):
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
class ImageKCentroidAutoDownscale:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, centroids):
width = image.shape[2] // AUTO_FACTOR
height = image.shape[1] // AUTO_FACTOR
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageKCentroidDownscale": ImageKCentroidDownscale,
"ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageKCentroidDownscale": "K-Centroid Downscale",
"ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)"
}

View File

@ -0,0 +1,78 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Patches the SD model and VAE to make it possible to generate seamlessly tilable
graphics. Horizontal and vertical direction are configurable separately.
'''
from typing import Optional
import torch
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
def flatten_modules(m):
'''Return submodules of module m in flattened form.'''
yield m
for c in m.children():
yield from flatten_modules(c)
# from: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py
def make_seamless_xy(model, x, y):
for layer in flatten_modules(model):
if type(layer) == torch.nn.Conv2d:
layer.padding_modeX = 'circular' if x else 'constant'
layer.padding_modeY = 'circular' if y else 'constant'
layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0)
layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3])
layer._conv_forward = __replacementConv2DConvForward.__get__(layer, torch.nn.Conv2d)
def restore_conv2d_methods(model):
for layer in flatten_modules(model):
if type(layer) == torch.nn.Conv2d:
layer._conv_forward = torch.nn.Conv2d._conv_forward.__get__(layer, torch.nn.Conv2d)
def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]):
working = F.pad(input, self.paddingX, mode=self.padding_modeX)
working = F.pad(working, self.paddingY, mode=self.padding_modeY)
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
class MakeModelTileable:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"vae": ("VAE",),
"tile_x": (["disabled", "enabled"], { "default": "disabled", }),
"tile_y": (["disabled", "enabled"], { "default": "disabled", }),
}
}
RETURN_TYPES = ("MODEL", "VAE")
FUNCTION = "patch_models"
CATEGORY = "advanced/patchers"
def patch_models(self, model, vae, tile_x, tile_y):
tile_x = (tile_x == 'enabled')
tile_y = (tile_y == 'enabled')
# XXX ideally, we'd return a clone of the model, not patch the model itself
#model = model.clone()
#vae = vae.???()
restore_conv2d_methods(model.model)
restore_conv2d_methods(vae.first_stage_model)
make_seamless_xy(model.model, tile_x, tile_y)
make_seamless_xy(vae.first_stage_model, tile_x, tile_y)
return (model, vae)
NODE_CLASS_MAPPINGS = {
"MakeModelTileable": MakeModelTileable,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MakeModelTileable": "Patch model tileability"
}

View File

@ -0,0 +1,68 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Extra mask operations.
'''
import numpy as np
import rembg
import torch
class BinarizeMask:
'''Binarize (threshold) a mask.'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
"threshold": ("INT", {
"default": 250,
"min": 0,
"max": 255,
"step": 1,
}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "binarize"
CATEGORY = "mask"
def binarize(self, mask, threshold):
t = torch.Tensor([threshold / 255.])
s = (mask >= t).float()
return (s,)
class ImageCutout:
'''Perform basic image cutout (adds alpha channel from mask).'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"mask": ("MASK",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "cutout"
CATEGORY = "image/postprocessing"
def cutout(self, image, mask):
# XXX check compatible dimensions.
o = np.zeros((image.shape[0], image.shape[1], image.shape[2], 4))
o[:, :, :, 0:3] = image.cpu().numpy()
o[:, :, :, 3] = mask.cpu().numpy()
return (torch.from_numpy(o),)
NODE_CLASS_MAPPINGS = {
"BinarizeMask": BinarizeMask,
"ImageCutout": ImageCutout,
}

View File

@ -0,0 +1,155 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Palettize an image.
'''
import os
import numpy as np
from PIL import Image
import torch
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../..', 'palettes')
PAL_EXT = '.png'
QUANTIZE_METHODS = {
'median_cut': Image.Quantize.MEDIANCUT,
'max_coverage': Image.Quantize.MAXCOVERAGE,
'fast_octree': Image.Quantize.FASTOCTREE,
}
# Determine optimal number of colors.
# FROM: astropulse/sd-palettize
#
# Use FASTOCTREE for determining the best k, as it is
# - its faster
# - it does a better job fitting the image to lower color counts than the other options
# Max converge is best for reducing an image's colors more accurately, but
# since for best k we only care about the best number of colors, a faster more
# predictable method is better.
# (Astropulse, 2023-06-05)
def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE):
# Convert the image to RGB mode
image = image.convert("RGB")
# Prepare arrays for distortion calculation
pixels = np.array(image)
pixel_indices = np.reshape(pixels, (-1, 3))
# Calculate distortion for different values of k
distortions = []
for k in range(1, max_k + 1):
quantized_image = image.quantize(colors=k, method=quantize_method, kmeans=k, dither=0)
centroids = np.array(quantized_image.getpalette()[:k * 3]).reshape(-1, 3)
# Calculate distortions
distances = np.linalg.norm(pixel_indices[:, np.newaxis] - centroids, axis=2)
min_distances = np.min(distances, axis=1)
distortions.append(np.sum(min_distances ** 2))
# Calculate the rate of change of distortions
rate_of_change = np.diff(distortions) / np.array(distortions[:-1])
# Find the elbow point (best k value)
if len(rate_of_change) == 0:
best_k = 2
else:
elbow_index = np.argmax(rate_of_change) + 1
best_k = elbow_index + 2
return best_k
palette_warned = False
def list_palettes():
global palette_warned
palettes = []
try:
for filename in os.listdir(PALETTES_PATH):
if filename.endswith(PAL_EXT):
palettes.append(filename[0:-len(PAL_EXT)])
except FileNotFoundError:
pass
if not palettes and not palette_warned:
palette_warned = True
print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
return palettes
def get_image_colors(pal_img):
palette = []
pal_img = pal_img.convert('RGB')
for i in pal_img.getcolors(16777216):
palette.append(i[1][0])
palette.append(i[1][1])
palette.append(i[1][2])
return palette
def load_palette(name):
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
class ImagePalettize:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"palette": (["auto_best_k", "auto_fixed_k"] + list_palettes(), {
"default": "auto_best_k",
}),
"max_k": ("INT", {
"default": 64,
"min": 1,
"max": 256,
"step": 1,
}),
"method": (list(QUANTIZE_METHODS.keys()), {
"default": "max_coverage",
}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "palettize"
CATEGORY = "image/postprocessing"
def palettize(self, image, palette, max_k, method):
k = None
pal_img = None
if palette not in {'auto_best_k', 'auto_fixed_k'}:
pal_entries = load_palette(palette)
k = len(pal_entries) // 3
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
pal_img.putpalette(pal_entries)
results = []
for i in image:
i = 255. * i.cpu().numpy()
i = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
if palette == 'auto_best_k':
k = determine_best_k(i, max_k)
print(f'Auto number of colors: {k}')
elif palette == 'auto_fixed_k':
k = max_k
i = i.quantize(colors=k, method=QUANTIZE_METHODS[method], kmeans=k, dither=0, palette=pal_img)
i = i.convert('RGB')
results.append(np.array(i))
result = np.array(results).astype(np.float32) / 255.0
return (torch.from_numpy(result), )
NODE_CLASS_MAPPINGS = {
"ImagePalettize": ImagePalettize,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImagePalettize": "ImagePalettize"
}

View File

@ -0,0 +1,45 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Simple image pattern generators.
'''
import os
import numpy as np
from PIL import Image
import torch
MAX_RESOLUTION = 8192
class ImageSolidColor:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"r": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
"g": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
"b": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "render"
CATEGORY = "image/pattern"
def render(self, width, height, r, g, b):
color = torch.tensor([r, g, b]) / 255.0
result = color.expand(1, height, width, 3)
return (result, )
NODE_CLASS_MAPPINGS = {
"ImageSolidColor": ImageSolidColor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSolidColor": "Solid Color",
}

View File

@ -0,0 +1,119 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Estimate what pixels belong to the background and perform a cut-out, using the 'rembg' models.
'''
import numpy as np
import rembg
import torch
MODELS = rembg.sessions.sessions_names
class ImageRemoveBackground:
'''Remove background from image (adds an alpha channel)'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"model": (MODELS, {
"default": "u2net",
}),
"alpha_matting": (["disabled", "enabled"], {
"default": "disabled",
}),
"am_foreground_thr": ("INT", {
"default": 240,
"min": 0,
"max": 255,
"step": 1,
}),
"am_background_thr": ("INT", {
"default": 10,
"min": 0,
"max": 255,
"step": 1,
}),
"am_erode_size": ("INT", {
"default": 10,
"min": 0,
"max": 255,
"step": 1,
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "remove_background"
CATEGORY = "image/postprocessing"
def remove_background(self, image, model, alpha_matting, am_foreground_thr, am_background_thr, am_erode_size):
session = rembg.new_session(model)
results = []
for i in image:
i = 255. * i.cpu().numpy()
i = np.clip(i, 0, 255).astype(np.uint8)
i = rembg.remove(i,
alpha_matting=(alpha_matting == "enabled"),
alpha_matting_foreground_threshold=am_foreground_thr,
alpha_matting_background_threshold=am_background_thr,
alpha_matting_erode_size=am_erode_size,
session=session,
)
results.append(i.astype(np.float32) / 255.0)
s = torch.from_numpy(np.array(results))
return (s,)
class ImageEstimateForegroundMask:
'''
Return a mask of which pixels are estimated to belong to foreground.
Only estimates the mask, does not perform cutout like
ImageRemoveBackground.
'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"model": (MODELS, {
"default": "u2net",
}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "estimate_background"
CATEGORY = "image/postprocessing"
def estimate_background(self, image, model):
session = rembg.new_session(model)
results = []
for i in image:
i = 255. * i.cpu().numpy()
i = np.clip(i, 0, 255).astype(np.uint8)
i = rembg.remove(i, only_mask=True, session=session)
results.append(i.astype(np.float32) / 255.0)
s = torch.from_numpy(np.array(results))
print(s.shape)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageRemoveBackground": ImageRemoveBackground,
"ImageEstimateForegroundMask": ImageEstimateForegroundMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageRemoveBackground": "Remove Background (rembg)",
"ImageEstimateForegroundMask": "Estimate Foreground (rembg)",
}

View File

@ -1697,4 +1697,10 @@ def init_custom_nodes():
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "k_centroid_downscale.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "make_model_tileable.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "mask_ops.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "palettize.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "patterngen.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "remove_background.py"))
load_custom_nodes()