mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
make the nodes organization more sane
This commit is contained in:
parent
87cf8f613e
commit
b3038de648
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,6 +6,7 @@
|
||||
/[Tt]emp/
|
||||
/[Cc]ustom_nodes/
|
||||
!/custom_nodes/example_node.py.example
|
||||
!/custom_nodes/__init__.py
|
||||
/extra_model_paths.yaml
|
||||
/.vs
|
||||
.idea/
|
||||
|
||||
1
comfy/nodes/common.py
Normal file
1
comfy/nodes/common.py
Normal file
@ -0,0 +1 @@
|
||||
MAX_RESOLUTION=8192
|
||||
78
comfy/nodes/package.py
Normal file
78
comfy/nodes/package.py
Normal file
@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
import time
|
||||
import types
|
||||
|
||||
import nodes as base_nodes
|
||||
from comfy_extras import nodes as comfy_extras_nodes
|
||||
import custom_nodes
|
||||
from comfy.nodes.package_typing import ExportedNodes
|
||||
from functools import reduce
|
||||
|
||||
_comfy_nodes = ExportedNodes()
|
||||
|
||||
|
||||
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
||||
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
||||
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
|
||||
if node_class_mappings:
|
||||
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
|
||||
if node_display_names:
|
||||
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
|
||||
|
||||
|
||||
def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False) -> ExportedNodes:
|
||||
exported_nodes = ExportedNodes()
|
||||
timings = []
|
||||
if hasattr(module, 'NODE_CLASS_MAPPINGS'):
|
||||
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
||||
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
|
||||
if node_class_mappings:
|
||||
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
|
||||
if node_display_names:
|
||||
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
|
||||
else:
|
||||
# Iterate through all the submodules
|
||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
||||
full_name = module.__name__ + "." + name
|
||||
time_before = time.perf_counter()
|
||||
success = True
|
||||
|
||||
if full_name.endswith(".disabled"):
|
||||
continue
|
||||
try:
|
||||
submodule = importlib.import_module(full_name)
|
||||
# Recursively call the function if it's a package
|
||||
exported_nodes.update(
|
||||
_import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times))
|
||||
except KeyboardInterrupt as interrupted:
|
||||
raise interrupted
|
||||
except Exception as x:
|
||||
success = False
|
||||
timings.append((time.perf_counter() - time_before, full_name, success))
|
||||
|
||||
if print_import_times and len(timings) > 0:
|
||||
for (duration, module_name, success) in sorted(timings):
|
||||
print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}")
|
||||
return exported_nodes
|
||||
|
||||
|
||||
def import_all_nodes_in_workspace() -> ExportedNodes:
|
||||
if len(_comfy_nodes) == 0:
|
||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||
map(_import_and_enumerate_nodes_in_module, [
|
||||
# this is the list of default nodes to import
|
||||
base_nodes,
|
||||
comfy_extras_nodes
|
||||
]),
|
||||
ExportedNodes())
|
||||
|
||||
custom_nodes_mappings = _import_and_enumerate_nodes_in_module(custom_nodes, print_import_times=True)
|
||||
|
||||
# don't allow custom nodes to overwrite base nodes
|
||||
custom_nodes_mappings -= base_and_extra
|
||||
|
||||
_comfy_nodes.update(base_and_extra + custom_nodes_mappings)
|
||||
return _comfy_nodes
|
||||
43
comfy/nodes/package_typing.py
Normal file
43
comfy/nodes/package_typing.py
Normal file
@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, ClassVar, Tuple, Dict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
class CustomNode(Protocol):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict: ...
|
||||
|
||||
RETURN_TYPES: ClassVar[Tuple[str]]
|
||||
RETURN_NAMES: ClassVar[Tuple[str]] = None
|
||||
OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None
|
||||
FUNCTION: ClassVar[str]
|
||||
CATEGORY: ClassVar[str]
|
||||
OUTPUT_NODE: ClassVar[bool] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportedNodes:
|
||||
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict)
|
||||
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
|
||||
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
|
||||
self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.NODE_CLASS_MAPPINGS)
|
||||
|
||||
def __sub__(self, other):
|
||||
exported_nodes = ExportedNodes().update(self)
|
||||
for self_key in exported_nodes.NODE_CLASS_MAPPINGS:
|
||||
if self_key in other.NODE_CLASS_MAPPINGS:
|
||||
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
|
||||
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
|
||||
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
|
||||
return exported_nodes
|
||||
|
||||
def __add__(self, other):
|
||||
exported_nodes = ExportedNodes().update(self)
|
||||
return exported_nodes.update(other)
|
||||
@ -1,91 +0,0 @@
|
||||
# Mara Huldra 2023
|
||||
# SPDX-License-Identifier: MIT
|
||||
'''
|
||||
Downscale an image using k-centroid scaling. This is useful for images with a
|
||||
low number of separate colors, such as those generated by Astropulse's pixel
|
||||
art model.
|
||||
'''
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
MAX_RESOLUTION = 1024
|
||||
AUTO_FACTOR = 8
|
||||
|
||||
def k_centroid_downscale(images, width, height, centroids=2):
|
||||
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
|
||||
|
||||
downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8)
|
||||
|
||||
for ii, image in enumerate(images):
|
||||
i = 255. * image.cpu().numpy()
|
||||
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
factor = (image.width // width, image.height // height)
|
||||
|
||||
for x, y in product(range(width), range(height)):
|
||||
tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1]))
|
||||
# quantize tile to fixed number of colors (creates palettized image)
|
||||
tile = tile.quantize(colors=centroids, method=1, kmeans=centroids)
|
||||
# get most common (median) color
|
||||
color_counts = tile.getcolors()
|
||||
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
|
||||
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3]
|
||||
|
||||
downscaled = downscaled.astype(np.float32) / 255.0
|
||||
return torch.from_numpy(downscaled)
|
||||
|
||||
|
||||
class ImageKCentroidDownscale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "downscale"
|
||||
|
||||
CATEGORY = "image/downscaling"
|
||||
|
||||
def downscale(self, image, width, height, centroids):
|
||||
s = k_centroid_downscale(image, width, height, centroids)
|
||||
return (s,)
|
||||
|
||||
class ImageKCentroidAutoDownscale:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "downscale"
|
||||
|
||||
CATEGORY = "image/downscaling"
|
||||
|
||||
def downscale(self, image, centroids):
|
||||
width = image.shape[2] // AUTO_FACTOR
|
||||
height = image.shape[1] // AUTO_FACTOR
|
||||
s = k_centroid_downscale(image, width, height, centroids)
|
||||
return (s,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageKCentroidDownscale": ImageKCentroidDownscale,
|
||||
"ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageKCentroidDownscale": "K-Centroid Downscale",
|
||||
"ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)"
|
||||
}
|
||||
@ -1,78 +0,0 @@
|
||||
# Mara Huldra 2023
|
||||
# SPDX-License-Identifier: MIT
|
||||
'''
|
||||
Patches the SD model and VAE to make it possible to generate seamlessly tilable
|
||||
graphics. Horizontal and vertical direction are configurable separately.
|
||||
'''
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
|
||||
def flatten_modules(m):
|
||||
'''Return submodules of module m in flattened form.'''
|
||||
yield m
|
||||
for c in m.children():
|
||||
yield from flatten_modules(c)
|
||||
|
||||
# from: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py
|
||||
def make_seamless_xy(model, x, y):
|
||||
for layer in flatten_modules(model):
|
||||
if type(layer) == torch.nn.Conv2d:
|
||||
layer.padding_modeX = 'circular' if x else 'constant'
|
||||
layer.padding_modeY = 'circular' if y else 'constant'
|
||||
layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0)
|
||||
layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3])
|
||||
layer._conv_forward = __replacementConv2DConvForward.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
def restore_conv2d_methods(model):
|
||||
for layer in flatten_modules(model):
|
||||
if type(layer) == torch.nn.Conv2d:
|
||||
layer._conv_forward = torch.nn.Conv2d._conv_forward.__get__(layer, torch.nn.Conv2d)
|
||||
|
||||
def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]):
|
||||
working = F.pad(input, self.paddingX, mode=self.padding_modeX)
|
||||
working = F.pad(working, self.paddingY, mode=self.padding_modeY)
|
||||
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
||||
|
||||
|
||||
class MakeModelTileable:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"vae": ("VAE",),
|
||||
"tile_x": (["disabled", "enabled"], { "default": "disabled", }),
|
||||
"tile_y": (["disabled", "enabled"], { "default": "disabled", }),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "VAE")
|
||||
FUNCTION = "patch_models"
|
||||
|
||||
CATEGORY = "advanced/patchers"
|
||||
|
||||
def patch_models(self, model, vae, tile_x, tile_y):
|
||||
tile_x = (tile_x == 'enabled')
|
||||
tile_y = (tile_y == 'enabled')
|
||||
# XXX ideally, we'd return a clone of the model, not patch the model itself
|
||||
#model = model.clone()
|
||||
#vae = vae.???()
|
||||
|
||||
restore_conv2d_methods(model.model)
|
||||
restore_conv2d_methods(vae.first_stage_model)
|
||||
make_seamless_xy(model.model, tile_x, tile_y)
|
||||
make_seamless_xy(vae.first_stage_model, tile_x, tile_y)
|
||||
return (model, vae)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"MakeModelTileable": MakeModelTileable,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MakeModelTileable": "Patch model tileability"
|
||||
}
|
||||
@ -1,68 +0,0 @@
|
||||
# Mara Huldra 2023
|
||||
# SPDX-License-Identifier: MIT
|
||||
'''
|
||||
Extra mask operations.
|
||||
'''
|
||||
import numpy as np
|
||||
import rembg
|
||||
import torch
|
||||
|
||||
|
||||
class BinarizeMask:
|
||||
'''Binarize (threshold) a mask.'''
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"threshold": ("INT", {
|
||||
"default": 250,
|
||||
"min": 0,
|
||||
"max": 255,
|
||||
"step": 1,
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "binarize"
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
def binarize(self, mask, threshold):
|
||||
t = torch.Tensor([threshold / 255.])
|
||||
s = (mask >= t).float()
|
||||
return (s,)
|
||||
|
||||
|
||||
class ImageCutout:
|
||||
'''Perform basic image cutout (adds alpha channel from mask).'''
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"mask": ("MASK",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "cutout"
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def cutout(self, image, mask):
|
||||
# XXX check compatible dimensions.
|
||||
o = np.zeros((image.shape[0], image.shape[1], image.shape[2], 4))
|
||||
o[:, :, :, 0:3] = image.cpu().numpy()
|
||||
o[:, :, :, 3] = mask.cpu().numpy()
|
||||
return (torch.from_numpy(o),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"BinarizeMask": BinarizeMask,
|
||||
"ImageCutout": ImageCutout,
|
||||
}
|
||||
|
||||
@ -1,155 +0,0 @@
|
||||
# Mara Huldra 2023
|
||||
# SPDX-License-Identifier: MIT
|
||||
'''
|
||||
Palettize an image.
|
||||
'''
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../..', 'palettes')
|
||||
PAL_EXT = '.png'
|
||||
|
||||
QUANTIZE_METHODS = {
|
||||
'median_cut': Image.Quantize.MEDIANCUT,
|
||||
'max_coverage': Image.Quantize.MAXCOVERAGE,
|
||||
'fast_octree': Image.Quantize.FASTOCTREE,
|
||||
}
|
||||
|
||||
# Determine optimal number of colors.
|
||||
# FROM: astropulse/sd-palettize
|
||||
#
|
||||
# Use FASTOCTREE for determining the best k, as it is
|
||||
# - its faster
|
||||
# - it does a better job fitting the image to lower color counts than the other options
|
||||
# Max converge is best for reducing an image's colors more accurately, but
|
||||
# since for best k we only care about the best number of colors, a faster more
|
||||
# predictable method is better.
|
||||
# (Astropulse, 2023-06-05)
|
||||
def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE):
|
||||
# Convert the image to RGB mode
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Prepare arrays for distortion calculation
|
||||
pixels = np.array(image)
|
||||
pixel_indices = np.reshape(pixels, (-1, 3))
|
||||
|
||||
# Calculate distortion for different values of k
|
||||
distortions = []
|
||||
for k in range(1, max_k + 1):
|
||||
quantized_image = image.quantize(colors=k, method=quantize_method, kmeans=k, dither=0)
|
||||
centroids = np.array(quantized_image.getpalette()[:k * 3]).reshape(-1, 3)
|
||||
|
||||
# Calculate distortions
|
||||
distances = np.linalg.norm(pixel_indices[:, np.newaxis] - centroids, axis=2)
|
||||
min_distances = np.min(distances, axis=1)
|
||||
distortions.append(np.sum(min_distances ** 2))
|
||||
|
||||
# Calculate the rate of change of distortions
|
||||
rate_of_change = np.diff(distortions) / np.array(distortions[:-1])
|
||||
|
||||
# Find the elbow point (best k value)
|
||||
if len(rate_of_change) == 0:
|
||||
best_k = 2
|
||||
else:
|
||||
elbow_index = np.argmax(rate_of_change) + 1
|
||||
best_k = elbow_index + 2
|
||||
|
||||
return best_k
|
||||
|
||||
palette_warned = False
|
||||
|
||||
def list_palettes():
|
||||
global palette_warned
|
||||
palettes = []
|
||||
try:
|
||||
for filename in os.listdir(PALETTES_PATH):
|
||||
if filename.endswith(PAL_EXT):
|
||||
palettes.append(filename[0:-len(PAL_EXT)])
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
if not palettes and not palette_warned:
|
||||
palette_warned = True
|
||||
print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
|
||||
return palettes
|
||||
|
||||
|
||||
def get_image_colors(pal_img):
|
||||
palette = []
|
||||
pal_img = pal_img.convert('RGB')
|
||||
for i in pal_img.getcolors(16777216):
|
||||
palette.append(i[1][0])
|
||||
palette.append(i[1][1])
|
||||
palette.append(i[1][2])
|
||||
return palette
|
||||
|
||||
|
||||
def load_palette(name):
|
||||
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
|
||||
|
||||
|
||||
class ImagePalettize:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"palette": (["auto_best_k", "auto_fixed_k"] + list_palettes(), {
|
||||
"default": "auto_best_k",
|
||||
}),
|
||||
"max_k": ("INT", {
|
||||
"default": 64,
|
||||
"min": 1,
|
||||
"max": 256,
|
||||
"step": 1,
|
||||
}),
|
||||
"method": (list(QUANTIZE_METHODS.keys()), {
|
||||
"default": "max_coverage",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "palettize"
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def palettize(self, image, palette, max_k, method):
|
||||
k = None
|
||||
pal_img = None
|
||||
if palette not in {'auto_best_k', 'auto_fixed_k'}:
|
||||
pal_entries = load_palette(palette)
|
||||
k = len(pal_entries) // 3
|
||||
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
|
||||
pal_img.putpalette(pal_entries)
|
||||
|
||||
results = []
|
||||
|
||||
for i in image:
|
||||
i = 255. * i.cpu().numpy()
|
||||
i = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
|
||||
if palette == 'auto_best_k':
|
||||
k = determine_best_k(i, max_k)
|
||||
print(f'Auto number of colors: {k}')
|
||||
elif palette == 'auto_fixed_k':
|
||||
k = max_k
|
||||
|
||||
i = i.quantize(colors=k, method=QUANTIZE_METHODS[method], kmeans=k, dither=0, palette=pal_img)
|
||||
i = i.convert('RGB')
|
||||
|
||||
results.append(np.array(i))
|
||||
|
||||
result = np.array(results).astype(np.float32) / 255.0
|
||||
return (torch.from_numpy(result), )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImagePalettize": ImagePalettize,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImagePalettize": "ImagePalettize"
|
||||
}
|
||||
@ -1,45 +0,0 @@
|
||||
# Mara Huldra 2023
|
||||
# SPDX-License-Identifier: MIT
|
||||
'''
|
||||
Simple image pattern generators.
|
||||
'''
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
MAX_RESOLUTION = 8192
|
||||
|
||||
class ImageSolidColor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"r": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
|
||||
"g": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
|
||||
"b": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "render"
|
||||
|
||||
CATEGORY = "image/pattern"
|
||||
|
||||
def render(self, width, height, r, g, b):
|
||||
color = torch.tensor([r, g, b]) / 255.0
|
||||
result = color.expand(1, height, width, 3)
|
||||
return (result, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageSolidColor": ImageSolidColor,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageSolidColor": "Solid Color",
|
||||
}
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
# Mara Huldra 2023
|
||||
# SPDX-License-Identifier: MIT
|
||||
'''
|
||||
Estimate what pixels belong to the background and perform a cut-out, using the 'rembg' models.
|
||||
'''
|
||||
import numpy as np
|
||||
import rembg
|
||||
import torch
|
||||
|
||||
|
||||
MODELS = rembg.sessions.sessions_names
|
||||
|
||||
|
||||
class ImageRemoveBackground:
|
||||
'''Remove background from image (adds an alpha channel)'''
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"model": (MODELS, {
|
||||
"default": "u2net",
|
||||
}),
|
||||
"alpha_matting": (["disabled", "enabled"], {
|
||||
"default": "disabled",
|
||||
}),
|
||||
"am_foreground_thr": ("INT", {
|
||||
"default": 240,
|
||||
"min": 0,
|
||||
"max": 255,
|
||||
"step": 1,
|
||||
}),
|
||||
"am_background_thr": ("INT", {
|
||||
"default": 10,
|
||||
"min": 0,
|
||||
"max": 255,
|
||||
"step": 1,
|
||||
}),
|
||||
"am_erode_size": ("INT", {
|
||||
"default": 10,
|
||||
"min": 0,
|
||||
"max": 255,
|
||||
"step": 1,
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "remove_background"
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def remove_background(self, image, model, alpha_matting, am_foreground_thr, am_background_thr, am_erode_size):
|
||||
session = rembg.new_session(model)
|
||||
results = []
|
||||
|
||||
for i in image:
|
||||
i = 255. * i.cpu().numpy()
|
||||
i = np.clip(i, 0, 255).astype(np.uint8)
|
||||
i = rembg.remove(i,
|
||||
alpha_matting=(alpha_matting == "enabled"),
|
||||
alpha_matting_foreground_threshold=am_foreground_thr,
|
||||
alpha_matting_background_threshold=am_background_thr,
|
||||
alpha_matting_erode_size=am_erode_size,
|
||||
session=session,
|
||||
)
|
||||
results.append(i.astype(np.float32) / 255.0)
|
||||
|
||||
s = torch.from_numpy(np.array(results))
|
||||
return (s,)
|
||||
|
||||
class ImageEstimateForegroundMask:
|
||||
'''
|
||||
Return a mask of which pixels are estimated to belong to foreground.
|
||||
Only estimates the mask, does not perform cutout like
|
||||
ImageRemoveBackground.
|
||||
'''
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"model": (MODELS, {
|
||||
"default": "u2net",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "estimate_background"
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def estimate_background(self, image, model):
|
||||
session = rembg.new_session(model)
|
||||
results = []
|
||||
|
||||
for i in image:
|
||||
i = 255. * i.cpu().numpy()
|
||||
i = np.clip(i, 0, 255).astype(np.uint8)
|
||||
i = rembg.remove(i, only_mask=True, session=session)
|
||||
results.append(i.astype(np.float32) / 255.0)
|
||||
|
||||
s = torch.from_numpy(np.array(results))
|
||||
print(s.shape)
|
||||
return (s,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageRemoveBackground": ImageRemoveBackground,
|
||||
"ImageEstimateForegroundMask": ImageEstimateForegroundMask,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageRemoveBackground": "Remove Background (rembg)",
|
||||
"ImageEstimateForegroundMask": "Estimate Foreground (rembg)",
|
||||
}
|
||||
0
comfy_extras/nodes/__init__.py
Normal file
0
comfy_extras/nodes/__init__.py
Normal file
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from nodes import MAX_RESOLUTION
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
|
||||
|
||||
class CLIPTextEncodeSDXLRefiner:
|
||||
@classmethod
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
|
||||
from nodes import MAX_RESOLUTION
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
|
||||
|
||||
class LatentCompositeMasked:
|
||||
@classmethod
|
||||
10
execution.py
10
execution.py
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import copy
|
||||
import datetime
|
||||
import gc
|
||||
import heapq
|
||||
import threading
|
||||
import traceback
|
||||
@ -14,7 +13,8 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
import nodes
|
||||
from comfy.nodes.package import import_all_nodes_in_workspace
|
||||
nodes = import_all_nodes_in_workspace()
|
||||
import comfy.model_management
|
||||
|
||||
"""
|
||||
@ -111,7 +111,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
results = []
|
||||
if input_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
elif max_len_input == 0:
|
||||
if allow_interrupt:
|
||||
@ -120,7 +120,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
@ -368,7 +368,7 @@ class PromptExecutor:
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
comfy.model_management.interrupt_current_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
self.server.client_id = extra_data["client_id"]
|
||||
|
||||
2
main.py
2
main.py
@ -69,7 +69,6 @@ import yaml
|
||||
import execution
|
||||
import server
|
||||
from server import BinaryEventTypes
|
||||
from nodes import init_custom_nodes
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
@ -143,7 +142,6 @@ if __name__ == "__main__":
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
init_custom_nodes()
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
|
||||
|
||||
86
nodes.py
86
nodes.py
@ -14,9 +14,6 @@ from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
|
||||
|
||||
import comfy.diffusers_load
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
@ -32,14 +29,8 @@ import importlib
|
||||
|
||||
import folder_paths
|
||||
import latent_preview
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def interrupt_processing(value=True):
|
||||
comfy.model_management.interrupt_current_processing(value)
|
||||
|
||||
MAX_RESOLUTION=8192
|
||||
|
||||
class CLIPTextEncode:
|
||||
@classmethod
|
||||
@ -1630,77 +1621,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# _for_testing
|
||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||
}
|
||||
|
||||
def load_custom_node(module_path, ignore=set()):
|
||||
module_name = os.path.basename(module_path)
|
||||
if os.path.isfile(module_path):
|
||||
sp = os.path.splitext(module_path)
|
||||
module_name = sp[0]
|
||||
try:
|
||||
if os.path.isfile(module_path):
|
||||
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
else:
|
||||
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
sys.modules[module_name] = module
|
||||
module_spec.loader.exec_module(module)
|
||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||
for name in module.NODE_CLASS_MAPPINGS:
|
||||
if name not in ignore:
|
||||
NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
|
||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||
return True
|
||||
else:
|
||||
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Cannot import {module_path} module for custom nodes:", e)
|
||||
return False
|
||||
|
||||
def load_custom_nodes():
|
||||
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
if "__pycache__" in possible_modules:
|
||||
possible_modules.remove("__pycache__")
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
time_before = time.perf_counter()
|
||||
success = load_custom_node(module_path, base_node_names)
|
||||
node_import_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
if len(node_import_times) > 0:
|
||||
print("\nImport times for custom nodes:")
|
||||
for n in sorted(node_import_times):
|
||||
if n[2]:
|
||||
import_message = ""
|
||||
else:
|
||||
import_message = " (IMPORT FAILED)"
|
||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||
print()
|
||||
|
||||
def init_custom_nodes():
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "k_centroid_downscale.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "make_model_tileable.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "mask_ops.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "palettize.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "patterngen.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "remove_background.py"))
|
||||
load_custom_nodes()
|
||||
}
|
||||
@ -17,15 +17,16 @@ from aiohttp import web
|
||||
|
||||
import execution
|
||||
import folder_paths
|
||||
import nodes
|
||||
import mimetypes
|
||||
|
||||
from comfy.digest import digest
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.nodes.package import import_all_nodes_in_workspace
|
||||
from comfy.vendor.appdirs import user_data_dir
|
||||
|
||||
nodes = import_all_nodes_in_workspace()
|
||||
|
||||
class BinaryEventTypes:
|
||||
PREVIEW_IMAGE = 1
|
||||
@ -488,7 +489,7 @@ class PromptServer():
|
||||
|
||||
@routes.post("/interrupt")
|
||||
async def post_interrupt(request):
|
||||
nodes.interrupt_processing()
|
||||
comfy.model_management.interrupt_current_processing()
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/history")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user