Add support for ComfyUI Manager

This commit is contained in:
doctorpangloss 2025-07-15 14:06:29 -07:00
parent 4e91556820
commit bf3345e083
8 changed files with 160 additions and 47 deletions

View File

@ -203,7 +203,9 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
# at this stage, it's safe to import nodes # at this stage, it's safe to import nodes
hook_breaker_ac10a0.save_functions() hook_breaker_ac10a0.save_functions()
server.nodes = get_nodes() nodes_to_import = get_nodes()
logger.debug(f"Imported {len(nodes_to_import)} nodes")
server.nodes = nodes_to_import
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
# as a side effect, this also populates the nodes for execution # as a side effect, this also populates the nodes for execution

View File

@ -1003,6 +1003,14 @@ class PromptServer(ExecutorToClientProgress):
self.client_session = aiohttp.ClientSession(timeout=timeout) self.client_session = aiohttp.ClientSession(timeout=timeout)
def add_routes(self): def add_routes(self):
# a mitigation for vanilla comfyui custom nodes that are stateful and add routes to a global
# prompt server instance. this is not a recommended pattern, but this mitigation is here to
# support it
from ..nodes.vanilla_node_importing import prompt_server_instance_routes
for route in prompt_server_instance_routes.routes:
self.routes.route(route.method, route.path)(route.handler)
prompt_server_instance_routes.clear()
self.user_manager.add_routes(self.routes) self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes)
# todo: needs to use module directories # todo: needs to use module directories

View File

@ -147,6 +147,11 @@ class SupportedExtensions:
p: FolderNames = self.parent() p: FolderNames = self.parent()
p.remove_all_supported_extensions(self.folder_name) p.remove_all_supported_extensions(self.folder_name)
def __len__(self):
p: FolderNames = self.parent()
return len(list(p.supported_extensions(self.folder_name)))
__ior__ = _append_any __ior__ = _append_any
add = _append_any add = _append_any
update = _append_any update = _append_any

View File

@ -1,6 +1,15 @@
from typing import Callable, NamedTuple
class RouteTuple(NamedTuple):
method: str
path: str
func: Callable
class _RoutesWrapper: class _RoutesWrapper:
def __init__(self): def __init__(self):
self.routes = [] self.routes: list[RouteTuple] = []
def _decorator_factory(self, method): def _decorator_factory(self, method):
def decorator(path): def decorator(path):
@ -8,7 +17,8 @@ class _RoutesWrapper:
from ..cmd.server import PromptServer from ..cmd.server import PromptServer
if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper): if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper):
getattr(PromptServer.instance.routes, method)(path)(func) getattr(PromptServer.instance.routes, method)(path)(func)
self.routes.append((method, path, func)) else:
self.routes.append(RouteTuple(method, path, func))
return func return func
return wrapper return wrapper
@ -39,5 +49,8 @@ class _RoutesWrapper:
def route(self, method, path): def route(self, method, path):
return self._decorator_factory(method.lower())(path) return self._decorator_factory(method.lower())(path)
def clear(self):
self.routes.clear()
prompt_server_instance_routes = _RoutesWrapper() prompt_server_instance_routes = _RoutesWrapper()

View File

@ -5,10 +5,11 @@ import torch
from .autoencoder_dc import AutoencoderDC from .autoencoder_dc import AutoencoderDC
logger = logging.getLogger(__name__)
try: try:
import torchaudio # pylint: disable=import-error import torchaudio # pylint: disable=import-error
except: except:
logging.warning("torchaudio missing, ACE model will be broken") logger.warning("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1 from .music_vocoder import ADaMoSHiFiGANV1

View File

@ -1,23 +1,28 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py # Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
import logging
logger = logging.getLogger(__name__)
try: try:
from torchaudio.transforms import MelScale # pylint: disable=import-error from torchaudio.transforms import MelScale # pylint: disable=import-error
except: except:
logging.warning("torchaudio missing, ACE model will be broken") logger.warning("torchaudio missing, ACE model will be broken")
from .... import model_management from .... import model_management
class LinearSpectrogram(nn.Module): class LinearSpectrogram(nn.Module):
def __init__( def __init__(
self, self,
n_fft=2048, n_fft=2048,
win_length=2048, win_length=2048,
hop_length=512, hop_length=512,
center=False, center=False,
mode="pow2_sqrt", mode="pow2_sqrt",
): ):
super().__init__() super().__init__()
@ -64,15 +69,15 @@ class LinearSpectrogram(nn.Module):
class LogMelSpectrogram(nn.Module): class LogMelSpectrogram(nn.Module):
def __init__( def __init__(
self, self,
sample_rate=44100, sample_rate=44100,
n_fft=2048, n_fft=2048,
win_length=2048, win_length=2048,
hop_length=512, hop_length=512,
n_mels=128, n_mels=128,
center=False, center=False,
f_min=0.0, f_min=0.0,
f_max=None, f_max=None,
): ):
super().__init__() super().__init__()

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import contextvars
import importlib import importlib
import logging import logging
import os import os
@ -7,20 +8,25 @@ import sys
import time import time
import types import types
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
from typing import Dict, Iterable from typing import Dict, Iterable
from . import base_nodes from . import base_nodes
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from ..cmd import folder_paths
from ..component_model.plugins import prompt_server_instance_routes from ..component_model.plugins import prompt_server_instance_routes
from ..distributed.executors import ContextVarExecutor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StreamToLogger: class StreamToLogger:
""" """
File-like stream object that redirects writes to a logger instance. File-like stream object that redirects writes to a logger instance.
This is used to capture print() statements from modules during import. This is used to capture print() statements from modules during import.
""" """
def __init__(self, logger: logging.Logger, log_level=logging.INFO): def __init__(self, logger: logging.Logger, log_level=logging.INFO):
self.logger = logger self.logger = logger
self.log_level = log_level self.log_level = log_level
@ -45,9 +51,10 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str])
def execute_script(script_path): def execute_script(script_path):
module_name = splitext(script_path)[0] module_name = splitext(script_path)[0]
try: try:
spec = importlib.util.spec_from_file_location(module_name, script_path) with _stdout_intercept(module_name):
module = importlib.util.module_from_spec(spec) spec = importlib.util.spec_from_file_location(module_name, script_path)
spec.loader.exec_module(module) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to execute startup-script: {script_path} / {e}") logger.error(f"Failed to execute startup-script: {script_path} / {e}")
@ -68,22 +75,56 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str])
script_path = join(module_path, "prestartup_script.py") script_path = join(module_path, "prestartup_script.py")
if exists(script_path): if exists(script_path):
time_before = time.perf_counter() if "comfyui-manager" in module_path.lower():
success = execute_script(script_path) os.environ['COMFYUI_PATH'] = str(folder_paths.base_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) os.environ['COMFYUI_FOLDERS_BASE_PATH'] = str(folder_paths.models_dir)
if len(node_prestartup_times) > 0: # Monkey-patch ComfyUI-Manager's security check to prevent it from crashing on startup
logger.debug("\nPrestartup times for custom nodes:") # and its logging handler to prevent it from taking over logging.
for n in sorted(node_prestartup_times): glob_path = join(module_path, "glob")
if n[2]: glob_path_added = False
import_message = "" original_add_handler = logging.Logger.addHandler
else:
import_message = " (PRESTARTUP FAILED)" def no_op_add_handler(self, handler):
logger.debug("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) logger.info(f"Skipping addHandler for {type(handler).__name__} during ComfyUI-Manager prestartup.")
try:
sys.path.insert(0, glob_path)
glob_path_added = True
# Patch security_check
import security_check
original_check = security_check.security_check
def patched_security_check():
try:
return original_check()
except Exception as e:
logger.error(f"ComfyUI-Manager security_check failed but was caught gracefully: {e}", exc_info=e)
security_check.security_check = patched_security_check
logger.info("Patched ComfyUI-Manager's security_check to fail gracefully.")
# Patch logging
logging.Logger.addHandler = no_op_add_handler
logger.info("Patched logging.Logger.addHandler to prevent ComfyUI-Manager from adding a logging handler.")
time_before = time.perf_counter()
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
except Exception as e:
logger.error(f"Failed to patch and execute ComfyUI-Manager's prestartup script: {e}", exc_info=e)
finally:
if glob_path_added and glob_path in sys.path:
sys.path.remove(glob_path)
logging.Logger.addHandler = original_add_handler
else:
time_before = time.perf_counter()
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
@contextmanager @contextmanager
def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNodes: def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNodes:
if module.__name__ == "ComfyUI-Manager": if module.__name__.lower() == "comfyui-manager":
from ..cmd import folder_paths from ..cmd import folder_paths
old_file = folder_paths.__file__ old_file = folder_paths.__file__
@ -93,21 +134,29 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNod
folder_paths.__file__ = new_path folder_paths.__file__ = new_path
# mitigate JS copy # mitigate JS copy
sys.modules['nodes'].EXTENSION_WEB_DIRS = {} sys.modules['nodes'].EXTENSION_WEB_DIRS = {}
yield ExportedNodes() yield ExportedNodes()
finally: finally:
folder_paths.__file__ = old_file folder_paths.__file__ = old_file
# todo: mitigate "/manager/reboot" # todo: mitigate "/manager/reboot"
# todo: mitigate process_wrap # todo: mitigate process_wrap
# todo: unfortunately, we shouldn't restore the patches here, they will have to be applied forever.
# concurrent.futures.ThreadPoolExecutor = _ThreadPoolExecutor
# threading.Thread.start = original_thread_start
else: else:
# redirect stdout to the module's logger during import yield ExportedNodes()
original_stdout = sys.stdout
module_logger = logging.getLogger(module.__name__) @contextmanager
def _stdout_intercept(name: str):
original_stdout = sys.stdout
try:
module_logger = logging.getLogger(name)
sys.stdout = StreamToLogger(module_logger, logging.INFO) sys.stdout = StreamToLogger(module_logger, logging.INFO)
try: yield
yield ExportedNodes() finally:
finally: sys.stdout = original_stdout
# Restore original stdout to ensure this change is temporary
sys.stdout = original_stdout
def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes: def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
@ -127,7 +176,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
module = importlib.util.module_from_spec(module_spec) module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module sys.modules[module_name] = module
with _exec_mitigations(module, module_path) as mitigated_exported_nodes: with _exec_mitigations(module, module_path) as mitigated_exported_nodes, _stdout_intercept(module_name):
module_spec.loader.exec_module(module) module_spec.loader.exec_module(module)
exported_nodes.update(mitigated_exported_nodes) exported_nodes.update(mitigated_exported_nodes)
@ -197,6 +246,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol
from .. import node_helpers from .. import node_helpers
from .. import __version__ from .. import __version__
import concurrent.futures
import threading
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol): for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
module_short_name = module.__name__.split(".")[-1] module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module sys.modules[module_short_name] = module
@ -237,6 +288,22 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths) node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths)
_ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor
original_thread_start = threading.Thread.start
concurrent.futures.ThreadPoolExecutor = ContextVarExecutor
# mitigate missing folder names and paths context
def patched_start(self, *args, **kwargs):
if not hasattr(self.run, '__wrapped_by_context__'):
ctx = contextvars.copy_context()
self.run = partial(ctx.run, self.run)
setattr(self.run, '__wrapped_by_context__', True)
original_thread_start(self, *args, **kwargs)
if not getattr(threading.Thread.start, '__is_patched_by_us', False):
threading.Thread.start = patched_start
setattr(threading.Thread.start, '__is_patched_by_us', True)
logger.info("Patched `threading.Thread.start` to propagate contextvars.")
_vanilla_load_importing_execute_prestartup_script(node_paths) _vanilla_load_importing_execute_prestartup_script(node_paths)
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths) vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
return vanilla_custom_nodes return vanilla_custom_nodes

View File

@ -57,7 +57,7 @@ dependencies = [
"pyjwt[crypto]", "pyjwt[crypto]",
"kornia>=0.7.0", "kornia>=0.7.0",
"mpmath>=1.0,!=1.4.0a0", "mpmath>=1.0,!=1.4.0a0",
"huggingface_hub[hf_transfer]", "huggingface_hub[hf_transfer]>0.20",
"lazy-object-proxy", "lazy-object-proxy",
"lazy_loader>=0.3", "lazy_loader>=0.3",
"can_ada", "can_ada",
@ -134,6 +134,18 @@ dev = [
"astroid", "astroid",
] ]
comfyui_manager = [
"GitPython",
"PyGithub",
"matrix-client==0.4.0",
"rich",
"typing-extensions",
"toml",
"uv",
"chardet",
"pip",
]
[project.optional-dependencies] [project.optional-dependencies]
cpu = [ cpu = [
"torch", "torch",