Fix entrypoints, add comfyui-worker entrypoint

This commit is contained in:
doctorpangloss 2024-02-08 19:08:42 -08:00
parent 72e92514a4
commit 0673262940
7 changed files with 80 additions and 14 deletions

View File

@ -11,7 +11,9 @@ from ..api.components.schema.prompt import Prompt
_event_tracker: MultiEventTracker _event_tracker: MultiEventTracker
def initialize_event_tracking(loop: asyncio.AbstractEventLoop): def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
loop = loop or asyncio.get_event_loop()
assert loop is not None
_event_trackers = [] _event_trackers = []
# perform the imports at the time this is invoked to prevent side effects and ordering issues # perform the imports at the time this is invoked to prevent side effects and ordering issues
from ..cli_args import args from ..cli_args import args

View File

@ -120,7 +120,7 @@ parser.add_argument("--plausible-analytics-domain", required=False,
parser.add_argument("--analytics-use-identity-provider", action="store_true", parser.add_argument("--analytics-use-identity-provider", action="store_true",
help="Uses platform identifiers for unique visitor analytics.") help="Uses platform identifiers for unique visitor analytics.")
parser.add_argument("--distributed-queue-connection-uri", type=str, default=None, parser.add_argument("--distributed-queue-connection-uri", type=str, default=None,
help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.") help="EXAMPLE: \"amqp://guest:guest@127.0.0.1\" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
parser.add_argument( parser.add_argument(
'--distributed-queue-roles', '--distributed-queue-roles',
action='append', action='append',

View File

@ -235,7 +235,6 @@ async def main():
queue_name=args.distributed_queue_name queue_name=args.distributed_queue_name
) )
await q.init() await q.init()
loop.add_signal_handler(signal.SIGINT, lambda *args, **kwargs: q.close())
else: else:
distributed = False distributed = False
q = execution.PromptQueue(server) q = execution.PromptQueue(server)
@ -260,7 +259,8 @@ async def main():
# the distributed prompt queue will be responsible for simulating those events until the broker is configured to # the distributed prompt queue will be responsible for simulating those events until the broker is configured to
# pass those messages to the appropriate user # pass those messages to the appropriate user
worker_thread_server = server if not distributed else ServerStub() worker_thread_server = server if not distributed else ServerStub()
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start() if "worker" in args.distributed_queue_roles:
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
# server has been imported and things should be looking good # server has been imported and things should be looking good
initialize_event_tracking(loop) initialize_event_tracking(loop)
@ -305,5 +305,9 @@ async def main():
cleanup_temp() cleanup_temp()
if __name__ == "__main__": def entrypoint():
asyncio.run(main()) asyncio.run(main())
if __name__ == "__main__":
entrypoint()

View File

@ -782,7 +782,7 @@ class PromptServer(ExecutorToClientProgress):
if verbose: if verbose:
print("Starting server\n") print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port)) print("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)

30
comfy/cmd/worker.py Normal file
View File

@ -0,0 +1,30 @@
import asyncio
from .. import options
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
options.enable_args_parsing()
from ..cli_args import args
async def main():
# assume we are a worker
args.distributed_queue_roles = ["worker"]
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name):
stop = asyncio.Event()
try:
await stop.wait()
except asyncio.CancelledError:
pass
def entrypoint():
asyncio.run(main())
if __name__ == "__main__":
entrypoint()

View File

@ -1,4 +1,6 @@
import asyncio import asyncio
import logging
import sys
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from dataclasses import asdict from dataclasses import asdict
@ -6,6 +8,7 @@ from typing import Optional
from aio_pika import connect_robust from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC from aio_pika.patterns import JsonRPC
from aiormq import AMQPConnectionError
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient from ..client.embedded_comfy_client import EmbeddedComfyClient
@ -28,36 +31,63 @@ class DistributedPromptWorker:
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient() self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
async def _do_work_item(self, request: dict) -> dict: async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request)
try: try:
request_obj = RpcRequest.from_dict(request) request_obj = RpcRequest.from_dict(request)
except Exception as e: except Exception as e:
request_dict_prompt_id_recovered = request["prompt_id"] \ request_dict_prompt_id_recovered = request["prompt_id"] \
if request is not None and "prompt_id" in request else "" if request is not None and "prompt_id" in request else ""
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {}, return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
ExecutionStatus("error", False, [str(e)]))) ExecutionStatus("error", False, [str(e)])))
reply: RpcReply
try: try:
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt, output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
request_obj.prompt_id, request_obj.prompt_id,
client_id=request_obj.user_id) client_id=request_obj.user_id)
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, []))) reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict,
ExecutionStatus("success", True, []))
except Exception as e: except Exception as e:
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)]))) reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {},
ExecutionStatus("error", False, [str(e)]))
async def __aenter__(self) -> "DistributedPromptWorker": await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply)
async def init(self):
await self._exit_stack.__aenter__() await self._exit_stack.__aenter__()
if not self._embedded_comfy_client.is_running: if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client) await self._exit_stack.enter_async_context(self._embedded_comfy_client)
self._connection = await connect_robust(self._connection_uri, loop=self._loop) try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error
self._channel = await self._connection.channel() self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel) self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True self._rpc.host_exceptions = True
await self._rpc.register(self._queue_name, self._do_work_item) await self._rpc.register(self._queue_name, self._do_work_item)
async def __aenter__(self) -> "DistributedPromptWorker":
await self.init()
return self return self
async def __aexit__(self, *args): async def _close(self):
await self._rpc.close() await self._rpc.close()
await self._channel.close() await self._channel.close()
await self._connection.close() await self._connection.close()
async def close(self):
await self._close()
await self._exit_stack.aclose()
async def __aexit__(self, *args):
await self._close()
return await self._exit_stack.__aexit__(*args) return await self._exit_stack.__aexit__(*args)
async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply):
pass
async def on_will_complete_work_item(self, request: dict):
pass

View File

@ -57,7 +57,6 @@ try:
print(f"comfyui setup.py: torch version was {torch.__version__} and built without build isolation, using this torch instead of upgrading", file=sys.stderr) print(f"comfyui setup.py: torch version was {torch.__version__} and built without build isolation, using this torch instead of upgrading", file=sys.stderr)
is_build_isolated_and_torch_version = torch.__version__ is_build_isolated_and_torch_version = torch.__version__
except Exception as e: except Exception as e:
print(f"comfyui setup.py: torch could not be imported because running with build isolation or not installed ({e}), installing torch for your platform", file=sys.stderr)
is_build_isolated_and_torch_version = None is_build_isolated_and_torch_version = None
def _is_nvidia() -> bool: def _is_nvidia() -> bool:
@ -186,7 +185,8 @@ setup(
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'comfyui-openapi-gen = comfy.cmd.openapi_gen:main', 'comfyui-openapi-gen = comfy.cmd.openapi_gen:main',
'comfyui = comfy.cmd.main:main' 'comfyui = comfy.cmd.main:entrypoint',
'comfyui-worker = comfy.cmd.worker:entrypoint'
], ],
}, },
package_data={ package_data={