move and fix main

This commit is contained in:
Benjamin Berman 2025-04-30 11:59:59 -07:00
parent eed79e210e
commit da2cbf7c91
3 changed files with 177 additions and 206 deletions

View File

@ -1,7 +1,7 @@
# Prevent custom nodes from hooking anything important # Prevent custom nodes from hooking anything important
import comfy.model_management from .. import model_management
HOOK_BREAK = [(comfy.model_management, "cast_to")] HOOK_BREAK = [(model_management, "cast_to")]
SAVED_FUNCTIONS = [] SAVED_FUNCTIONS = []

View File

@ -1,162 +1,49 @@
import comfy.options
comfy.options.enable_args_parsing()
import os
import importlib.util
import folder_paths
import time
from comfy.cli_args import args
from app.logger import setup_logger
import itertools
import utils.extra_config
import logging
import sys
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
def apply_custom_paths():
# extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
utils.extra_config.load_extra_path_config(config_path)
# --output-directory, --input-directory, --user-directory
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
logging.info(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models",
os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.info(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
def execute_prestartup_script():
def execute_script(script_path):
module_name = os.path.splitext(script_path)[0]
try:
spec = importlib.util.spec_from_file_location(module_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return True
except Exception as e:
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
return False
if args.disable_all_custom_nodes:
return
node_paths = folder_paths.get_folder_paths("custom_nodes")
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
node_prestartup_times = []
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
continue
script_path = os.path.join(module_path, "prestartup_script.py")
if os.path.exists(script_path):
time_before = time.perf_counter()
success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_prestartup_times) > 0:
logging.info("\nPrestartup times for custom nodes:")
for n in sorted(node_prestartup_times):
if n[2]:
import_message = ""
else:
import_message = " (PRESTARTUP FAILED)"
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
logging.info("")
apply_custom_paths()
execute_prestartup_script()
# Main code
import asyncio import asyncio
import contextvars
import gc
import itertools
import logging
import os
import shutil import shutil
import threading import threading
import gc import time
from pathlib import Path
from typing import Optional
# main_pre must be the earliest import since it suppresses some spurious warnings
from .main_pre import args
from . import hook_breaker_ac10a0
from .extra_model_paths import load_extra_path_config
from .. import model_management
from ..analytics.analytics import initialize_event_tracking
from ..cmd import cuda_malloc
from ..cmd import folder_paths
from ..cmd import server as server_module
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
from ..distributed.server_stub import ServerStub
from ..nodes.package import import_all_nodes_in_workspace
if os.name == "nt": logger = logging.getLogger(__name__)
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__":
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None:
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc
if args.windows_standalone_build:
try:
from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except:
pass
import comfy.utils
import execution
import server
from server import BinaryEventTypes
import nodes
import comfy.model_management
import comfyui_version
import app.logger
import hook_breaker_ac10a0
def cuda_malloc_warning(): def cuda_malloc_warning():
device = comfy.model_management.get_torch_device() device = model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device) device_name = model_management.get_torch_device_name(device)
cuda_malloc_warning = False cuda_malloc_warning = False
if "cudaMallocAsync" in device_name: if "cudaMallocAsync" in device_name:
for b in cuda_malloc.blacklist: for b in cuda_malloc.blacklist:
if b in device_name: if b in device_name:
cuda_malloc_warning = True cuda_malloc_warning = True
if cuda_malloc_warning: if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") logger.warning(
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server_instance): def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
current_time: float = 0.0 from ..cmd import execution
from ..component_model import queue_types
from .. import model_management
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0: if args.cache_lru > 0:
cache_type = execution.CacheType.LRU cache_type = execution.CacheType.LRU
@ -167,7 +54,7 @@ def prompt_worker(q, server_instance):
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0
current_time = 0.0
while True: while True:
timeout = 1000.0 timeout = 1000.0
if need_gc: if need_gc:
@ -184,22 +71,23 @@ def prompt_worker(q, server_instance):
need_gc = True need_gc = True
q.task_done(item_id, q.task_done(item_id,
e.history_result, e.history_result,
status=execution.PromptQueue.ExecutionStatus( status=queue_types.ExecutionStatus(
status_str='success' if e.success else 'error', status_str='success' if e.success else 'error',
completed=e.success, completed=e.success,
messages=e.status_messages)) messages=e.status_messages))
if server_instance.client_id is not None: if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
server_instance.client_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) logger.debug("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags() flags = q.get_flags()
free_memory = flags.get("free_memory", False) free_memory = flags.get("free_memory", False)
if flags.get("unload_models", free_memory): if flags.get("unload_models", free_memory):
comfy.model_management.unload_all_models() model_management.unload_all_models()
need_gc = True need_gc = True
last_gc_collect = 0 last_gc_collect = 0
@ -212,7 +100,7 @@ def prompt_worker(q, server_instance):
current_time = time.perf_counter() current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval: if (current_time - last_gc_collect) > gc_collect_interval:
gc.collect() gc.collect()
comfy.model_management.soft_empty_cache() model_management.soft_empty_cache()
last_gc_collect = current_time last_gc_collect = current_time
need_gc = False need_gc = False
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
@ -222,98 +110,181 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
addresses = [] addresses = []
for addr in address.split(","): for addr in address.split(","):
addresses.append((addr, port)) addresses.append((addr, port))
await asyncio.gather( await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
)
def hijack_progress(server_instance):
def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None:
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp(): def cleanup_temp():
temp_dir = folder_paths.get_temp_directory() try:
if os.path.exists(temp_dir): folder_paths.get_temp_directory()
shutil.rmtree(temp_dir, ignore_errors=True) temp_dir = folder_paths.get_temp_directory()
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
except NameError:
# __file__ was not defined
pass
def start_comfyui(asyncio_loop=None): def start_comfyui(asyncio_loop: asyncio.AbstractEventLoop = None):
asyncio_loop = asyncio_loop or asyncio.get_event_loop()
asyncio_loop.run_until_complete(_start_comfyui())
async def _start_comfyui(from_script_dir: Optional[Path] = None):
""" """
Starts the ComfyUI server using the provided asyncio event loop or creates a new one. Runs ComfyUI's frontend and backend like upstream.
Returns the event loop, server instance, and a function to start the server asynchronously. :param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
""" """
if not from_script_dir:
os_getcwd = os.getcwd()
else:
os_getcwd = str(from_script_dir)
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
logging.info(f"Setting temp directory to: {temp_dir}") logger.debug(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)
cleanup_temp() cleanup_temp()
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logger.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
# configure extra model paths earlier
try:
extra_model_paths_config_path = os.path.join(os_getcwd, "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
except NameError:
pass
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
# always create directories when started interactively
folder_paths.create_directories()
if args.create_directories:
import_all_nodes_in_workspace(raise_on_failure=False)
folder_paths.create_directories()
exit(0)
if args.windows_standalone_build: if args.windows_standalone_build:
folder_paths.create_directories()
try: try:
import new_updater from . import new_updater
new_updater.update_windows_updater() new_updater.update_windows_updater()
except: except:
pass pass
if not asyncio_loop: loop = asyncio.get_event_loop()
asyncio_loop = asyncio.new_event_loop() server = server_module.PromptServer(loop)
asyncio.set_event_loop(asyncio_loop) if args.external_address is not None:
prompt_server = server.PromptServer(asyncio_loop) server.external_address = args.external_address
q = execution.PromptQueue(prompt_server)
# at this stage, it's safe to import nodes
hook_breaker_ac10a0.save_functions() hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) server.nodes = import_all_nodes_in_workspace()
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
# as a side effect, this also populates the nodes for execution
if args.distributed_queue_connection_uri is not None:
distributed = True
q = DistributedPromptQueue(
caller_server=server if args.distributed_queue_frontend else None,
connection_uri=args.distributed_queue_connection_uri,
is_caller=args.distributed_queue_frontend,
is_callee=args.distributed_queue_worker,
loop=loop,
queue_name=args.distributed_queue_name
)
await q.init()
else:
distributed = False
from .execution import PromptQueue
q = PromptQueue(server)
server.prompt_queue = q
server.add_routes()
cuda_malloc_warning() cuda_malloc_warning()
prompt_server.add_routes() # in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
hijack_progress(prompt_server) worker_thread_server = server if not distributed else ServerStub()
if not distributed or args.distributed_queue_worker:
if distributed:
logger.warning(
f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
# todo: this should really be using an executor instead of doing things this jankilicious way
ctx = contextvars.copy_context()
threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server),
daemon=True, args=(q, worker_thread_server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start() # server has been imported and things should be looking good
initialize_event_tracking(loop)
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
logger.debug(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models",
os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logger.debug(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
# for CI purposes, try importing all the nodes
import_all_nodes_in_workspace(raise_on_failure=True)
exit(0) exit(0)
else:
# we no longer lazily load nodes. we'll do it now for the sake of creating directories
import_all_nodes_in_workspace(raise_on_failure=False)
# now that nodes are loaded, create more directories if appropriate
folder_paths.create_directories()
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True) # replaced my folder_paths.create_directories
call_on_start = None call_on_start = None
if args.auto_launch: if args.auto_launch:
def startup_server(scheme, address, port): def startup_server(scheme="http", address="localhost", port=8188):
import webbrowser import webbrowser
if os.name == 'nt' and address == '0.0.0.0': if os.name == 'nt' and address == '0.0.0.0' or address == '':
address = '127.0.0.1' address = '127.0.0.1'
if ':' in address: if ':' in address:
address = "[{}]".format(address) address = "[{}]".format(address)
webbrowser.open(f"{scheme}://{address}:{port}") webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server
async def start_all(): first_listen_addr = args.listen.split(',')[0] if ',' in args.listen else args.listen
await prompt_server.setup() server.address = first_listen_addr
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) server.port = args.port
# Returning these so that other code can integrate with the ComfyUI loop and server try:
return asyncio_loop, prompt_server, start_all await server.setup()
await run(server, address=first_listen_addr, port=args.port, verbose=not args.dont_print_server,
call_on_start=call_on_start)
except (asyncio.CancelledError, KeyboardInterrupt):
logger.debug("\nStopped server")
finally:
if distributed:
await q.close()
cleanup_temp()
def entrypoint():
try:
asyncio.run(_start_comfyui())
except KeyboardInterrupt:
logger.info(f"Gracefully shutting down due to KeyboardInterrupt")
if __name__ == "__main__": if __name__ == "__main__":
# Running directly, just start ComfyUI. entrypoint()
logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui()
try:
x = start_all_func()
app.logger.print_startup_warnings()
event_loop.run_until_complete(x)
except KeyboardInterrupt:
logging.info("\nStopped server")
cleanup_temp()

View File

@ -7,10 +7,10 @@ from comfy.component_model.folder_path_types import FolderNames
if __name__ == "__main__": if __name__ == "__main__":
warnings.warn("main.py is deprecated. Start comfyui by installing the package through the instructions in the README, not by cloning the repository.", DeprecationWarning) warnings.warn("main.py is deprecated. Start comfyui by installing the package through the instructions in the README, not by cloning the repository.", DeprecationWarning)
this_file_parent_dir = Path(__file__).parent this_file_parent_dir = Path(__file__).parent
from comfy.cmd.main import main from comfy.cmd.main import _start_comfyui
from comfy.cmd.folder_paths import folder_names_and_paths # type: FolderNames from comfy.cmd.folder_paths import folder_names_and_paths # type: FolderNames
fn: FolderNames = folder_names_and_paths fn: FolderNames = folder_names_and_paths
fn.base_paths.clear() fn.base_paths.clear()
fn.base_paths.append(this_file_parent_dir) fn.base_paths.append(this_file_parent_dir)
asyncio.run(main(from_script_dir=this_file_parent_dir)) asyncio.run(_start_comfyui(from_script_dir=this_file_parent_dir))