make the nodes organization more sane

This commit is contained in:
Benjamin Berman 2023-08-02 15:21:46 -07:00
parent 87cf8f613e
commit b3038de648
25 changed files with 137 additions and 651 deletions

1
.gitignore vendored
View File

@ -6,6 +6,7 @@
/[Tt]emp/
/[Cc]ustom_nodes/
!/custom_nodes/example_node.py.example
!/custom_nodes/__init__.py
/extra_model_paths.yaml
/.vs
.idea/

1
comfy/nodes/common.py Normal file
View File

@ -0,0 +1 @@
MAX_RESOLUTION=8192

78
comfy/nodes/package.py Normal file
View File

@ -0,0 +1,78 @@
from __future__ import annotations
import importlib
import pkgutil
import time
import types
import nodes as base_nodes
from comfy_extras import nodes as comfy_extras_nodes
import custom_nodes
from comfy.nodes.package_typing import ExportedNodes
from functools import reduce
_comfy_nodes = ExportedNodes()
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
if node_class_mappings:
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
if node_display_names:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False) -> ExportedNodes:
exported_nodes = ExportedNodes()
timings = []
if hasattr(module, 'NODE_CLASS_MAPPINGS'):
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
if node_class_mappings:
exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings)
if node_display_names:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
else:
# Iterate through all the submodules
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
full_name = module.__name__ + "." + name
time_before = time.perf_counter()
success = True
if full_name.endswith(".disabled"):
continue
try:
submodule = importlib.import_module(full_name)
# Recursively call the function if it's a package
exported_nodes.update(
_import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times))
except KeyboardInterrupt as interrupted:
raise interrupted
except Exception as x:
success = False
timings.append((time.perf_counter() - time_before, full_name, success))
if print_import_times and len(timings) > 0:
for (duration, module_name, success) in sorted(timings):
print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}")
return exported_nodes
def import_all_nodes_in_workspace() -> ExportedNodes:
if len(_comfy_nodes) == 0:
base_and_extra = reduce(lambda x, y: x.update(y),
map(_import_and_enumerate_nodes_in_module, [
# this is the list of default nodes to import
base_nodes,
comfy_extras_nodes
]),
ExportedNodes())
custom_nodes_mappings = _import_and_enumerate_nodes_in_module(custom_nodes, print_import_times=True)
# don't allow custom nodes to overwrite base nodes
custom_nodes_mappings -= base_and_extra
_comfy_nodes.update(base_and_extra + custom_nodes_mappings)
return _comfy_nodes

View File

@ -0,0 +1,43 @@
from __future__ import annotations
from typing import Protocol, ClassVar, Tuple, Dict
from dataclasses import dataclass, field
class CustomNode(Protocol):
@classmethod
def INPUT_TYPES(cls) -> dict: ...
RETURN_TYPES: ClassVar[Tuple[str]]
RETURN_NAMES: ClassVar[Tuple[str]] = None
OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None
FUNCTION: ClassVar[str]
CATEGORY: ClassVar[str]
OUTPUT_NODE: ClassVar[bool] = None
@dataclass
class ExportedNodes:
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
return self
def __len__(self):
return len(self.NODE_CLASS_MAPPINGS)
def __sub__(self, other):
exported_nodes = ExportedNodes().update(self)
for self_key in exported_nodes.NODE_CLASS_MAPPINGS:
if self_key in other.NODE_CLASS_MAPPINGS:
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
return exported_nodes
def __add__(self, other):
exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other)

View File

@ -1,91 +0,0 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Downscale an image using k-centroid scaling. This is useful for images with a
low number of separate colors, such as those generated by Astropulse's pixel
art model.
'''
from itertools import product
import numpy as np
from PIL import Image
import torch
MAX_RESOLUTION = 1024
AUTO_FACTOR = 8
def k_centroid_downscale(images, width, height, centroids=2):
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8)
for ii, image in enumerate(images):
i = 255. * image.cpu().numpy()
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
factor = (image.width // width, image.height // height)
for x, y in product(range(width), range(height)):
tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1]))
# quantize tile to fixed number of colors (creates palettized image)
tile = tile.quantize(colors=centroids, method=1, kmeans=centroids)
# get most common (median) color
color_counts = tile.getcolors()
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3]
downscaled = downscaled.astype(np.float32) / 255.0
return torch.from_numpy(downscaled)
class ImageKCentroidDownscale:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, width, height, centroids):
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
class ImageKCentroidAutoDownscale:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "downscale"
CATEGORY = "image/downscaling"
def downscale(self, image, centroids):
width = image.shape[2] // AUTO_FACTOR
height = image.shape[1] // AUTO_FACTOR
s = k_centroid_downscale(image, width, height, centroids)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageKCentroidDownscale": ImageKCentroidDownscale,
"ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageKCentroidDownscale": "K-Centroid Downscale",
"ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)"
}

View File

@ -1,78 +0,0 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Patches the SD model and VAE to make it possible to generate seamlessly tilable
graphics. Horizontal and vertical direction are configurable separately.
'''
from typing import Optional
import torch
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
def flatten_modules(m):
'''Return submodules of module m in flattened form.'''
yield m
for c in m.children():
yield from flatten_modules(c)
# from: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py
def make_seamless_xy(model, x, y):
for layer in flatten_modules(model):
if type(layer) == torch.nn.Conv2d:
layer.padding_modeX = 'circular' if x else 'constant'
layer.padding_modeY = 'circular' if y else 'constant'
layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0)
layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3])
layer._conv_forward = __replacementConv2DConvForward.__get__(layer, torch.nn.Conv2d)
def restore_conv2d_methods(model):
for layer in flatten_modules(model):
if type(layer) == torch.nn.Conv2d:
layer._conv_forward = torch.nn.Conv2d._conv_forward.__get__(layer, torch.nn.Conv2d)
def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]):
working = F.pad(input, self.paddingX, mode=self.padding_modeX)
working = F.pad(working, self.paddingY, mode=self.padding_modeY)
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
class MakeModelTileable:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"vae": ("VAE",),
"tile_x": (["disabled", "enabled"], { "default": "disabled", }),
"tile_y": (["disabled", "enabled"], { "default": "disabled", }),
}
}
RETURN_TYPES = ("MODEL", "VAE")
FUNCTION = "patch_models"
CATEGORY = "advanced/patchers"
def patch_models(self, model, vae, tile_x, tile_y):
tile_x = (tile_x == 'enabled')
tile_y = (tile_y == 'enabled')
# XXX ideally, we'd return a clone of the model, not patch the model itself
#model = model.clone()
#vae = vae.???()
restore_conv2d_methods(model.model)
restore_conv2d_methods(vae.first_stage_model)
make_seamless_xy(model.model, tile_x, tile_y)
make_seamless_xy(vae.first_stage_model, tile_x, tile_y)
return (model, vae)
NODE_CLASS_MAPPINGS = {
"MakeModelTileable": MakeModelTileable,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"MakeModelTileable": "Patch model tileability"
}

View File

@ -1,68 +0,0 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Extra mask operations.
'''
import numpy as np
import rembg
import torch
class BinarizeMask:
'''Binarize (threshold) a mask.'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
"threshold": ("INT", {
"default": 250,
"min": 0,
"max": 255,
"step": 1,
}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "binarize"
CATEGORY = "mask"
def binarize(self, mask, threshold):
t = torch.Tensor([threshold / 255.])
s = (mask >= t).float()
return (s,)
class ImageCutout:
'''Perform basic image cutout (adds alpha channel from mask).'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"mask": ("MASK",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "cutout"
CATEGORY = "image/postprocessing"
def cutout(self, image, mask):
# XXX check compatible dimensions.
o = np.zeros((image.shape[0], image.shape[1], image.shape[2], 4))
o[:, :, :, 0:3] = image.cpu().numpy()
o[:, :, :, 3] = mask.cpu().numpy()
return (torch.from_numpy(o),)
NODE_CLASS_MAPPINGS = {
"BinarizeMask": BinarizeMask,
"ImageCutout": ImageCutout,
}

View File

@ -1,155 +0,0 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Palettize an image.
'''
import os
import numpy as np
from PIL import Image
import torch
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../..', 'palettes')
PAL_EXT = '.png'
QUANTIZE_METHODS = {
'median_cut': Image.Quantize.MEDIANCUT,
'max_coverage': Image.Quantize.MAXCOVERAGE,
'fast_octree': Image.Quantize.FASTOCTREE,
}
# Determine optimal number of colors.
# FROM: astropulse/sd-palettize
#
# Use FASTOCTREE for determining the best k, as it is
# - its faster
# - it does a better job fitting the image to lower color counts than the other options
# Max converge is best for reducing an image's colors more accurately, but
# since for best k we only care about the best number of colors, a faster more
# predictable method is better.
# (Astropulse, 2023-06-05)
def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE):
# Convert the image to RGB mode
image = image.convert("RGB")
# Prepare arrays for distortion calculation
pixels = np.array(image)
pixel_indices = np.reshape(pixels, (-1, 3))
# Calculate distortion for different values of k
distortions = []
for k in range(1, max_k + 1):
quantized_image = image.quantize(colors=k, method=quantize_method, kmeans=k, dither=0)
centroids = np.array(quantized_image.getpalette()[:k * 3]).reshape(-1, 3)
# Calculate distortions
distances = np.linalg.norm(pixel_indices[:, np.newaxis] - centroids, axis=2)
min_distances = np.min(distances, axis=1)
distortions.append(np.sum(min_distances ** 2))
# Calculate the rate of change of distortions
rate_of_change = np.diff(distortions) / np.array(distortions[:-1])
# Find the elbow point (best k value)
if len(rate_of_change) == 0:
best_k = 2
else:
elbow_index = np.argmax(rate_of_change) + 1
best_k = elbow_index + 2
return best_k
palette_warned = False
def list_palettes():
global palette_warned
palettes = []
try:
for filename in os.listdir(PALETTES_PATH):
if filename.endswith(PAL_EXT):
palettes.append(filename[0:-len(PAL_EXT)])
except FileNotFoundError:
pass
if not palettes and not palette_warned:
palette_warned = True
print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
return palettes
def get_image_colors(pal_img):
palette = []
pal_img = pal_img.convert('RGB')
for i in pal_img.getcolors(16777216):
palette.append(i[1][0])
palette.append(i[1][1])
palette.append(i[1][2])
return palette
def load_palette(name):
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
class ImagePalettize:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"palette": (["auto_best_k", "auto_fixed_k"] + list_palettes(), {
"default": "auto_best_k",
}),
"max_k": ("INT", {
"default": 64,
"min": 1,
"max": 256,
"step": 1,
}),
"method": (list(QUANTIZE_METHODS.keys()), {
"default": "max_coverage",
}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "palettize"
CATEGORY = "image/postprocessing"
def palettize(self, image, palette, max_k, method):
k = None
pal_img = None
if palette not in {'auto_best_k', 'auto_fixed_k'}:
pal_entries = load_palette(palette)
k = len(pal_entries) // 3
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
pal_img.putpalette(pal_entries)
results = []
for i in image:
i = 255. * i.cpu().numpy()
i = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
if palette == 'auto_best_k':
k = determine_best_k(i, max_k)
print(f'Auto number of colors: {k}')
elif palette == 'auto_fixed_k':
k = max_k
i = i.quantize(colors=k, method=QUANTIZE_METHODS[method], kmeans=k, dither=0, palette=pal_img)
i = i.convert('RGB')
results.append(np.array(i))
result = np.array(results).astype(np.float32) / 255.0
return (torch.from_numpy(result), )
NODE_CLASS_MAPPINGS = {
"ImagePalettize": ImagePalettize,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImagePalettize": "ImagePalettize"
}

View File

@ -1,45 +0,0 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Simple image pattern generators.
'''
import os
import numpy as np
from PIL import Image
import torch
MAX_RESOLUTION = 8192
class ImageSolidColor:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"r": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
"g": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
"b": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "render"
CATEGORY = "image/pattern"
def render(self, width, height, r, g, b):
color = torch.tensor([r, g, b]) / 255.0
result = color.expand(1, height, width, 3)
return (result, )
NODE_CLASS_MAPPINGS = {
"ImageSolidColor": ImageSolidColor,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSolidColor": "Solid Color",
}

View File

@ -1,119 +0,0 @@
# Mara Huldra 2023
# SPDX-License-Identifier: MIT
'''
Estimate what pixels belong to the background and perform a cut-out, using the 'rembg' models.
'''
import numpy as np
import rembg
import torch
MODELS = rembg.sessions.sessions_names
class ImageRemoveBackground:
'''Remove background from image (adds an alpha channel)'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"model": (MODELS, {
"default": "u2net",
}),
"alpha_matting": (["disabled", "enabled"], {
"default": "disabled",
}),
"am_foreground_thr": ("INT", {
"default": 240,
"min": 0,
"max": 255,
"step": 1,
}),
"am_background_thr": ("INT", {
"default": 10,
"min": 0,
"max": 255,
"step": 1,
}),
"am_erode_size": ("INT", {
"default": 10,
"min": 0,
"max": 255,
"step": 1,
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "remove_background"
CATEGORY = "image/postprocessing"
def remove_background(self, image, model, alpha_matting, am_foreground_thr, am_background_thr, am_erode_size):
session = rembg.new_session(model)
results = []
for i in image:
i = 255. * i.cpu().numpy()
i = np.clip(i, 0, 255).astype(np.uint8)
i = rembg.remove(i,
alpha_matting=(alpha_matting == "enabled"),
alpha_matting_foreground_threshold=am_foreground_thr,
alpha_matting_background_threshold=am_background_thr,
alpha_matting_erode_size=am_erode_size,
session=session,
)
results.append(i.astype(np.float32) / 255.0)
s = torch.from_numpy(np.array(results))
return (s,)
class ImageEstimateForegroundMask:
'''
Return a mask of which pixels are estimated to belong to foreground.
Only estimates the mask, does not perform cutout like
ImageRemoveBackground.
'''
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"model": (MODELS, {
"default": "u2net",
}),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "estimate_background"
CATEGORY = "image/postprocessing"
def estimate_background(self, image, model):
session = rembg.new_session(model)
results = []
for i in image:
i = 255. * i.cpu().numpy()
i = np.clip(i, 0, 255).astype(np.uint8)
i = rembg.remove(i, only_mask=True, session=session)
results.append(i.astype(np.float32) / 255.0)
s = torch.from_numpy(np.array(results))
print(s.shape)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageRemoveBackground": ImageRemoveBackground,
"ImageEstimateForegroundMask": ImageEstimateForegroundMask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageRemoveBackground": "Remove Background (rembg)",
"ImageEstimateForegroundMask": "Estimate Foreground (rembg)",
}

View File

View File

@ -1,5 +1,6 @@
import torch
from nodes import MAX_RESOLUTION
from comfy.nodes.common import MAX_RESOLUTION
class CLIPTextEncodeSDXLRefiner:
@classmethod

View File

@ -1,6 +1,7 @@
import torch
from nodes import MAX_RESOLUTION
from comfy.nodes.common import MAX_RESOLUTION
class LatentCompositeMasked:
@classmethod

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import copy
import datetime
import gc
import heapq
import threading
import traceback
@ -14,7 +13,8 @@ import sys
import torch
import nodes
from comfy.nodes.package import import_all_nodes_in_workspace
nodes = import_all_nodes_in_workspace()
import comfy.model_management
"""
@ -111,7 +111,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
results = []
if input_is_list:
if allow_interrupt:
nodes.before_node_execution()
comfy.model_management.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)(**input_data_all))
elif max_len_input == 0:
if allow_interrupt:
@ -120,7 +120,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
comfy.model_management.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
@ -368,7 +368,7 @@ class PromptExecutor:
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
comfy.model_management.interrupt_current_processing(False)
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]

View File

@ -69,7 +69,6 @@ import yaml
import execution
import server
from server import BinaryEventTypes
from nodes import init_custom_nodes
import comfy.model_management
@ -143,7 +142,6 @@ if __name__ == "__main__":
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
init_custom_nodes()
server.add_routes()
hijack_progress(server)

View File

@ -14,9 +14,6 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
import comfy.diffusers_load
import comfy.samplers
import comfy.sample
@ -32,14 +29,8 @@ import importlib
import folder_paths
import latent_preview
from comfy.nodes.common import MAX_RESOLUTION
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192
class CLIPTextEncode:
@classmethod
@ -1630,77 +1621,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
# _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
}
def load_custom_node(module_path, ignore=set()):
module_name = os.path.basename(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]
try:
if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
else:
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
for name in module.NODE_CLASS_MAPPINGS:
if name not in ignore:
NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e)
return False
def load_custom_nodes():
base_node_names = set(NODE_CLASS_MAPPINGS.keys())
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
if module_path.endswith(".disabled"): continue
time_before = time.perf_counter()
success = load_custom_node(module_path, base_node_names)
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
print("\nImport times for custom nodes:")
for n in sorted(node_import_times):
if n[2]:
import_message = ""
else:
import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
def init_custom_nodes():
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "k_centroid_downscale.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "make_model_tileable.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "mask_ops.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "palettize.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "patterngen.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "remove_background.py"))
load_custom_nodes()
}

View File

@ -17,15 +17,16 @@ from aiohttp import web
import execution
import folder_paths
import nodes
import mimetypes
from comfy.digest import digest
from comfy.cli_args import args
import comfy.utils
import comfy.model_management
from comfy.nodes.package import import_all_nodes_in_workspace
from comfy.vendor.appdirs import user_data_dir
nodes = import_all_nodes_in_workspace()
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@ -488,7 +489,7 @@ class PromptServer():
@routes.post("/interrupt")
async def post_interrupt(request):
nodes.interrupt_processing()
comfy.model_management.interrupt_current_processing()
return web.Response(status=200)
@routes.post("/history")