Add support for ComfyUI Manager

This commit is contained in:
doctorpangloss 2025-07-15 14:06:29 -07:00
parent 4e91556820
commit bf3345e083
8 changed files with 160 additions and 47 deletions

View File

@ -203,7 +203,9 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
# at this stage, it's safe to import nodes
hook_breaker_ac10a0.save_functions()
server.nodes = get_nodes()
nodes_to_import = get_nodes()
logger.debug(f"Imported {len(nodes_to_import)} nodes")
server.nodes = nodes_to_import
hook_breaker_ac10a0.restore_functions()
# as a side effect, this also populates the nodes for execution

View File

@ -1003,6 +1003,14 @@ class PromptServer(ExecutorToClientProgress):
self.client_session = aiohttp.ClientSession(timeout=timeout)
def add_routes(self):
# a mitigation for vanilla comfyui custom nodes that are stateful and add routes to a global
# prompt server instance. this is not a recommended pattern, but this mitigation is here to
# support it
from ..nodes.vanilla_node_importing import prompt_server_instance_routes
for route in prompt_server_instance_routes.routes:
self.routes.route(route.method, route.path)(route.handler)
prompt_server_instance_routes.clear()
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
# todo: needs to use module directories

View File

@ -147,6 +147,11 @@ class SupportedExtensions:
p: FolderNames = self.parent()
p.remove_all_supported_extensions(self.folder_name)
def __len__(self):
p: FolderNames = self.parent()
return len(list(p.supported_extensions(self.folder_name)))
__ior__ = _append_any
add = _append_any
update = _append_any

View File

@ -1,6 +1,15 @@
from typing import Callable, NamedTuple
class RouteTuple(NamedTuple):
method: str
path: str
func: Callable
class _RoutesWrapper:
def __init__(self):
self.routes = []
self.routes: list[RouteTuple] = []
def _decorator_factory(self, method):
def decorator(path):
@ -8,7 +17,8 @@ class _RoutesWrapper:
from ..cmd.server import PromptServer
if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper):
getattr(PromptServer.instance.routes, method)(path)(func)
self.routes.append((method, path, func))
else:
self.routes.append(RouteTuple(method, path, func))
return func
return wrapper
@ -39,5 +49,8 @@ class _RoutesWrapper:
def route(self, method, path):
return self._decorator_factory(method.lower())(path)
def clear(self):
self.routes.clear()
prompt_server_instance_routes = _RoutesWrapper()

View File

@ -5,10 +5,11 @@ import torch
from .autoencoder_dc import AutoencoderDC
logger = logging.getLogger(__name__)
try:
import torchaudio # pylint: disable=import-error
except:
logging.warning("torchaudio missing, ACE model will be broken")
logger.warning("torchaudio missing, ACE model will be broken")
import torchvision.transforms as transforms
from .music_vocoder import ADaMoSHiFiGANV1

View File

@ -1,23 +1,28 @@
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
import logging
import torch
import torch.nn as nn
from torch import Tensor
import logging
logger = logging.getLogger(__name__)
try:
from torchaudio.transforms import MelScale # pylint: disable=import-error
except:
logging.warning("torchaudio missing, ACE model will be broken")
logger.warning("torchaudio missing, ACE model will be broken")
from .... import model_management
class LinearSpectrogram(nn.Module):
def __init__(
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
self,
n_fft=2048,
win_length=2048,
hop_length=512,
center=False,
mode="pow2_sqrt",
):
super().__init__()
@ -64,15 +69,15 @@ class LinearSpectrogram(nn.Module):
class LogMelSpectrogram(nn.Module):
def __init__(
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
self,
sample_rate=44100,
n_fft=2048,
win_length=2048,
hop_length=512,
n_mels=128,
center=False,
f_min=0.0,
f_max=None,
):
super().__init__()

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextvars
import importlib
import logging
import os
@ -7,20 +8,25 @@ import sys
import time
import types
from contextlib import contextmanager
from functools import partial
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
from typing import Dict, Iterable
from . import base_nodes
from .package_typing import ExportedNodes
from ..cmd import folder_paths
from ..component_model.plugins import prompt_server_instance_routes
from ..distributed.executors import ContextVarExecutor
logger = logging.getLogger(__name__)
class StreamToLogger:
"""
File-like stream object that redirects writes to a logger instance.
This is used to capture print() statements from modules during import.
"""
def __init__(self, logger: logging.Logger, log_level=logging.INFO):
self.logger = logger
self.log_level = log_level
@ -45,9 +51,10 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str])
def execute_script(script_path):
module_name = splitext(script_path)[0]
try:
spec = importlib.util.spec_from_file_location(module_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
with _stdout_intercept(module_name):
spec = importlib.util.spec_from_file_location(module_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return True
except Exception as e:
logger.error(f"Failed to execute startup-script: {script_path} / {e}")
@ -68,22 +75,56 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str])
script_path = join(module_path, "prestartup_script.py")
if exists(script_path):
time_before = time.perf_counter()
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_prestartup_times) > 0:
logger.debug("\nPrestartup times for custom nodes:")
for n in sorted(node_prestartup_times):
if n[2]:
import_message = ""
else:
import_message = " (PRESTARTUP FAILED)"
logger.debug("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
if "comfyui-manager" in module_path.lower():
os.environ['COMFYUI_PATH'] = str(folder_paths.base_path)
os.environ['COMFYUI_FOLDERS_BASE_PATH'] = str(folder_paths.models_dir)
# Monkey-patch ComfyUI-Manager's security check to prevent it from crashing on startup
# and its logging handler to prevent it from taking over logging.
glob_path = join(module_path, "glob")
glob_path_added = False
original_add_handler = logging.Logger.addHandler
def no_op_add_handler(self, handler):
logger.info(f"Skipping addHandler for {type(handler).__name__} during ComfyUI-Manager prestartup.")
try:
sys.path.insert(0, glob_path)
glob_path_added = True
# Patch security_check
import security_check
original_check = security_check.security_check
def patched_security_check():
try:
return original_check()
except Exception as e:
logger.error(f"ComfyUI-Manager security_check failed but was caught gracefully: {e}", exc_info=e)
security_check.security_check = patched_security_check
logger.info("Patched ComfyUI-Manager's security_check to fail gracefully.")
# Patch logging
logging.Logger.addHandler = no_op_add_handler
logger.info("Patched logging.Logger.addHandler to prevent ComfyUI-Manager from adding a logging handler.")
time_before = time.perf_counter()
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
except Exception as e:
logger.error(f"Failed to patch and execute ComfyUI-Manager's prestartup script: {e}", exc_info=e)
finally:
if glob_path_added and glob_path in sys.path:
sys.path.remove(glob_path)
logging.Logger.addHandler = original_add_handler
else:
time_before = time.perf_counter()
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
@contextmanager
def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNodes:
if module.__name__ == "ComfyUI-Manager":
if module.__name__.lower() == "comfyui-manager":
from ..cmd import folder_paths
old_file = folder_paths.__file__
@ -93,21 +134,29 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNod
folder_paths.__file__ = new_path
# mitigate JS copy
sys.modules['nodes'].EXTENSION_WEB_DIRS = {}
yield ExportedNodes()
finally:
folder_paths.__file__ = old_file
# todo: mitigate "/manager/reboot"
# todo: mitigate process_wrap
# todo: unfortunately, we shouldn't restore the patches here, they will have to be applied forever.
# concurrent.futures.ThreadPoolExecutor = _ThreadPoolExecutor
# threading.Thread.start = original_thread_start
else:
# redirect stdout to the module's logger during import
original_stdout = sys.stdout
module_logger = logging.getLogger(module.__name__)
yield ExportedNodes()
@contextmanager
def _stdout_intercept(name: str):
original_stdout = sys.stdout
try:
module_logger = logging.getLogger(name)
sys.stdout = StreamToLogger(module_logger, logging.INFO)
try:
yield ExportedNodes()
finally:
# Restore original stdout to ensure this change is temporary
sys.stdout = original_stdout
yield
finally:
sys.stdout = original_stdout
def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
@ -127,7 +176,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
with _exec_mitigations(module, module_path) as mitigated_exported_nodes:
with _exec_mitigations(module, module_path) as mitigated_exported_nodes, _stdout_intercept(module_name):
module_spec.loader.exec_module(module)
exported_nodes.update(mitigated_exported_nodes)
@ -197,6 +246,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol
from .. import node_helpers
from .. import __version__
import concurrent.futures
import threading
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module
@ -237,6 +288,22 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths)
_ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor
original_thread_start = threading.Thread.start
concurrent.futures.ThreadPoolExecutor = ContextVarExecutor
# mitigate missing folder names and paths context
def patched_start(self, *args, **kwargs):
if not hasattr(self.run, '__wrapped_by_context__'):
ctx = contextvars.copy_context()
self.run = partial(ctx.run, self.run)
setattr(self.run, '__wrapped_by_context__', True)
original_thread_start(self, *args, **kwargs)
if not getattr(threading.Thread.start, '__is_patched_by_us', False):
threading.Thread.start = patched_start
setattr(threading.Thread.start, '__is_patched_by_us', True)
logger.info("Patched `threading.Thread.start` to propagate contextvars.")
_vanilla_load_importing_execute_prestartup_script(node_paths)
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
return vanilla_custom_nodes
return vanilla_custom_nodes

View File

@ -57,7 +57,7 @@ dependencies = [
"pyjwt[crypto]",
"kornia>=0.7.0",
"mpmath>=1.0,!=1.4.0a0",
"huggingface_hub[hf_transfer]",
"huggingface_hub[hf_transfer]>0.20",
"lazy-object-proxy",
"lazy_loader>=0.3",
"can_ada",
@ -134,6 +134,18 @@ dev = [
"astroid",
]
comfyui_manager = [
"GitPython",
"PyGithub",
"matrix-client==0.4.0",
"rich",
"typing-extensions",
"toml",
"uv",
"chardet",
"pip",
]
[project.optional-dependencies]
cpu = [
"torch",