diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index bfd578cd8..04b02f9f0 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -203,7 +203,9 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None): # at this stage, it's safe to import nodes hook_breaker_ac10a0.save_functions() - server.nodes = get_nodes() + nodes_to_import = get_nodes() + logger.debug(f"Imported {len(nodes_to_import)} nodes") + server.nodes = nodes_to_import hook_breaker_ac10a0.restore_functions() # as a side effect, this also populates the nodes for execution diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index e08ea38a3..9d8405871 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -1003,6 +1003,14 @@ class PromptServer(ExecutorToClientProgress): self.client_session = aiohttp.ClientSession(timeout=timeout) def add_routes(self): + # a mitigation for vanilla comfyui custom nodes that are stateful and add routes to a global + # prompt server instance. this is not a recommended pattern, but this mitigation is here to + # support it + from ..nodes.vanilla_node_importing import prompt_server_instance_routes + for route in prompt_server_instance_routes.routes: + self.routes.route(route.method, route.path)(route.handler) + prompt_server_instance_routes.clear() + self.user_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes) # todo: needs to use module directories diff --git a/comfy/component_model/folder_path_types.py b/comfy/component_model/folder_path_types.py index 19aa930d8..42932e217 100644 --- a/comfy/component_model/folder_path_types.py +++ b/comfy/component_model/folder_path_types.py @@ -147,6 +147,11 @@ class SupportedExtensions: p: FolderNames = self.parent() p.remove_all_supported_extensions(self.folder_name) + def __len__(self): + p: FolderNames = self.parent() + return len(list(p.supported_extensions(self.folder_name))) + + __ior__ = _append_any add = _append_any update = _append_any diff --git a/comfy/component_model/plugins.py b/comfy/component_model/plugins.py index 09aa334e7..2db02b8bf 100644 --- a/comfy/component_model/plugins.py +++ b/comfy/component_model/plugins.py @@ -1,6 +1,15 @@ +from typing import Callable, NamedTuple + + +class RouteTuple(NamedTuple): + method: str + path: str + func: Callable + + class _RoutesWrapper: def __init__(self): - self.routes = [] + self.routes: list[RouteTuple] = [] def _decorator_factory(self, method): def decorator(path): @@ -8,7 +17,8 @@ class _RoutesWrapper: from ..cmd.server import PromptServer if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper): getattr(PromptServer.instance.routes, method)(path)(func) - self.routes.append((method, path, func)) + else: + self.routes.append(RouteTuple(method, path, func)) return func return wrapper @@ -39,5 +49,8 @@ class _RoutesWrapper: def route(self, method, path): return self._decorator_factory(method.lower())(path) + def clear(self): + self.routes.clear() + prompt_server_instance_routes = _RoutesWrapper() diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index 6357cd84e..e29e8fc2c 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -5,10 +5,11 @@ import torch from .autoencoder_dc import AutoencoderDC +logger = logging.getLogger(__name__) try: import torchaudio # pylint: disable=import-error except: - logging.warning("torchaudio missing, ACE model will be broken") + logger.warning("torchaudio missing, ACE model will be broken") import torchvision.transforms as transforms from .music_vocoder import ADaMoSHiFiGANV1 diff --git a/comfy/ldm/ace/vae/music_log_mel.py b/comfy/ldm/ace/vae/music_log_mel.py index caa3c64d0..8cc07084d 100755 --- a/comfy/ldm/ace/vae/music_log_mel.py +++ b/comfy/ldm/ace/vae/music_log_mel.py @@ -1,23 +1,28 @@ # Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py +import logging + import torch import torch.nn as nn from torch import Tensor -import logging + +logger = logging.getLogger(__name__) + try: from torchaudio.transforms import MelScale # pylint: disable=import-error except: - logging.warning("torchaudio missing, ACE model will be broken") + logger.warning("torchaudio missing, ACE model will be broken") from .... import model_management + class LinearSpectrogram(nn.Module): def __init__( - self, - n_fft=2048, - win_length=2048, - hop_length=512, - center=False, - mode="pow2_sqrt", + self, + n_fft=2048, + win_length=2048, + hop_length=512, + center=False, + mode="pow2_sqrt", ): super().__init__() @@ -64,15 +69,15 @@ class LinearSpectrogram(nn.Module): class LogMelSpectrogram(nn.Module): def __init__( - self, - sample_rate=44100, - n_fft=2048, - win_length=2048, - hop_length=512, - n_mels=128, - center=False, - f_min=0.0, - f_max=None, + self, + sample_rate=44100, + n_fft=2048, + win_length=2048, + hop_length=512, + n_mels=128, + center=False, + f_min=0.0, + f_max=None, ): super().__init__() diff --git a/comfy/nodes/vanilla_node_importing.py b/comfy/nodes/vanilla_node_importing.py index 164536077..57bcf3d90 100644 --- a/comfy/nodes/vanilla_node_importing.py +++ b/comfy/nodes/vanilla_node_importing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import importlib import logging import os @@ -7,20 +8,25 @@ import sys import time import types from contextlib import contextmanager +from functools import partial from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath from typing import Dict, Iterable from . import base_nodes from .package_typing import ExportedNodes +from ..cmd import folder_paths from ..component_model.plugins import prompt_server_instance_routes +from ..distributed.executors import ContextVarExecutor logger = logging.getLogger(__name__) + class StreamToLogger: """ File-like stream object that redirects writes to a logger instance. This is used to capture print() statements from modules during import. """ + def __init__(self, logger: logging.Logger, log_level=logging.INFO): self.logger = logger self.log_level = log_level @@ -45,9 +51,10 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) def execute_script(script_path): module_name = splitext(script_path)[0] try: - spec = importlib.util.spec_from_file_location(module_name, script_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + with _stdout_intercept(module_name): + spec = importlib.util.spec_from_file_location(module_name, script_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) return True except Exception as e: logger.error(f"Failed to execute startup-script: {script_path} / {e}") @@ -68,22 +75,56 @@ def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) script_path = join(module_path, "prestartup_script.py") if exists(script_path): - time_before = time.perf_counter() - success = execute_script(script_path) - node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) - if len(node_prestartup_times) > 0: - logger.debug("\nPrestartup times for custom nodes:") - for n in sorted(node_prestartup_times): - if n[2]: - import_message = "" - else: - import_message = " (PRESTARTUP FAILED)" - logger.debug("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) + if "comfyui-manager" in module_path.lower(): + os.environ['COMFYUI_PATH'] = str(folder_paths.base_path) + os.environ['COMFYUI_FOLDERS_BASE_PATH'] = str(folder_paths.models_dir) + # Monkey-patch ComfyUI-Manager's security check to prevent it from crashing on startup + # and its logging handler to prevent it from taking over logging. + glob_path = join(module_path, "glob") + glob_path_added = False + original_add_handler = logging.Logger.addHandler + + def no_op_add_handler(self, handler): + logger.info(f"Skipping addHandler for {type(handler).__name__} during ComfyUI-Manager prestartup.") + + try: + sys.path.insert(0, glob_path) + glob_path_added = True + # Patch security_check + import security_check + original_check = security_check.security_check + + def patched_security_check(): + try: + return original_check() + except Exception as e: + logger.error(f"ComfyUI-Manager security_check failed but was caught gracefully: {e}", exc_info=e) + + security_check.security_check = patched_security_check + logger.info("Patched ComfyUI-Manager's security_check to fail gracefully.") + + # Patch logging + logging.Logger.addHandler = no_op_add_handler + logger.info("Patched logging.Logger.addHandler to prevent ComfyUI-Manager from adding a logging handler.") + + time_before = time.perf_counter() + success = execute_script(script_path) + node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) + except Exception as e: + logger.error(f"Failed to patch and execute ComfyUI-Manager's prestartup script: {e}", exc_info=e) + finally: + if glob_path_added and glob_path in sys.path: + sys.path.remove(glob_path) + logging.Logger.addHandler = original_add_handler + else: + time_before = time.perf_counter() + success = execute_script(script_path) + node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) @contextmanager def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNodes: - if module.__name__ == "ComfyUI-Manager": + if module.__name__.lower() == "comfyui-manager": from ..cmd import folder_paths old_file = folder_paths.__file__ @@ -93,21 +134,29 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNod folder_paths.__file__ = new_path # mitigate JS copy sys.modules['nodes'].EXTENSION_WEB_DIRS = {} + yield ExportedNodes() finally: folder_paths.__file__ = old_file # todo: mitigate "/manager/reboot" # todo: mitigate process_wrap + # todo: unfortunately, we shouldn't restore the patches here, they will have to be applied forever. + # concurrent.futures.ThreadPoolExecutor = _ThreadPoolExecutor + # threading.Thread.start = original_thread_start else: - # redirect stdout to the module's logger during import - original_stdout = sys.stdout - module_logger = logging.getLogger(module.__name__) + yield ExportedNodes() + +@contextmanager +def _stdout_intercept(name: str): + original_stdout = sys.stdout + + try: + module_logger = logging.getLogger(name) sys.stdout = StreamToLogger(module_logger, logging.INFO) - try: - yield ExportedNodes() - finally: - # Restore original stdout to ensure this change is temporary - sys.stdout = original_stdout + yield + finally: + sys.stdout = original_stdout + def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes: @@ -127,7 +176,7 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes: module = importlib.util.module_from_spec(module_spec) sys.modules[module_name] = module - with _exec_mitigations(module, module_path) as mitigated_exported_nodes: + with _exec_mitigations(module, module_path) as mitigated_exported_nodes, _stdout_intercept(module_name): module_spec.loader.exec_module(module) exported_nodes.update(mitigated_exported_nodes) @@ -197,6 +246,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes: from ..cmd import cuda_malloc, folder_paths, latent_preview, protocol from .. import node_helpers from .. import __version__ + import concurrent.futures + import threading for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol): module_short_name = module.__name__.split(".")[-1] sys.modules[module_short_name] = module @@ -237,6 +288,22 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes: node_paths = frozenset(abspath(custom_node_path) for custom_node_path in node_paths) + _ThreadPoolExecutor = concurrent.futures.ThreadPoolExecutor + original_thread_start = threading.Thread.start + concurrent.futures.ThreadPoolExecutor = ContextVarExecutor + + # mitigate missing folder names and paths context + def patched_start(self, *args, **kwargs): + if not hasattr(self.run, '__wrapped_by_context__'): + ctx = contextvars.copy_context() + self.run = partial(ctx.run, self.run) + setattr(self.run, '__wrapped_by_context__', True) + original_thread_start(self, *args, **kwargs) + + if not getattr(threading.Thread.start, '__is_patched_by_us', False): + threading.Thread.start = patched_start + setattr(threading.Thread.start, '__is_patched_by_us', True) + logger.info("Patched `threading.Thread.start` to propagate contextvars.") _vanilla_load_importing_execute_prestartup_script(node_paths) vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths) - return vanilla_custom_nodes \ No newline at end of file + return vanilla_custom_nodes diff --git a/pyproject.toml b/pyproject.toml index 745415567..4916a6f64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "pyjwt[crypto]", "kornia>=0.7.0", "mpmath>=1.0,!=1.4.0a0", - "huggingface_hub[hf_transfer]", + "huggingface_hub[hf_transfer]>0.20", "lazy-object-proxy", "lazy_loader>=0.3", "can_ada", @@ -134,6 +134,18 @@ dev = [ "astroid", ] +comfyui_manager = [ + "GitPython", + "PyGithub", + "matrix-client==0.4.0", + "rich", + "typing-extensions", + "toml", + "uv", + "chardet", + "pip", +] + [project.optional-dependencies] cpu = [ "torch",