diff --git a/.gitignore b/.gitignore index b1146714d..adc66db68 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ /[Tt]emp/ /[Cc]ustom_nodes/ !/custom_nodes/example_node.py.example +!/custom_nodes/__init__.py /extra_model_paths.yaml /.vs .idea/ diff --git a/comfy_extras/mara/__init__.py b/comfy/nodes/__init__.py similarity index 100% rename from comfy_extras/mara/__init__.py rename to comfy/nodes/__init__.py diff --git a/comfy/nodes/common.py b/comfy/nodes/common.py new file mode 100644 index 000000000..84474a17d --- /dev/null +++ b/comfy/nodes/common.py @@ -0,0 +1 @@ +MAX_RESOLUTION=8192 diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py new file mode 100644 index 000000000..f66a8e074 --- /dev/null +++ b/comfy/nodes/package.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import importlib +import pkgutil +import time +import types + +import nodes as base_nodes +from comfy_extras import nodes as comfy_extras_nodes +import custom_nodes +from comfy.nodes.package_typing import ExportedNodes +from functools import reduce + +_comfy_nodes = ExportedNodes() + + +def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): + node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) + node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None) + if node_class_mappings: + exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings) + if node_display_names: + exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names) + + +def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False) -> ExportedNodes: + exported_nodes = ExportedNodes() + timings = [] + if hasattr(module, 'NODE_CLASS_MAPPINGS'): + node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) + node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None) + if node_class_mappings: + exported_nodes.NODE_CLASS_MAPPINGS.update(node_class_mappings) + if node_display_names: + exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names) + else: + # Iterate through all the submodules + for _, name, is_pkg in pkgutil.iter_modules(module.__path__): + full_name = module.__name__ + "." + name + time_before = time.perf_counter() + success = True + + if full_name.endswith(".disabled"): + continue + try: + submodule = importlib.import_module(full_name) + # Recursively call the function if it's a package + exported_nodes.update( + _import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times)) + except KeyboardInterrupt as interrupted: + raise interrupted + except Exception as x: + success = False + timings.append((time.perf_counter() - time_before, full_name, success)) + + if print_import_times and len(timings) > 0: + for (duration, module_name, success) in sorted(timings): + print(f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name}") + return exported_nodes + + +def import_all_nodes_in_workspace() -> ExportedNodes: + if len(_comfy_nodes) == 0: + base_and_extra = reduce(lambda x, y: x.update(y), + map(_import_and_enumerate_nodes_in_module, [ + # this is the list of default nodes to import + base_nodes, + comfy_extras_nodes + ]), + ExportedNodes()) + + custom_nodes_mappings = _import_and_enumerate_nodes_in_module(custom_nodes, print_import_times=True) + + # don't allow custom nodes to overwrite base nodes + custom_nodes_mappings -= base_and_extra + + _comfy_nodes.update(base_and_extra + custom_nodes_mappings) + return _comfy_nodes diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py new file mode 100644 index 000000000..64a8cff11 --- /dev/null +++ b/comfy/nodes/package_typing.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Protocol, ClassVar, Tuple, Dict +from dataclasses import dataclass, field + + +class CustomNode(Protocol): + @classmethod + def INPUT_TYPES(cls) -> dict: ... + + RETURN_TYPES: ClassVar[Tuple[str]] + RETURN_NAMES: ClassVar[Tuple[str]] = None + OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None + FUNCTION: ClassVar[str] + CATEGORY: ClassVar[str] + OUTPUT_NODE: ClassVar[bool] = None + + +@dataclass +class ExportedNodes: + NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict) + NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict) + + def update(self, exported_nodes: ExportedNodes) -> ExportedNodes: + self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS) + self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS) + return self + + def __len__(self): + return len(self.NODE_CLASS_MAPPINGS) + + def __sub__(self, other): + exported_nodes = ExportedNodes().update(self) + for self_key in exported_nodes.NODE_CLASS_MAPPINGS: + if self_key in other.NODE_CLASS_MAPPINGS: + exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key) + if self_key in other.NODE_DISPLAY_NAME_MAPPINGS: + exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key) + return exported_nodes + + def __add__(self, other): + exported_nodes = ExportedNodes().update(self) + return exported_nodes.update(other) diff --git a/comfy_extras/mara/k_centroid_downscale.py b/comfy_extras/mara/k_centroid_downscale.py deleted file mode 100644 index 2f70c694c..000000000 --- a/comfy_extras/mara/k_centroid_downscale.py +++ /dev/null @@ -1,91 +0,0 @@ -# Mara Huldra 2023 -# SPDX-License-Identifier: MIT -''' -Downscale an image using k-centroid scaling. This is useful for images with a -low number of separate colors, such as those generated by Astropulse's pixel -art model. -''' -from itertools import product - -import numpy as np -from PIL import Image -import torch - -MAX_RESOLUTION = 1024 -AUTO_FACTOR = 8 - -def k_centroid_downscale(images, width, height, centroids=2): - '''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.''' - - downscaled = np.zeros((images.shape[0], height, width, 3), dtype=np.uint8) - - for ii, image in enumerate(images): - i = 255. * image.cpu().numpy() - image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - factor = (image.width // width, image.height // height) - - for x, y in product(range(width), range(height)): - tile = image.crop((x * factor[0], y * factor[1], (x + 1) * factor[0], (y + 1) * factor[1])) - # quantize tile to fixed number of colors (creates palettized image) - tile = tile.quantize(colors=centroids, method=1, kmeans=centroids) - # get most common (median) color - color_counts = tile.getcolors() - most_common_idx = max(color_counts, key=lambda x: x[0])[1] - downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3] - - downscaled = downscaled.astype(np.float32) / 255.0 - return torch.from_numpy(downscaled) - - -class ImageKCentroidDownscale: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}), - } - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "downscale" - - CATEGORY = "image/downscaling" - - def downscale(self, image, width, height, centroids): - s = k_centroid_downscale(image, width, height, centroids) - return (s,) - -class ImageKCentroidAutoDownscale: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "centroids": ("INT", {"default": 2, "min": 1, "max": 256, "step": 1}), - } - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "downscale" - - CATEGORY = "image/downscaling" - - def downscale(self, image, centroids): - width = image.shape[2] // AUTO_FACTOR - height = image.shape[1] // AUTO_FACTOR - s = k_centroid_downscale(image, width, height, centroids) - return (s,) - - -NODE_CLASS_MAPPINGS = { - "ImageKCentroidDownscale": ImageKCentroidDownscale, - "ImageKCentroidAutoDownscale": ImageKCentroidAutoDownscale, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageKCentroidDownscale": "K-Centroid Downscale", - "ImageKCentroidAutoDownscale": "K-Centroid Downscale (autosize)" -} diff --git a/comfy_extras/mara/make_model_tileable.py b/comfy_extras/mara/make_model_tileable.py deleted file mode 100644 index b2192dbd8..000000000 --- a/comfy_extras/mara/make_model_tileable.py +++ /dev/null @@ -1,78 +0,0 @@ -# Mara Huldra 2023 -# SPDX-License-Identifier: MIT -''' -Patches the SD model and VAE to make it possible to generate seamlessly tilable -graphics. Horizontal and vertical direction are configurable separately. -''' -from typing import Optional - -import torch -from torch.nn import functional as F -from torch.nn.modules.utils import _pair - - -def flatten_modules(m): - '''Return submodules of module m in flattened form.''' - yield m - for c in m.children(): - yield from flatten_modules(c) - -# from: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py -def make_seamless_xy(model, x, y): - for layer in flatten_modules(model): - if type(layer) == torch.nn.Conv2d: - layer.padding_modeX = 'circular' if x else 'constant' - layer.padding_modeY = 'circular' if y else 'constant' - layer.paddingX = (layer._reversed_padding_repeated_twice[0], layer._reversed_padding_repeated_twice[1], 0, 0) - layer.paddingY = (0, 0, layer._reversed_padding_repeated_twice[2], layer._reversed_padding_repeated_twice[3]) - layer._conv_forward = __replacementConv2DConvForward.__get__(layer, torch.nn.Conv2d) - -def restore_conv2d_methods(model): - for layer in flatten_modules(model): - if type(layer) == torch.nn.Conv2d: - layer._conv_forward = torch.nn.Conv2d._conv_forward.__get__(layer, torch.nn.Conv2d) - -def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]): - working = F.pad(input, self.paddingX, mode=self.padding_modeX) - working = F.pad(working, self.paddingY, mode=self.padding_modeY) - return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups) - - -class MakeModelTileable: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "vae": ("VAE",), - "tile_x": (["disabled", "enabled"], { "default": "disabled", }), - "tile_y": (["disabled", "enabled"], { "default": "disabled", }), - } - } - - RETURN_TYPES = ("MODEL", "VAE") - FUNCTION = "patch_models" - - CATEGORY = "advanced/patchers" - - def patch_models(self, model, vae, tile_x, tile_y): - tile_x = (tile_x == 'enabled') - tile_y = (tile_y == 'enabled') - # XXX ideally, we'd return a clone of the model, not patch the model itself - #model = model.clone() - #vae = vae.???() - - restore_conv2d_methods(model.model) - restore_conv2d_methods(vae.first_stage_model) - make_seamless_xy(model.model, tile_x, tile_y) - make_seamless_xy(vae.first_stage_model, tile_x, tile_y) - return (model, vae) - - -NODE_CLASS_MAPPINGS = { - "MakeModelTileable": MakeModelTileable, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "MakeModelTileable": "Patch model tileability" -} diff --git a/comfy_extras/mara/mask_ops.py b/comfy_extras/mara/mask_ops.py deleted file mode 100644 index bd870eab4..000000000 --- a/comfy_extras/mara/mask_ops.py +++ /dev/null @@ -1,68 +0,0 @@ -# Mara Huldra 2023 -# SPDX-License-Identifier: MIT -''' -Extra mask operations. -''' -import numpy as np -import rembg -import torch - - -class BinarizeMask: - '''Binarize (threshold) a mask.''' - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - "threshold": ("INT", { - "default": 250, - "min": 0, - "max": 255, - "step": 1, - }), - }, - } - - RETURN_TYPES = ("MASK",) - FUNCTION = "binarize" - - CATEGORY = "mask" - - def binarize(self, mask, threshold): - t = torch.Tensor([threshold / 255.]) - s = (mask >= t).float() - return (s,) - - -class ImageCutout: - '''Perform basic image cutout (adds alpha channel from mask).''' - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "mask": ("MASK",), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "cutout" - - CATEGORY = "image/postprocessing" - - def cutout(self, image, mask): - # XXX check compatible dimensions. - o = np.zeros((image.shape[0], image.shape[1], image.shape[2], 4)) - o[:, :, :, 0:3] = image.cpu().numpy() - o[:, :, :, 3] = mask.cpu().numpy() - return (torch.from_numpy(o),) - - -NODE_CLASS_MAPPINGS = { - "BinarizeMask": BinarizeMask, - "ImageCutout": ImageCutout, -} - diff --git a/comfy_extras/mara/palettize.py b/comfy_extras/mara/palettize.py deleted file mode 100644 index 36d3155b4..000000000 --- a/comfy_extras/mara/palettize.py +++ /dev/null @@ -1,155 +0,0 @@ -# Mara Huldra 2023 -# SPDX-License-Identifier: MIT -''' -Palettize an image. -''' -import os - -import numpy as np -from PIL import Image -import torch - -PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../..', 'palettes') -PAL_EXT = '.png' - -QUANTIZE_METHODS = { - 'median_cut': Image.Quantize.MEDIANCUT, - 'max_coverage': Image.Quantize.MAXCOVERAGE, - 'fast_octree': Image.Quantize.FASTOCTREE, -} - -# Determine optimal number of colors. -# FROM: astropulse/sd-palettize -# -# Use FASTOCTREE for determining the best k, as it is -# - its faster -# - it does a better job fitting the image to lower color counts than the other options -# Max converge is best for reducing an image's colors more accurately, but -# since for best k we only care about the best number of colors, a faster more -# predictable method is better. -# (Astropulse, 2023-06-05) -def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE): - # Convert the image to RGB mode - image = image.convert("RGB") - - # Prepare arrays for distortion calculation - pixels = np.array(image) - pixel_indices = np.reshape(pixels, (-1, 3)) - - # Calculate distortion for different values of k - distortions = [] - for k in range(1, max_k + 1): - quantized_image = image.quantize(colors=k, method=quantize_method, kmeans=k, dither=0) - centroids = np.array(quantized_image.getpalette()[:k * 3]).reshape(-1, 3) - - # Calculate distortions - distances = np.linalg.norm(pixel_indices[:, np.newaxis] - centroids, axis=2) - min_distances = np.min(distances, axis=1) - distortions.append(np.sum(min_distances ** 2)) - - # Calculate the rate of change of distortions - rate_of_change = np.diff(distortions) / np.array(distortions[:-1]) - - # Find the elbow point (best k value) - if len(rate_of_change) == 0: - best_k = 2 - else: - elbow_index = np.argmax(rate_of_change) + 1 - best_k = elbow_index + 2 - - return best_k - -palette_warned = False - -def list_palettes(): - global palette_warned - palettes = [] - try: - for filename in os.listdir(PALETTES_PATH): - if filename.endswith(PAL_EXT): - palettes.append(filename[0:-len(PAL_EXT)]) - except FileNotFoundError: - pass - if not palettes and not palette_warned: - palette_warned = True - print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.") - return palettes - - -def get_image_colors(pal_img): - palette = [] - pal_img = pal_img.convert('RGB') - for i in pal_img.getcolors(16777216): - palette.append(i[1][0]) - palette.append(i[1][1]) - palette.append(i[1][2]) - return palette - - -def load_palette(name): - return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT))) - - -class ImagePalettize: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "palette": (["auto_best_k", "auto_fixed_k"] + list_palettes(), { - "default": "auto_best_k", - }), - "max_k": ("INT", { - "default": 64, - "min": 1, - "max": 256, - "step": 1, - }), - "method": (list(QUANTIZE_METHODS.keys()), { - "default": "max_coverage", - }), - } - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "palettize" - - CATEGORY = "image/postprocessing" - - def palettize(self, image, palette, max_k, method): - k = None - pal_img = None - if palette not in {'auto_best_k', 'auto_fixed_k'}: - pal_entries = load_palette(palette) - k = len(pal_entries) // 3 - pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette - pal_img.putpalette(pal_entries) - - results = [] - - for i in image: - i = 255. * i.cpu().numpy() - i = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - - if palette == 'auto_best_k': - k = determine_best_k(i, max_k) - print(f'Auto number of colors: {k}') - elif palette == 'auto_fixed_k': - k = max_k - - i = i.quantize(colors=k, method=QUANTIZE_METHODS[method], kmeans=k, dither=0, palette=pal_img) - i = i.convert('RGB') - - results.append(np.array(i)) - - result = np.array(results).astype(np.float32) / 255.0 - return (torch.from_numpy(result), ) - - -NODE_CLASS_MAPPINGS = { - "ImagePalettize": ImagePalettize, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "ImagePalettize": "ImagePalettize" -} diff --git a/comfy_extras/mara/patterngen.py b/comfy_extras/mara/patterngen.py deleted file mode 100644 index 76d588af6..000000000 --- a/comfy_extras/mara/patterngen.py +++ /dev/null @@ -1,45 +0,0 @@ -# Mara Huldra 2023 -# SPDX-License-Identifier: MIT -''' -Simple image pattern generators. -''' -import os - -import numpy as np -from PIL import Image -import torch - -MAX_RESOLUTION = 8192 - -class ImageSolidColor: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "width": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 64, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "r": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), - "g": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), - "b": ("INT", {"default": 0, "min": 0, "max": 255, "step": 1}), - } - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "render" - - CATEGORY = "image/pattern" - - def render(self, width, height, r, g, b): - color = torch.tensor([r, g, b]) / 255.0 - result = color.expand(1, height, width, 3) - return (result, ) - - -NODE_CLASS_MAPPINGS = { - "ImageSolidColor": ImageSolidColor, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageSolidColor": "Solid Color", -} - diff --git a/comfy_extras/mara/remove_background.py b/comfy_extras/mara/remove_background.py deleted file mode 100644 index 6ba498c5b..000000000 --- a/comfy_extras/mara/remove_background.py +++ /dev/null @@ -1,119 +0,0 @@ -# Mara Huldra 2023 -# SPDX-License-Identifier: MIT -''' -Estimate what pixels belong to the background and perform a cut-out, using the 'rembg' models. -''' -import numpy as np -import rembg -import torch - - -MODELS = rembg.sessions.sessions_names - - -class ImageRemoveBackground: - '''Remove background from image (adds an alpha channel)''' - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "model": (MODELS, { - "default": "u2net", - }), - "alpha_matting": (["disabled", "enabled"], { - "default": "disabled", - }), - "am_foreground_thr": ("INT", { - "default": 240, - "min": 0, - "max": 255, - "step": 1, - }), - "am_background_thr": ("INT", { - "default": 10, - "min": 0, - "max": 255, - "step": 1, - }), - "am_erode_size": ("INT", { - "default": 10, - "min": 0, - "max": 255, - "step": 1, - }), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "remove_background" - - CATEGORY = "image/postprocessing" - - def remove_background(self, image, model, alpha_matting, am_foreground_thr, am_background_thr, am_erode_size): - session = rembg.new_session(model) - results = [] - - for i in image: - i = 255. * i.cpu().numpy() - i = np.clip(i, 0, 255).astype(np.uint8) - i = rembg.remove(i, - alpha_matting=(alpha_matting == "enabled"), - alpha_matting_foreground_threshold=am_foreground_thr, - alpha_matting_background_threshold=am_background_thr, - alpha_matting_erode_size=am_erode_size, - session=session, - ) - results.append(i.astype(np.float32) / 255.0) - - s = torch.from_numpy(np.array(results)) - return (s,) - -class ImageEstimateForegroundMask: - ''' - Return a mask of which pixels are estimated to belong to foreground. - Only estimates the mask, does not perform cutout like - ImageRemoveBackground. - ''' - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "model": (MODELS, { - "default": "u2net", - }), - }, - } - - RETURN_TYPES = ("MASK",) - FUNCTION = "estimate_background" - - CATEGORY = "image/postprocessing" - - def estimate_background(self, image, model): - session = rembg.new_session(model) - results = [] - - for i in image: - i = 255. * i.cpu().numpy() - i = np.clip(i, 0, 255).astype(np.uint8) - i = rembg.remove(i, only_mask=True, session=session) - results.append(i.astype(np.float32) / 255.0) - - s = torch.from_numpy(np.array(results)) - print(s.shape) - return (s,) - - -NODE_CLASS_MAPPINGS = { - "ImageRemoveBackground": ImageRemoveBackground, - "ImageEstimateForegroundMask": ImageEstimateForegroundMask, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageRemoveBackground": "Remove Background (rembg)", - "ImageEstimateForegroundMask": "Estimate Foreground (rembg)", -} diff --git a/comfy_extras/nodes/__init__.py b/comfy_extras/nodes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes/nodes_canny.py similarity index 100% rename from comfy_extras/nodes_canny.py rename to comfy_extras/nodes/nodes_canny.py diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes/nodes_clip_sdxl.py similarity index 98% rename from comfy_extras/nodes_clip_sdxl.py rename to comfy_extras/nodes/nodes_clip_sdxl.py index dcf8859fa..94308edef 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes/nodes_clip_sdxl.py @@ -1,5 +1,6 @@ import torch -from nodes import MAX_RESOLUTION +from comfy.nodes.common import MAX_RESOLUTION + class CLIPTextEncodeSDXLRefiner: @classmethod diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes/nodes_hypernetwork.py similarity index 100% rename from comfy_extras/nodes_hypernetwork.py rename to comfy_extras/nodes/nodes_hypernetwork.py diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py similarity index 99% rename from comfy_extras/nodes_mask.py rename to comfy_extras/nodes/nodes_mask.py index 15377af14..158a71c0d 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes/nodes_mask.py @@ -1,6 +1,7 @@ import torch -from nodes import MAX_RESOLUTION +from comfy.nodes.common import MAX_RESOLUTION + class LatentCompositeMasked: @classmethod diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes/nodes_model_merging.py similarity index 100% rename from comfy_extras/nodes_model_merging.py rename to comfy_extras/nodes/nodes_model_merging.py diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py similarity index 100% rename from comfy_extras/nodes_post_processing.py rename to comfy_extras/nodes/nodes_post_processing.py diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes/nodes_rebatch.py similarity index 100% rename from comfy_extras/nodes_rebatch.py rename to comfy_extras/nodes/nodes_rebatch.py diff --git a/comfy_extras/nodes_tomesd.py b/comfy_extras/nodes/nodes_tomesd.py similarity index 100% rename from comfy_extras/nodes_tomesd.py rename to comfy_extras/nodes/nodes_tomesd.py diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes/nodes_upscale_model.py similarity index 100% rename from comfy_extras/nodes_upscale_model.py rename to comfy_extras/nodes/nodes_upscale_model.py diff --git a/execution.py b/execution.py index 5b577228c..0f73af9f7 100644 --- a/execution.py +++ b/execution.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio import copy import datetime -import gc import heapq import threading import traceback @@ -14,7 +13,8 @@ import sys import torch -import nodes +from comfy.nodes.package import import_all_nodes_in_workspace +nodes = import_all_nodes_in_workspace() import comfy.model_management """ @@ -111,7 +111,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): results = [] if input_is_list: if allow_interrupt: - nodes.before_node_execution() + comfy.model_management.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)(**input_data_all)) elif max_len_input == 0: if allow_interrupt: @@ -120,7 +120,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): else: for i in range(max_len_input): if allow_interrupt: - nodes.before_node_execution() + comfy.model_management.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) return results @@ -368,7 +368,7 @@ class PromptExecutor: del d def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - nodes.interrupt_processing(False) + comfy.model_management.interrupt_current_processing(False) if "client_id" in extra_data: self.server.client_id = extra_data["client_id"] diff --git a/main.py b/main.py index dda776ac6..b7700b34a 100644 --- a/main.py +++ b/main.py @@ -69,7 +69,6 @@ import yaml import execution import server from server import BinaryEventTypes -from nodes import init_custom_nodes import comfy.model_management @@ -143,7 +142,6 @@ if __name__ == "__main__": for config_path in itertools.chain(*args.extra_model_paths_config): load_extra_path_config(config_path) - init_custom_nodes() server.add_routes() hijack_progress(server) diff --git a/nodes.py b/nodes.py index 4598455b0..965589994 100644 --- a/nodes.py +++ b/nodes.py @@ -14,9 +14,6 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch -sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) - - import comfy.diffusers_load import comfy.samplers import comfy.sample @@ -32,14 +29,8 @@ import importlib import folder_paths import latent_preview +from comfy.nodes.common import MAX_RESOLUTION -def before_node_execution(): - comfy.model_management.throw_exception_if_processing_interrupted() - -def interrupt_processing(value=True): - comfy.model_management.interrupt_current_processing(value) - -MAX_RESOLUTION=8192 class CLIPTextEncode: @classmethod @@ -1630,77 +1621,4 @@ NODE_DISPLAY_NAME_MAPPINGS = { # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", -} - -def load_custom_node(module_path, ignore=set()): - module_name = os.path.basename(module_path) - if os.path.isfile(module_path): - sp = os.path.splitext(module_path) - module_name = sp[0] - try: - if os.path.isfile(module_path): - module_spec = importlib.util.spec_from_file_location(module_name, module_path) - else: - module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) - module = importlib.util.module_from_spec(module_spec) - sys.modules[module_name] = module - module_spec.loader.exec_module(module) - if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: - for name in module.NODE_CLASS_MAPPINGS: - if name not in ignore: - NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name] - if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: - NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) - return True - else: - print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") - return False - except Exception as e: - print(traceback.format_exc()) - print(f"Cannot import {module_path} module for custom nodes:", e) - return False - -def load_custom_nodes(): - base_node_names = set(NODE_CLASS_MAPPINGS.keys()) - node_paths = folder_paths.get_folder_paths("custom_nodes") - node_import_times = [] - for custom_node_path in node_paths: - possible_modules = os.listdir(custom_node_path) - if "__pycache__" in possible_modules: - possible_modules.remove("__pycache__") - - for possible_module in possible_modules: - module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - if module_path.endswith(".disabled"): continue - time_before = time.perf_counter() - success = load_custom_node(module_path, base_node_names) - node_import_times.append((time.perf_counter() - time_before, module_path, success)) - - if len(node_import_times) > 0: - print("\nImport times for custom nodes:") - for n in sorted(node_import_times): - if n[2]: - import_message = "" - else: - import_message = " (IMPORT FAILED)" - print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) - print() - -def init_custom_nodes(): - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "k_centroid_downscale.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "make_model_tileable.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "mask_ops.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "palettize.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "patterngen.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras", "mara"), "remove_background.py")) - load_custom_nodes() +} \ No newline at end of file diff --git a/server.py b/server.py index 99db7939d..631abbcc5 100644 --- a/server.py +++ b/server.py @@ -17,15 +17,16 @@ from aiohttp import web import execution import folder_paths -import nodes import mimetypes from comfy.digest import digest from comfy.cli_args import args import comfy.utils import comfy.model_management +from comfy.nodes.package import import_all_nodes_in_workspace from comfy.vendor.appdirs import user_data_dir +nodes = import_all_nodes_in_workspace() class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -488,7 +489,7 @@ class PromptServer(): @routes.post("/interrupt") async def post_interrupt(request): - nodes.interrupt_processing() + comfy.model_management.interrupt_current_processing() return web.Response(status=200) @routes.post("/history")