mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Add support for ComfyUI Manager
This commit is contained in:
parent
4e91556820
commit
bf3345e083
@ -203,7 +203,9 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
|
||||
# at this stage, it's safe to import nodes
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
server.nodes = get_nodes()
|
||||
nodes_to_import = get_nodes()
|
||||
logger.debug(f"Imported {len(nodes_to_import)} nodes")
|
||||
server.nodes = nodes_to_import
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
# as a side effect, this also populates the nodes for execution
|
||||
|
||||
|
||||
@ -1003,6 +1003,14 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
def add_routes(self):
|
||||
# a mitigation for vanilla comfyui custom nodes that are stateful and add routes to a global
|
||||
# prompt server instance. this is not a recommended pattern, but this mitigation is here to
|
||||
# support it
|
||||
from ..nodes.vanilla_node_importing import prompt_server_instance_routes
|
||||
for route in prompt_server_instance_routes.routes:
|
||||
self.routes.route(route.method, route.path)(route.handler)
|
||||
prompt_server_instance_routes.clear()
|
||||
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
# todo: needs to use module directories
|
||||
|
||||
@ -147,6 +147,11 @@ class SupportedExtensions:
|
||||
p: FolderNames = self.parent()
|
||||
p.remove_all_supported_extensions(self.folder_name)
|
||||
|
||||
def __len__(self):
|
||||
p: FolderNames = self.parent()
|
||||
return len(list(p.supported_extensions(self.folder_name)))
|
||||
|
||||
|
||||
__ior__ = _append_any
|
||||
add = _append_any
|
||||
update = _append_any
|
||||
|
||||
@ -1,6 +1,15 @@
|
||||
from typing import Callable, NamedTuple
|
||||
|
||||
|
||||
class RouteTuple(NamedTuple):
|
||||
method: str
|
||||
path: str
|
||||
func: Callable
|
||||
|
||||
|
||||
class _RoutesWrapper:
|
||||
def __init__(self):
|
||||
self.routes = []
|
||||
self.routes: list[RouteTuple] = []
|
||||
|
||||
def _decorator_factory(self, method):
|
||||
def decorator(path):
|
||||
@ -8,7 +17,8 @@ class _RoutesWrapper:
|
||||
from ..cmd.server import PromptServer
|
||||
if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper):
|
||||
getattr(PromptServer.instance.routes, method)(path)(func)
|
||||
self.routes.append((method, path, func))
|
||||
else:
|
||||
self.routes.append(RouteTuple(method, path, func))
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
@ -39,5 +49,8 @@ class _RoutesWrapper:
|
||||
def route(self, method, path):
|
||||
return self._decorator_factory(method.lower())(path)
|
||||
|
||||
def clear(self):
|
||||
self.routes.clear()
|
||||
|
||||
|
||||
prompt_server_instance_routes = _RoutesWrapper()
|
||||
|
||||
@ -5,10 +5,11 @@ import torch
|
||||
|
||||
from .autoencoder_dc import AutoencoderDC
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
logger.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
from .music_vocoder import ADaMoSHiFiGANV1
|
||||
|
||||
@ -1,23 +1,28 @@
|
||||
# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from torchaudio.transforms import MelScale # pylint: disable=import-error
|
||||
except:
|
||||
logging.warning("torchaudio missing, ACE model will be broken")
|
||||
logger.warning("torchaudio missing, ACE model will be broken")
|
||||
|
||||
from .... import model_management
|
||||
|
||||
|
||||
class LinearSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
center=False,
|
||||
mode="pow2_sqrt",
|
||||
self,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
center=False,
|
||||
mode="pow2_sqrt",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -64,15 +69,15 @@ class LinearSpectrogram(nn.Module):
|
||||
|
||||
class LogMelSpectrogram(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
n_mels=128,
|
||||
center=False,
|
||||
f_min=0.0,
|
||||
f_max=None,
|
||||
self,
|
||||
sample_rate=44100,
|
||||
n_fft=2048,
|
||||
win_length=2048,
|
||||
hop_length=512,
|
||||
n_mels=128,
|
||||
center=False,
|
||||
f_min=0.0,
|
||||
f_max=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@ -7,20 +8,25 @@ import sys
|
||||
import time
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from . import base_nodes
|
||||
from .package_typing import ExportedNodes
|
||||
from ..cmd import folder_paths
|
||||
from ..component_model.plugins import prompt_server_instance_routes
|
||||
from ..distributed.executors import ContextVarExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamToLogger:
|
||||
"""
|
||||
File-like stream object that redirects writes to a logger instance.
|
||||
This is used to capture print() statements from modules during import.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: logging.Logger, log_level=logging.INFO):
|
||||
self.logger = logger
|
||||
self.log_level = log_level
|
||||
@ -45,9 +51,10 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str])
|
||||
def execute_script(script_path):
|
||||
module_name = splitext(script_path)[0]
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
with _stdout_intercept(module_name):
|
||||
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||
@ -68,22 +75,56 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str])
|
||||
|
||||
script_path = join(module_path, "prestartup_script.py")
|
||||
if exists(script_path):
|
||||
time_before = time.perf_counter()
|
||||
success = execute_script(script_path)
|
||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
if len(node_prestartup_times) > 0:
|
||||
logger.debug("\nPrestartup times for custom nodes:")
|
||||
for n in sorted(node_prestartup_times):
|
||||
if n[2]:
|
||||
import_message = ""
|
||||
else:
|
||||
import_message = " (PRESTARTUP FAILED)"
|
||||
logger.debug("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||
if "comfyui-manager" in module_path.lower():
|
||||
os.environ['COMFYUI_PATH'] = str(folder_paths.base_path)
|
||||
os.environ['COMFYUI_FOLDERS_BASE_PATH'] = str(folder_paths.models_dir)
|
||||
# Monkey-patch ComfyUI-Manager's security check to prevent it from crashing on startup
|
||||
# and its logging handler to prevent it from taking over logging.
|
||||
glob_path = join(module_path, "glob")
|
||||
glob_path_added = False
|
||||
original_add_handler = logging.Logger.addHandler
|
||||
|
||||
def no_op_add_handler(self, handler):
|
||||
logger.info(f"Skipping addHandler for {type(handler).__name__} during ComfyUI-Manager prestartup.")
|
||||
|
||||
try:
|
||||
sys.path.insert(0, glob_path)
|
||||
glob_path_added = True
|
||||
# Patch security_check
|
||||
import security_check
|
||||
original_check = security_check.security_check
|
||||
|
||||
def patched_security_check():
|
||||
try:
|
||||
return original_check()
|
||||
except Exception as e:
|
||||
logger.error(f"ComfyUI-Manager security_check failed but was caught gracefully: {e}", exc_info=e)
|
||||
|
||||
security_check.security_check = patched_security_check
|
||||
logger.info("Patched ComfyUI-Manager's security_check to fail gracefully.")
|
||||
|
||||
# Patch logging
|
||||
logging.Logger.addHandler = no_op_add_handler
|
||||
logger.info("Patched logging.Logger.addHandler to prevent ComfyUI-Manager from adding a logging handler.")
|
||||
|
||||
time_before = time.perf_counter()
|
||||
success = execute_script(script_path)
|
||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to patch and execute ComfyUI-Manager's prestartup script: {e}", exc_info=e)
|
||||
finally:
|
||||
if glob_path_added and glob_path in sys.path:
|
||||
sys.path.remove(glob_path)
|
||||
logging.Logger.addHandler = original_add_handler
|
||||
else:
|
||||
time_before = time.perf_counter()
|
||||
success = execute_script(script_path)
|
||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNodes:
|
||||
if module.__name__ == "ComfyUI-Manager":
|
||||
if module.__name__.lower() == "comfyui-manager":
|
||||
from ..cmd import folder_paths
|
||||
old_file = folder_paths.__file__
|
||||
|
||||
@ -93,21 +134,29 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNod
|
||||
folder_paths.__file__ = new_path
|
||||
# mitigate JS copy
|
||||
sys.modules['nodes'].EXTENSION_WEB_DIRS = {}
|
||||
|
||||
yield ExportedNodes()
|
||||
finally:
|
||||
folder_paths.__file__ = old_file
|
||||
# todo: mitigate "/manager/reboot"
|
||||
# todo: mitigate process_wrap
|
||||
# todo: unfortunately, we shouldn't restore the patches here, they will have to be applied forever.
|
||||
# concurrent.futures.ThreadPoolExecutor = _ThreadPoolExecutor
|
||||
# threading.Thread.start = original_thread_start
|
||||
else:
|
||||
# redirect stdout to the module's logger during import
|
||||
original_stdout = sys.stdout
|
||||
module_logger = logging.getLogger(module.__name__)
|
||||
yield ExportedNodes()
|
||||
|
||||
@contextmanager
|
||||
def _stdout_intercept(name: str):
|
||||
original_stdout = sys.stdout
|
||||
|
||||
try:
|
||||
module_logger = logging.getLogger(name)
|
||||
sys.stdout = StreamToLogger(module_logger, logging.INFO)
|
||||
try:
|
||||
yield ExportedNodes()
|
||||
finally:
|
||||
# Restore original stdout to ensure this change is temporary
|
||||
sys.stdout = original_stdout
|
||||
yield
|
||||
finally:
|
||||
sys.stdout = original_stdout
|
||||
|
||||
|
||||
|
||||
def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
|
||||
@ -127,7 +176,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
with _exec_mitigations(module, module_path) as mitigated_exported_nodes:
|
||||
with _exec_mitigations(module, module_path) as mitigated_exported_nodes, _stdout_intercept(module_name):
|
||||
module_spec.loader.exec_module(module)
|
||||
exported_nodes.update(mitigated_exported_nodes)
|
||||
|
||||
@ -197,6 +246,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol
|
||||
from .. import node_helpers
|
||||
from .. import __version__
|
||||
import concurrent.futures
|
||||
import threading
|
||||
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
|
||||
module_short_name = module.__name__.split(".")[-1]
|
||||
sys.modules[module_short_name] = module
|
||||
@ -237,6 +288,22 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
|
||||
node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths)
|
||||
|
||||
_ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor
|
||||
original_thread_start = threading.Thread.start
|
||||
concurrent.futures.ThreadPoolExecutor = ContextVarExecutor
|
||||
|
||||
# mitigate missing folder names and paths context
|
||||
def patched_start(self, *args, **kwargs):
|
||||
if not hasattr(self.run, '__wrapped_by_context__'):
|
||||
ctx = contextvars.copy_context()
|
||||
self.run = partial(ctx.run, self.run)
|
||||
setattr(self.run, '__wrapped_by_context__', True)
|
||||
original_thread_start(self, *args, **kwargs)
|
||||
|
||||
if not getattr(threading.Thread.start, '__is_patched_by_us', False):
|
||||
threading.Thread.start = patched_start
|
||||
setattr(threading.Thread.start, '__is_patched_by_us', True)
|
||||
logger.info("Patched `threading.Thread.start` to propagate contextvars.")
|
||||
_vanilla_load_importing_execute_prestartup_script(node_paths)
|
||||
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
|
||||
return vanilla_custom_nodes
|
||||
return vanilla_custom_nodes
|
||||
|
||||
@ -57,7 +57,7 @@ dependencies = [
|
||||
"pyjwt[crypto]",
|
||||
"kornia>=0.7.0",
|
||||
"mpmath>=1.0,!=1.4.0a0",
|
||||
"huggingface_hub[hf_transfer]",
|
||||
"huggingface_hub[hf_transfer]>0.20",
|
||||
"lazy-object-proxy",
|
||||
"lazy_loader>=0.3",
|
||||
"can_ada",
|
||||
@ -134,6 +134,18 @@ dev = [
|
||||
"astroid",
|
||||
]
|
||||
|
||||
comfyui_manager = [
|
||||
"GitPython",
|
||||
"PyGithub",
|
||||
"matrix-client==0.4.0",
|
||||
"rich",
|
||||
"typing-extensions",
|
||||
"toml",
|
||||
"uv",
|
||||
"chardet",
|
||||
"pip",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user