mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
mara nodes
This commit is contained in:
parent
66b857d069
commit
87cf8f613e
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
|
.DS_Store
|
||||||
/[Oo]utput/
|
/[Oo]utput/
|
||||||
/[Ii]nput/
|
/[Ii]nput/
|
||||||
!/input/example.png
|
!/input/example.png
|
||||||
|
|||||||
0
comfy/app/__init__.py
Normal file
0
comfy/app/__init__.py
Normal file
0
comfy_extras/mara/__init__.py
Normal file
0
comfy_extras/mara/__init__.py
Normal file
91
comfy_extras/mara/k_centroid_downscale.py
Normal file
91
comfy_extras/mara/k_centroid_downscale.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# Mara Huldra 2023
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
'''
|
||||||
|
Downscale an image using k-centroid scaling. This is useful for images with a
|
||||||
|
low number of separate colors, such as those generated by Astropulse's pixel
|
||||||
|
art model.
|
||||||
|
'''
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
MAX_RESOLUTION = 1024
|
||||||
|
AUTO_FACTOR = 8
|
||||||
|
|
||||||
|
def k_centroid_downscale(images, width, height, centroids=2):
|
||||||
|
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
|
||||||
|
|
||||||
|
downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
for ii, image in enumerate(images):
|
||||||
|
i = 255. * image.cpu().numpy()
|
||||||
|
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
|
factor = (image.width // width, image.height // height)
|
||||||
|
|
||||||
|
for x, y in product(range(width), range(height)):
|
||||||
|
tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1]))
|
||||||
|
# quantize tile to fixed number of colors (creates palettized image)
|
||||||
|
tile = tile.quantize(colors=centroids, method=1, kmeans=centroids)
|
||||||
|
# get most common (median) color
|
||||||
|
color_counts = tile.getcolors()
|
||||||
|
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
|
||||||
|
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3]
|
||||||
|
|
||||||
|
downscaled = downscaled.astype(np.float32) / 255.0
|
||||||
|
return torch.from_numpy(downscaled)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageKCentroidDownscale:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
|
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
|
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "downscale"
|
||||||
|
|
||||||
|
CATEGORY = "image/downscaling"
|
||||||
|
|
||||||
|
def downscale(self, image, width, height, centroids):
|
||||||
|
s = k_centroid_downscale(image, width, height, centroids)
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
class ImageKCentroidAutoDownscale:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "downscale"
|
||||||
|
|
||||||
|
CATEGORY = "image/downscaling"
|
||||||
|
|
||||||
|
def downscale(self, image, centroids):
|
||||||
|
width = image.shape[2] // AUTO_FACTOR
|
||||||
|
height = image.shape[1] // AUTO_FACTOR
|
||||||
|
s = k_centroid_downscale(image, width, height, centroids)
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImageKCentroidDownscale": ImageKCentroidDownscale,
|
||||||
|
"ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"ImageKCentroidDownscale": "K-Centroid Downscale",
|
||||||
|
"ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)"
|
||||||
|
}
|
||||||
78
comfy_extras/mara/make_model_tileable.py
Normal file
78
comfy_extras/mara/make_model_tileable.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Mara Huldra 2023
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
'''
|
||||||
|
Patches the SD model and VAE to make it possible to generate seamlessly tilable
|
||||||
|
graphics. Horizontal and vertical direction are configurable separately.
|
||||||
|
'''
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.modules.utils import _pair
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_modules(m):
|
||||||
|
'''Return submodules of module m in flattened form.'''
|
||||||
|
yield m
|
||||||
|
for c in m.children():
|
||||||
|
yield from flatten_modules(c)
|
||||||
|
|
||||||
|
# from: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py
|
||||||
|
def make_seamless_xy(model, x, y):
|
||||||
|
for layer in flatten_modules(model):
|
||||||
|
if type(layer) == torch.nn.Conv2d:
|
||||||
|
layer.padding_modeX = 'circular' if x else 'constant'
|
||||||
|
layer.padding_modeY = 'circular' if y else 'constant'
|
||||||
|
layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0)
|
||||||
|
layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3])
|
||||||
|
layer._conv_forward = __replacementConv2DConvForward.__get__(layer, torch.nn.Conv2d)
|
||||||
|
|
||||||
|
def restore_conv2d_methods(model):
|
||||||
|
for layer in flatten_modules(model):
|
||||||
|
if type(layer) == torch.nn.Conv2d:
|
||||||
|
layer._conv_forward = torch.nn.Conv2d._conv_forward.__get__(layer, torch.nn.Conv2d)
|
||||||
|
|
||||||
|
def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||||
|
working = F.pad(input, self.paddingX, mode=self.padding_modeX)
|
||||||
|
working = F.pad(working, self.paddingY, mode=self.padding_modeY)
|
||||||
|
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
|
class MakeModelTileable:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"tile_x": (["disabled", "enabled"], { "default": "disabled", }),
|
||||||
|
"tile_y": (["disabled", "enabled"], { "default": "disabled", }),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL", "VAE")
|
||||||
|
FUNCTION = "patch_models"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/patchers"
|
||||||
|
|
||||||
|
def patch_models(self, model, vae, tile_x, tile_y):
|
||||||
|
tile_x = (tile_x == 'enabled')
|
||||||
|
tile_y = (tile_y == 'enabled')
|
||||||
|
# XXX ideally, we'd return a clone of the model, not patch the model itself
|
||||||
|
#model = model.clone()
|
||||||
|
#vae = vae.???()
|
||||||
|
|
||||||
|
restore_conv2d_methods(model.model)
|
||||||
|
restore_conv2d_methods(vae.first_stage_model)
|
||||||
|
make_seamless_xy(model.model, tile_x, tile_y)
|
||||||
|
make_seamless_xy(vae.first_stage_model, tile_x, tile_y)
|
||||||
|
return (model, vae)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"MakeModelTileable": MakeModelTileable,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"MakeModelTileable": "Patch model tileability"
|
||||||
|
}
|
||||||
68
comfy_extras/mara/mask_ops.py
Normal file
68
comfy_extras/mara/mask_ops.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Mara Huldra 2023
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
'''
|
||||||
|
Extra mask operations.
|
||||||
|
'''
|
||||||
|
import numpy as np
|
||||||
|
import rembg
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class BinarizeMask:
|
||||||
|
'''Binarize (threshold) a mask.'''
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"mask": ("MASK",),
|
||||||
|
"threshold": ("INT", {
|
||||||
|
"default": 250,
|
||||||
|
"min": 0,
|
||||||
|
"max": 255,
|
||||||
|
"step": 1,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
FUNCTION = "binarize"
|
||||||
|
|
||||||
|
CATEGORY = "mask"
|
||||||
|
|
||||||
|
def binarize(self, mask, threshold):
|
||||||
|
t = torch.Tensor([threshold / 255.])
|
||||||
|
s = (mask >= t).float()
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCutout:
|
||||||
|
'''Perform basic image cutout (adds alpha channel from mask).'''
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"mask": ("MASK",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "cutout"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def cutout(self, image, mask):
|
||||||
|
# XXX check compatible dimensions.
|
||||||
|
o = np.zeros((image.shape[0], image.shape[1], image.shape[2], 4))
|
||||||
|
o[:, :, :, 0:3] = image.cpu().numpy()
|
||||||
|
o[:, :, :, 3] = mask.cpu().numpy()
|
||||||
|
return (torch.from_numpy(o),)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"BinarizeMask": BinarizeMask,
|
||||||
|
"ImageCutout": ImageCutout,
|
||||||
|
}
|
||||||
|
|
||||||
155
comfy_extras/mara/palettize.py
Normal file
155
comfy_extras/mara/palettize.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
# Mara Huldra 2023
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
'''
|
||||||
|
Palettize an image.
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../..', 'palettes')
|
||||||
|
PAL_EXT = '.png'
|
||||||
|
|
||||||
|
QUANTIZE_METHODS = {
|
||||||
|
'median_cut': Image.Quantize.MEDIANCUT,
|
||||||
|
'max_coverage': Image.Quantize.MAXCOVERAGE,
|
||||||
|
'fast_octree': Image.Quantize.FASTOCTREE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine optimal number of colors.
|
||||||
|
# FROM: astropulse/sd-palettize
|
||||||
|
#
|
||||||
|
# Use FASTOCTREE for determining the best k, as it is
|
||||||
|
# - its faster
|
||||||
|
# - it does a better job fitting the image to lower color counts than the other options
|
||||||
|
# Max converge is best for reducing an image's colors more accurately, but
|
||||||
|
# since for best k we only care about the best number of colors, a faster more
|
||||||
|
# predictable method is better.
|
||||||
|
# (Astropulse, 2023-06-05)
|
||||||
|
def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE):
|
||||||
|
# Convert the image to RGB mode
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
# Prepare arrays for distortion calculation
|
||||||
|
pixels = np.array(image)
|
||||||
|
pixel_indices = np.reshape(pixels, (-1, 3))
|
||||||
|
|
||||||
|
# Calculate distortion for different values of k
|
||||||
|
distortions = []
|
||||||
|
for k in range(1, max_k + 1):
|
||||||
|
quantized_image = image.quantize(colors=k, method=quantize_method, kmeans=k, dither=0)
|
||||||
|
centroids = np.array(quantized_image.getpalette()[:k * 3]).reshape(-1, 3)
|
||||||
|
|
||||||
|
# Calculate distortions
|
||||||
|
distances = np.linalg.norm(pixel_indices[:, np.newaxis] - centroids, axis=2)
|
||||||
|
min_distances = np.min(distances, axis=1)
|
||||||
|
distortions.append(np.sum(min_distances ** 2))
|
||||||
|
|
||||||
|
# Calculate the rate of change of distortions
|
||||||
|
rate_of_change = np.diff(distortions) / np.array(distortions[:-1])
|
||||||
|
|
||||||
|
# Find the elbow point (best k value)
|
||||||
|
if len(rate_of_change) == 0:
|
||||||
|
best_k = 2
|
||||||
|
else:
|
||||||
|
elbow_index = np.argmax(rate_of_change) + 1
|
||||||
|
best_k = elbow_index + 2
|
||||||
|
|
||||||
|
return best_k
|
||||||
|
|
||||||
|
palette_warned = False
|
||||||
|
|
||||||
|
def list_palettes():
|
||||||
|
global palette_warned
|
||||||
|
palettes = []
|
||||||
|
try:
|
||||||
|
for filename in os.listdir(PALETTES_PATH):
|
||||||
|
if filename.endswith(PAL_EXT):
|
||||||
|
palettes.append(filename[0:-len(PAL_EXT)])
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
if not palettes and not palette_warned:
|
||||||
|
palette_warned = True
|
||||||
|
print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
|
||||||
|
return palettes
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_colors(pal_img):
|
||||||
|
palette = []
|
||||||
|
pal_img = pal_img.convert('RGB')
|
||||||
|
for i in pal_img.getcolors(16777216):
|
||||||
|
palette.append(i[1][0])
|
||||||
|
palette.append(i[1][1])
|
||||||
|
palette.append(i[1][2])
|
||||||
|
return palette
|
||||||
|
|
||||||
|
|
||||||
|
def load_palette(name):
|
||||||
|
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePalettize:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"palette": (["auto_best_k", "auto_fixed_k"] + list_palettes(), {
|
||||||
|
"default": "auto_best_k",
|
||||||
|
}),
|
||||||
|
"max_k": ("INT", {
|
||||||
|
"default": 64,
|
||||||
|
"min": 1,
|
||||||
|
"max": 256,
|
||||||
|
"step": 1,
|
||||||
|
}),
|
||||||
|
"method": (list(QUANTIZE_METHODS.keys()), {
|
||||||
|
"default": "max_coverage",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "palettize"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def palettize(self, image, palette, max_k, method):
|
||||||
|
k = None
|
||||||
|
pal_img = None
|
||||||
|
if palette not in {'auto_best_k', 'auto_fixed_k'}:
|
||||||
|
pal_entries = load_palette(palette)
|
||||||
|
k = len(pal_entries) // 3
|
||||||
|
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
|
||||||
|
pal_img.putpalette(pal_entries)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in image:
|
||||||
|
i = 255. * i.cpu().numpy()
|
||||||
|
i = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
|
if palette == 'auto_best_k':
|
||||||
|
k = determine_best_k(i, max_k)
|
||||||
|
print(f'Auto number of colors: {k}')
|
||||||
|
elif palette == 'auto_fixed_k':
|
||||||
|
k = max_k
|
||||||
|
|
||||||
|
i = i.quantize(colors=k, method=QUANTIZE_METHODS[method], kmeans=k, dither=0, palette=pal_img)
|
||||||
|
i = i.convert('RGB')
|
||||||
|
|
||||||
|
results.append(np.array(i))
|
||||||
|
|
||||||
|
result = np.array(results).astype(np.float32) / 255.0
|
||||||
|
return (torch.from_numpy(result), )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImagePalettize": ImagePalettize,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"ImagePalettize": "ImagePalettize"
|
||||||
|
}
|
||||||
45
comfy_extras/mara/patterngen.py
Normal file
45
comfy_extras/mara/patterngen.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Mara Huldra 2023
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
'''
|
||||||
|
Simple image pattern generators.
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
MAX_RESOLUTION = 8192
|
||||||
|
|
||||||
|
class ImageSolidColor:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
|
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||||
|
"r": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
|
||||||
|
"g": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
|
||||||
|
"b": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "render"
|
||||||
|
|
||||||
|
CATEGORY = "image/pattern"
|
||||||
|
|
||||||
|
def render(self, width, height, r, g, b):
|
||||||
|
color = torch.tensor([r, g, b]) / 255.0
|
||||||
|
result = color.expand(1, height, width, 3)
|
||||||
|
return (result, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImageSolidColor": ImageSolidColor,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"ImageSolidColor": "Solid Color",
|
||||||
|
}
|
||||||
|
|
||||||
119
comfy_extras/mara/remove_background.py
Normal file
119
comfy_extras/mara/remove_background.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# Mara Huldra 2023
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
'''
|
||||||
|
Estimate what pixels belong to the background and perform a cut-out, using the 'rembg' models.
|
||||||
|
'''
|
||||||
|
import numpy as np
|
||||||
|
import rembg
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
MODELS = rembg.sessions.sessions_names
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRemoveBackground:
|
||||||
|
'''Remove background from image (adds an alpha channel)'''
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"model": (MODELS, {
|
||||||
|
"default": "u2net",
|
||||||
|
}),
|
||||||
|
"alpha_matting": (["disabled", "enabled"], {
|
||||||
|
"default": "disabled",
|
||||||
|
}),
|
||||||
|
"am_foreground_thr": ("INT", {
|
||||||
|
"default": 240,
|
||||||
|
"min": 0,
|
||||||
|
"max": 255,
|
||||||
|
"step": 1,
|
||||||
|
}),
|
||||||
|
"am_background_thr": ("INT", {
|
||||||
|
"default": 10,
|
||||||
|
"min": 0,
|
||||||
|
"max": 255,
|
||||||
|
"step": 1,
|
||||||
|
}),
|
||||||
|
"am_erode_size": ("INT", {
|
||||||
|
"default": 10,
|
||||||
|
"min": 0,
|
||||||
|
"max": 255,
|
||||||
|
"step": 1,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "remove_background"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def remove_background(self, image, model, alpha_matting, am_foreground_thr, am_background_thr, am_erode_size):
|
||||||
|
session = rembg.new_session(model)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in image:
|
||||||
|
i = 255. * i.cpu().numpy()
|
||||||
|
i = np.clip(i, 0, 255).astype(np.uint8)
|
||||||
|
i = rembg.remove(i,
|
||||||
|
alpha_matting=(alpha_matting == "enabled"),
|
||||||
|
alpha_matting_foreground_threshold=am_foreground_thr,
|
||||||
|
alpha_matting_background_threshold=am_background_thr,
|
||||||
|
alpha_matting_erode_size=am_erode_size,
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
results.append(i.astype(np.float32) / 255.0)
|
||||||
|
|
||||||
|
s = torch.from_numpy(np.array(results))
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
class ImageEstimateForegroundMask:
|
||||||
|
'''
|
||||||
|
Return a mask of which pixels are estimated to belong to foreground.
|
||||||
|
Only estimates the mask, does not perform cutout like
|
||||||
|
ImageRemoveBackground.
|
||||||
|
'''
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"model": (MODELS, {
|
||||||
|
"default": "u2net",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
FUNCTION = "estimate_background"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def estimate_background(self, image, model):
|
||||||
|
session = rembg.new_session(model)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i in image:
|
||||||
|
i = 255. * i.cpu().numpy()
|
||||||
|
i = np.clip(i, 0, 255).astype(np.uint8)
|
||||||
|
i = rembg.remove(i, only_mask=True, session=session)
|
||||||
|
results.append(i.astype(np.float32) / 255.0)
|
||||||
|
|
||||||
|
s = torch.from_numpy(np.array(results))
|
||||||
|
print(s.shape)
|
||||||
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ImageRemoveBackground": ImageRemoveBackground,
|
||||||
|
"ImageEstimateForegroundMask": ImageEstimateForegroundMask,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"ImageRemoveBackground": "Remove Background (rembg)",
|
||||||
|
"ImageEstimateForegroundMask": "Estimate Foreground (rembg)",
|
||||||
|
}
|
||||||
6
nodes.py
6
nodes.py
@ -1697,4 +1697,10 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "k_centroid_downscale.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "make_model_tileable.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "mask_ops.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "palettize.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "patterngen.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "remove_background.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user