From 0673262940518b5f74a016aa991e1ad4fb5e866f Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 8 Feb 2024 19:08:42 -0800 Subject: [PATCH] Fix entrypoints, add comfyui-worker entrypoint --- comfy/analytics/analytics.py | 4 +- comfy/cli_args.py | 2 +- comfy/cmd/main.py | 10 +++-- comfy/cmd/server.py | 2 +- comfy/cmd/worker.py | 30 +++++++++++++ .../distributed/distributed_prompt_worker.py | 42 ++++++++++++++++--- setup.py | 4 +- 7 files changed, 80 insertions(+), 14 deletions(-) create mode 100644 comfy/cmd/worker.py diff --git a/comfy/analytics/analytics.py b/comfy/analytics/analytics.py index d0b8bfd15..a5256d4f3 100644 --- a/comfy/analytics/analytics.py +++ b/comfy/analytics/analytics.py @@ -11,7 +11,9 @@ from ..api.components.schema.prompt import Prompt _event_tracker: MultiEventTracker -def initialize_event_tracking(loop: asyncio.AbstractEventLoop): +def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None): + loop = loop or asyncio.get_event_loop() + assert loop is not None _event_trackers = [] # perform the imports at the time this is invoked to prevent side effects and ordering issues from ..cli_args import args diff --git a/comfy/cli_args.py b/comfy/cli_args.py index ebbbd4466..4fbc3c4b4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -120,7 +120,7 @@ parser.add_argument("--plausible-analytics-domain", required=False, parser.add_argument("--analytics-use-identity-provider", action="store_true", help="Uses platform identifiers for unique visitor analytics.") parser.add_argument("--distributed-queue-connection-uri", type=str, default=None, - help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.") + help="EXAMPLE: \"amqp://guest:guest@127.0.0.1\" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.") parser.add_argument( '--distributed-queue-roles', action='append', diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index c546636cf..cba09c61c 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -235,7 +235,6 @@ async def main(): queue_name=args.distributed_queue_name ) await q.init() - loop.add_signal_handler(signal.SIGINT, lambda *args, **kwargs: q.close()) else: distributed = False q = execution.PromptQueue(server) @@ -260,7 +259,8 @@ async def main(): # the distributed prompt queue will be responsible for simulating those events until the broker is configured to # pass those messages to the appropriate user worker_thread_server = server if not distributed else ServerStub() - threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start() + if "worker" in args.distributed_queue_roles: + threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start() # server has been imported and things should be looking good initialize_event_tracking(loop) @@ -305,5 +305,9 @@ async def main(): cleanup_temp() -if __name__ == "__main__": +def entrypoint(): asyncio.run(main()) + + +if __name__ == "__main__": + entrypoint() diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index c709baf2a..8310e26b5 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -782,7 +782,7 @@ class PromptServer(ExecutorToClientProgress): if verbose: print("Starting server\n") - print("To see the GUI go to: http://{}:{}".format(address, port)) + print("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port)) if call_on_start is not None: call_on_start(address, port) diff --git a/comfy/cmd/worker.py b/comfy/cmd/worker.py new file mode 100644 index 000000000..a3e155e66 --- /dev/null +++ b/comfy/cmd/worker.py @@ -0,0 +1,30 @@ +import asyncio + +from .. import options +from ..distributed.distributed_prompt_worker import DistributedPromptWorker + +options.enable_args_parsing() + +from ..cli_args import args + + +async def main(): + # assume we are a worker + args.distributed_queue_roles = ["worker"] + assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server" + + async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, + queue_name=args.distributed_queue_name): + stop = asyncio.Event() + try: + await stop.wait() + except asyncio.CancelledError: + pass + + +def entrypoint(): + asyncio.run(main()) + + +if __name__ == "__main__": + entrypoint() diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index 2fbbdf9de..e2b335f59 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -1,4 +1,6 @@ import asyncio +import logging +import sys from asyncio import AbstractEventLoop from contextlib import AsyncExitStack from dataclasses import asdict @@ -6,6 +8,7 @@ from typing import Optional from aio_pika import connect_robust from aio_pika.patterns import JsonRPC +from aiormq import AMQPConnectionError from .distributed_types import RpcRequest, RpcReply from ..client.embedded_comfy_client import EmbeddedComfyClient @@ -28,36 +31,63 @@ class DistributedPromptWorker: self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient() async def _do_work_item(self, request: dict) -> dict: + await self.on_will_complete_work_item(request) try: request_obj = RpcRequest.from_dict(request) except Exception as e: request_dict_prompt_id_recovered = request["prompt_id"] \ if request is not None and "prompt_id" in request else "" return asdict(RpcReply(request_dict_prompt_id_recovered, "", {}, - ExecutionStatus("error", False, [str(e)]))) + ExecutionStatus("error", False, [str(e)]))) + reply: RpcReply try: output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt, request_obj.prompt_id, client_id=request_obj.user_id) - return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, []))) + reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, + ExecutionStatus("success", True, [])) except Exception as e: - return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)]))) + reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {}, + ExecutionStatus("error", False, [str(e)])) - async def __aenter__(self) -> "DistributedPromptWorker": + await self.on_did_complete_work_item(request_obj, reply) + return asdict(reply) + + async def init(self): await self._exit_stack.__aenter__() if not self._embedded_comfy_client.is_running: await self._exit_stack.enter_async_context(self._embedded_comfy_client) - self._connection = await connect_robust(self._connection_uri, loop=self._loop) + try: + self._connection = await connect_robust(self._connection_uri, loop=self._loop) + except AMQPConnectionError as connection_error: + logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error) + raise connection_error self._channel = await self._connection.channel() self._rpc = await JsonRPC.create(channel=self._channel) self._rpc.host_exceptions = True await self._rpc.register(self._queue_name, self._do_work_item) + + async def __aenter__(self) -> "DistributedPromptWorker": + await self.init() return self - async def __aexit__(self, *args): + async def _close(self): await self._rpc.close() await self._channel.close() await self._connection.close() + + async def close(self): + await self._close() + await self._exit_stack.aclose() + + async def __aexit__(self, *args): + await self._close() return await self._exit_stack.__aexit__(*args) + + async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply): + pass + + async def on_will_complete_work_item(self, request: dict): + pass diff --git a/setup.py b/setup.py index 00b439c69..c37ee6cae 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,6 @@ try: print(f"comfyui setup.py: torch version was {torch.__version__} and built without build isolation, using this torch instead of upgrading", file=sys.stderr) is_build_isolated_and_torch_version = torch.__version__ except Exception as e: - print(f"comfyui setup.py: torch could not be imported because running with build isolation or not installed ({e}), installing torch for your platform", file=sys.stderr) is_build_isolated_and_torch_version = None def _is_nvidia() -> bool: @@ -186,7 +185,8 @@ setup( entry_points={ 'console_scripts': [ 'comfyui-openapi-gen = comfy.cmd.openapi_gen:main', - 'comfyui = comfy.cmd.main:main' + 'comfyui = comfy.cmd.main:entrypoint', + 'comfyui-worker = comfy.cmd.worker:entrypoint' ], }, package_data={