Fix entrypoints, add comfyui-worker entrypoint

This commit is contained in:
doctorpangloss 2024-02-08 19:08:42 -08:00
parent 72e92514a4
commit 0673262940
7 changed files with 80 additions and 14 deletions

View File

@ -11,7 +11,9 @@ from ..api.components.schema.prompt import Prompt
_event_tracker: MultiEventTracker
def initialize_event_tracking(loop: asyncio.AbstractEventLoop):
def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
loop = loop or asyncio.get_event_loop()
assert loop is not None
_event_trackers = []
# perform the imports at the time this is invoked to prevent side effects and ordering issues
from ..cli_args import args

View File

@ -120,7 +120,7 @@ parser.add_argument("--plausible-analytics-domain", required=False,
parser.add_argument("--analytics-use-identity-provider", action="store_true",
help="Uses platform identifiers for unique visitor analytics.")
parser.add_argument("--distributed-queue-connection-uri", type=str, default=None,
help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
help="EXAMPLE: \"amqp://guest:guest@127.0.0.1\" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
parser.add_argument(
'--distributed-queue-roles',
action='append',

View File

@ -235,7 +235,6 @@ async def main():
queue_name=args.distributed_queue_name
)
await q.init()
loop.add_signal_handler(signal.SIGINT, lambda *args, **kwargs: q.close())
else:
distributed = False
q = execution.PromptQueue(server)
@ -260,7 +259,8 @@ async def main():
# the distributed prompt queue will be responsible for simulating those events until the broker is configured to
# pass those messages to the appropriate user
worker_thread_server = server if not distributed else ServerStub()
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
if "worker" in args.distributed_queue_roles:
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
# server has been imported and things should be looking good
initialize_event_tracking(loop)
@ -305,5 +305,9 @@ async def main():
cleanup_temp()
if __name__ == "__main__":
def entrypoint():
asyncio.run(main())
if __name__ == "__main__":
entrypoint()

View File

@ -782,7 +782,7 @@ class PromptServer(ExecutorToClientProgress):
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
print("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
if call_on_start is not None:
call_on_start(address, port)

30
comfy/cmd/worker.py Normal file
View File

@ -0,0 +1,30 @@
import asyncio
from .. import options
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
options.enable_args_parsing()
from ..cli_args import args
async def main():
# assume we are a worker
args.distributed_queue_roles = ["worker"]
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name):
stop = asyncio.Event()
try:
await stop.wait()
except asyncio.CancelledError:
pass
def entrypoint():
asyncio.run(main())
if __name__ == "__main__":
entrypoint()

View File

@ -1,4 +1,6 @@
import asyncio
import logging
import sys
from asyncio import AbstractEventLoop
from contextlib import AsyncExitStack
from dataclasses import asdict
@ -6,6 +8,7 @@ from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC
from aiormq import AMQPConnectionError
from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient
@ -28,36 +31,63 @@ class DistributedPromptWorker:
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request)
try:
request_obj = RpcRequest.from_dict(request)
except Exception as e:
request_dict_prompt_id_recovered = request["prompt_id"] \
if request is not None and "prompt_id" in request else ""
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
ExecutionStatus("error", False, [str(e)])))
ExecutionStatus("error", False, [str(e)])))
reply: RpcReply
try:
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
request_obj.prompt_id,
client_id=request_obj.user_id)
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, [])))
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict,
ExecutionStatus("success", True, []))
except Exception as e:
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)])))
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {},
ExecutionStatus("error", False, [str(e)]))
async def __aenter__(self) -> "DistributedPromptWorker":
await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply)
async def init(self):
await self._exit_stack.__aenter__()
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error
self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True
await self._rpc.register(self._queue_name, self._do_work_item)
async def __aenter__(self) -> "DistributedPromptWorker":
await self.init()
return self
async def __aexit__(self, *args):
async def _close(self):
await self._rpc.close()
await self._channel.close()
await self._connection.close()
async def close(self):
await self._close()
await self._exit_stack.aclose()
async def __aexit__(self, *args):
await self._close()
return await self._exit_stack.__aexit__(*args)
async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply):
pass
async def on_will_complete_work_item(self, request: dict):
pass

View File

@ -57,7 +57,6 @@ try:
print(f"comfyui setup.py: torch version was {torch.__version__} and built without build isolation, using this torch instead of upgrading", file=sys.stderr)
is_build_isolated_and_torch_version = torch.__version__
except Exception as e:
print(f"comfyui setup.py: torch could not be imported because running with build isolation or not installed ({e}), installing torch for your platform", file=sys.stderr)
is_build_isolated_and_torch_version = None
def _is_nvidia() -> bool:
@ -186,7 +185,8 @@ setup(
entry_points={
'console_scripts': [
'comfyui-openapi-gen = comfy.cmd.openapi_gen:main',
'comfyui = comfy.cmd.main:main'
'comfyui = comfy.cmd.main:entrypoint',
'comfyui-worker = comfy.cmd.worker:entrypoint'
],
},
package_data={