From f8eea225d47a8b5924e4fdd89fd591e0ee7e4acb Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 24 Jun 2025 12:30:54 -0700 Subject: [PATCH] modify main.py --- comfy/cmd/main.py | 396 +++++++++++++++++++++------------------------- 1 file changed, 184 insertions(+), 212 deletions(-) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 0d7c97dcb..6ef0b1f8b 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,154 +1,50 @@ -import comfy.options -comfy.options.enable_args_parsing() - -import os -import importlib.util -import folder_paths -import time -from comfy.cli_args import args -from app.logger import setup_logger -import itertools -import utils.extra_config -import logging -import sys - -if __name__ == "__main__": - #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. - os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' - os.environ['DO_NOT_TRACK'] = '1' - -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) - -def apply_custom_paths(): - # extra model paths - extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") - if os.path.isfile(extra_model_paths_config_path): - utils.extra_config.load_extra_path_config(extra_model_paths_config_path) - - if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): - utils.extra_config.load_extra_path_config(config_path) - - # --output-directory, --input-directory, --user-directory - if args.output_directory: - output_dir = os.path.abspath(args.output_directory) - logging.info(f"Setting output directory to: {output_dir}") - folder_paths.set_output_directory(output_dir) - - # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes - folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) - folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) - folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) - folder_paths.add_model_folder_path("diffusion_models", - os.path.join(folder_paths.get_output_directory(), "diffusion_models")) - folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) - - if args.input_directory: - input_dir = os.path.abspath(args.input_directory) - logging.info(f"Setting input directory to: {input_dir}") - folder_paths.set_input_directory(input_dir) - - if args.user_directory: - user_dir = os.path.abspath(args.user_directory) - logging.info(f"Setting user directory to: {user_dir}") - folder_paths.set_user_directory(user_dir) - - -def execute_prestartup_script(): - def execute_script(script_path): - module_name = os.path.splitext(script_path)[0] - try: - spec = importlib.util.spec_from_file_location(module_name, script_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return True - except Exception as e: - logging.error(f"Failed to execute startup-script: {script_path} / {e}") - return False - - if args.disable_all_custom_nodes: - return - - node_paths = folder_paths.get_folder_paths("custom_nodes") - for custom_node_path in node_paths: - possible_modules = os.listdir(custom_node_path) - node_prestartup_times = [] - - for possible_module in possible_modules: - module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": - continue - - script_path = os.path.join(module_path, "prestartup_script.py") - if os.path.exists(script_path): - time_before = time.perf_counter() - success = execute_script(script_path) - node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) - if len(node_prestartup_times) > 0: - logging.info("\nPrestartup times for custom nodes:") - for n in sorted(node_prestartup_times): - if n[2]: - import_message = "" - else: - import_message = " (PRESTARTUP FAILED)" - logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) - logging.info("") - -apply_custom_paths() -execute_prestartup_script() - - -# Main code import asyncio +import contextvars +import gc +import itertools +import logging +import os import shutil import threading -import gc +import time +from pathlib import Path +from typing import Optional +from comfy.component_model.entrypoints_common import configure_application_paths, executor_from_args +# main_pre must be the earliest import since it suppresses some spurious warnings +from .main_pre import args +from . import hook_breaker_ac10a0 +from .extra_model_paths import load_extra_path_config +from .. import model_management +from ..analytics.analytics import initialize_event_tracking +from ..cmd import cuda_malloc +from ..cmd import folder_paths +from ..cmd import server as server_module +from ..component_model.abstract_prompt_queue import AbstractPromptQueue +from ..distributed.distributed_prompt_queue import DistributedPromptQueue +from ..distributed.server_stub import ServerStub +from ..nodes.package import import_all_nodes_in_workspace -if os.name == "nt": - logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) +logger = logging.getLogger(__name__) -if __name__ == "__main__": - if args.cuda_device is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) - logging.info("Set cuda device to: {}".format(args.cuda_device)) - - if args.oneapi_device_selector is not None: - os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector - logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) - - if args.deterministic: - if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: - os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" - - import cuda_malloc - -import comfy.utils - -import execution -import server -from server import BinaryEventTypes -import nodes -import comfy.model_management -import comfyui_version -import app.logger -import hook_breaker_ac10a0 def cuda_malloc_warning(): - device = comfy.model_management.get_torch_device() - device_name = comfy.model_management.get_torch_device_name(device) + device = model_management.get_torch_device() + device_name = model_management.get_torch_device_name(device) cuda_malloc_warning = False if "cudaMallocAsync" in device_name: for b in cuda_malloc.blacklist: if b in device_name: cuda_malloc_warning = True if cuda_malloc_warning: - logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") + logger.warning( + "\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def prompt_worker(q, server_instance): - current_time: float = 0.0 +def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer): + from ..cmd import execution + from ..component_model import queue_types + from .. import model_management cache_type = execution.CacheType.CLASSIC if args.cache_lru > 0: cache_type = execution.CacheType.LRU @@ -159,7 +55,7 @@ def prompt_worker(q, server_instance): last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 - + current_time = 0.0 while True: timeout = 1000.0 if need_gc: @@ -176,28 +72,23 @@ def prompt_worker(q, server_instance): need_gc = True q.task_done(item_id, e.history_result, - status=execution.PromptQueue.ExecutionStatus( + status=queue_types.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, messages=e.status_messages)) if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, + server_instance.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time - - # Log Time in a more readable way after 10 minutes - if execution_time > 600: - execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time)) - logging.info(f"Prompt executed in {execution_time}") - else: - logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) + logger.debug("Prompt executed in {:.2f} seconds".format(execution_time)) flags = q.get_flags() free_memory = flags.get("free_memory", False) if flags.get("unload_models", free_memory): - comfy.model_management.unload_all_models() + model_management.unload_all_models() need_gc = True last_gc_collect = 0 @@ -210,7 +101,7 @@ def prompt_worker(q, server_instance): current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: gc.collect() - comfy.model_management.soft_empty_cache() + model_management.soft_empty_cache() last_gc_collect = current_time need_gc = False hook_breaker_ac10a0.restore_functions() @@ -220,110 +111,191 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star addresses = [] for addr in address.split(","): addresses.append((addr, port)) - await asyncio.gather( - server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop() - ) - - -def hijack_progress(server_instance): - def hook(value, total, preview_image): - comfy.model_management.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id} - - server_instance.send_sync("progress", progress, server_instance.client_id) - if preview_image is not None: - server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id) - - comfy.utils.set_progress_bar_global_hook(hook) + await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop()) def cleanup_temp(): - temp_dir = folder_paths.get_temp_directory() - if os.path.exists(temp_dir): - shutil.rmtree(temp_dir, ignore_errors=True) - - -def setup_database(): try: - from app.database.db import init_db, dependencies_available - if dependencies_available(): - init_db() - except Exception as e: - logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") + folder_paths.get_temp_directory() + temp_dir = folder_paths.get_temp_directory() + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) + except NameError: + # __file__ was not defined + pass -def start_comfyui(asyncio_loop=None): +def start_comfyui(asyncio_loop: asyncio.AbstractEventLoop = None): + asyncio_loop = asyncio_loop or asyncio.get_event_loop() + asyncio_loop.run_until_complete(_start_comfyui()) + + +async def _start_comfyui(from_script_dir: Optional[Path] = None): """ - Starts the ComfyUI server using the provided asyncio event loop or creates a new one. - Returns the event loop, server instance, and a function to start the server asynchronously. + Runs ComfyUI's frontend and backend like upstream. + :param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path """ + if not from_script_dir: + os_getcwd = os.getcwd() + else: + os_getcwd = str(from_script_dir) + if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") - logging.info(f"Setting temp directory to: {temp_dir}") + logger.debug(f"Setting temp directory to: {temp_dir}") folder_paths.set_temp_directory(temp_dir) cleanup_temp() + if args.user_directory: + user_dir = os.path.abspath(args.user_directory) + logger.info(f"Setting user directory to: {user_dir}") + folder_paths.set_user_directory(user_dir) + + # configure extra model paths earlier + try: + extra_model_paths_config_path = os.path.join(os_getcwd, "extra_model_paths.yaml") + if os.path.isfile(extra_model_paths_config_path): + load_extra_path_config(extra_model_paths_config_path) + except NameError: + pass + + if args.extra_model_paths_config: + for config_path in itertools.chain(*args.extra_model_paths_config): + load_extra_path_config(config_path) + + # always create directories when started interactively + folder_paths.create_directories() + if args.create_directories: + import_all_nodes_in_workspace(raise_on_failure=False) + folder_paths.create_directories() + exit(0) + if args.windows_standalone_build: + folder_paths.create_directories() try: - import new_updater + from . import new_updater new_updater.update_windows_updater() except: pass - if not asyncio_loop: - asyncio_loop = asyncio.new_event_loop() - asyncio.set_event_loop(asyncio_loop) - prompt_server = server.PromptServer(asyncio_loop) + loop = asyncio.get_event_loop() + server = server_module.PromptServer(loop) + if args.external_address is not None: + server.external_address = args.external_address + # at this stage, it's safe to import nodes hook_breaker_ac10a0.save_functions() - nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes) + server.nodes = import_all_nodes_in_workspace() hook_breaker_ac10a0.restore_functions() + # as a side effect, this also populates the nodes for execution + if args.distributed_queue_connection_uri is not None: + distributed = True + q = DistributedPromptQueue( + caller_server=server if args.distributed_queue_frontend else None, + connection_uri=args.distributed_queue_connection_uri, + is_caller=args.distributed_queue_frontend, + is_callee=args.distributed_queue_worker, + loop=loop, + queue_name=args.distributed_queue_name + ) + await q.init() + else: + distributed = False + from .execution import PromptQueue + q = PromptQueue(server) + server.prompt_queue = q + + server.add_routes() cuda_malloc_warning() - setup_database() - prompt_server.add_routes() - hijack_progress(prompt_server) + # in a distributed setting, the default prompt worker will not be able to send execution events via the websocket + worker_thread_server = server if not distributed else ServerStub() + if not distributed or args.distributed_queue_worker: + if distributed: + logger.warning( + f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.") + # todo: this should really be using an executor instead of doing things this jankilicious way + ctx = contextvars.copy_context() + threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), + daemon=True, args=(q, worker_thread_server,)).start() - threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() + # server has been imported and things should be looking good + initialize_event_tracking(loop) + + if args.output_directory: + output_dir = os.path.abspath(args.output_directory) + logger.debug(f"Setting output directory to: {output_dir}") + folder_paths.set_output_directory(output_dir) + + # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes + folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) + folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) + folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) + folder_paths.add_model_folder_path("diffusion_models", + os.path.join(folder_paths.get_output_directory(), "diffusion_models")) + folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras")) + + if args.input_directory: + input_dir = os.path.abspath(args.input_directory) + logger.debug(f"Setting input directory to: {input_dir}") + folder_paths.set_input_directory(input_dir) if args.quick_test_for_ci: - exit(0) + # for CI purposes, try importing all the nodes + import_all_nodes_in_workspace(raise_on_failure=True) + return + else: + # we no longer lazily load nodes. we'll do it now for the sake of creating directories + import_all_nodes_in_workspace(raise_on_failure=False) + # now that nodes are loaded, create more directories if appropriate + folder_paths.create_directories() - os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) + if len(args.workflows) > 0: + configure_application_paths(args) + executor = await executor_from_args(args) + from ..entrypoints.workflow import run_workflows + await run_workflows(executor, args.workflows) + return + + # replaced my folder_paths.create_directories call_on_start = None if args.auto_launch: - def startup_server(scheme, address, port): + def startup_server(scheme="http", address="localhost", port=8188): import webbrowser - if os.name == 'nt' and address == '0.0.0.0': + if os.name == 'nt' and address == '0.0.0.0' or address == '': address = '127.0.0.1' if ':' in address: address = "[{}]".format(address) webbrowser.open(f"{scheme}://{address}:{port}") + call_on_start = startup_server - async def start_all(): - await prompt_server.setup() - await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) + first_listen_addr = args.listen.split(',')[0] if ',' in args.listen else args.listen + server.address = first_listen_addr + server.port = args.port - # Returning these so that other code can integrate with the ComfyUI loop and server - return asyncio_loop, prompt_server, start_all + try: + await server.setup() + await run(server, address=first_listen_addr, port=args.port, verbose=not args.dont_print_server, + call_on_start=call_on_start) + except (asyncio.CancelledError, KeyboardInterrupt): + logger.debug("Stopped server") + finally: + if distributed: + await q.close() + cleanup_temp() +def entrypoint(): + try: + asyncio.run(_start_comfyui()) + except KeyboardInterrupt: + logger.info(f"Gracefully shutting down due to KeyboardInterrupt") + + +def main(): + entrypoint() + if __name__ == "__main__": - # Running directly, just start ComfyUI. - logging.info("Python version: {}".format(sys.version)) - logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) - - if sys.version_info.major == 3 and sys.version_info.minor < 10: - logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") - - event_loop, _, start_all_func = start_comfyui() - try: - x = start_all_func() - app.logger.print_startup_warnings() - event_loop.run_until_complete(x) - except KeyboardInterrupt: - logging.info("\nStopped server") - - cleanup_temp() + entrypoint()