mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fix entrypoints, add comfyui-worker entrypoint
This commit is contained in:
parent
72e92514a4
commit
0673262940
@ -11,7 +11,9 @@ from ..api.components.schema.prompt import Prompt
|
||||
_event_tracker: MultiEventTracker
|
||||
|
||||
|
||||
def initialize_event_tracking(loop: asyncio.AbstractEventLoop):
|
||||
def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
assert loop is not None
|
||||
_event_trackers = []
|
||||
# perform the imports at the time this is invoked to prevent side effects and ordering issues
|
||||
from ..cli_args import args
|
||||
|
||||
@ -120,7 +120,7 @@ parser.add_argument("--plausible-analytics-domain", required=False,
|
||||
parser.add_argument("--analytics-use-identity-provider", action="store_true",
|
||||
help="Uses platform identifiers for unique visitor analytics.")
|
||||
parser.add_argument("--distributed-queue-connection-uri", type=str, default=None,
|
||||
help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
|
||||
help="EXAMPLE: \"amqp://guest:guest@127.0.0.1\" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
|
||||
parser.add_argument(
|
||||
'--distributed-queue-roles',
|
||||
action='append',
|
||||
|
||||
@ -235,7 +235,6 @@ async def main():
|
||||
queue_name=args.distributed_queue_name
|
||||
)
|
||||
await q.init()
|
||||
loop.add_signal_handler(signal.SIGINT, lambda *args, **kwargs: q.close())
|
||||
else:
|
||||
distributed = False
|
||||
q = execution.PromptQueue(server)
|
||||
@ -260,7 +259,8 @@ async def main():
|
||||
# the distributed prompt queue will be responsible for simulating those events until the broker is configured to
|
||||
# pass those messages to the appropriate user
|
||||
worker_thread_server = server if not distributed else ServerStub()
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
|
||||
if "worker" in args.distributed_queue_roles:
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
|
||||
|
||||
# server has been imported and things should be looking good
|
||||
initialize_event_tracking(loop)
|
||||
@ -305,5 +305,9 @@ async def main():
|
||||
cleanup_temp()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def entrypoint():
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
entrypoint()
|
||||
|
||||
@ -782,7 +782,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
if verbose:
|
||||
print("Starting server\n")
|
||||
print("To see the GUI go to: http://{}:{}".format(address, port))
|
||||
print("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
|
||||
if call_on_start is not None:
|
||||
call_on_start(address, port)
|
||||
|
||||
|
||||
30
comfy/cmd/worker.py
Normal file
30
comfy/cmd/worker.py
Normal file
@ -0,0 +1,30 @@
|
||||
import asyncio
|
||||
|
||||
from .. import options
|
||||
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
|
||||
options.enable_args_parsing()
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
|
||||
async def main():
|
||||
# assume we are a worker
|
||||
args.distributed_queue_roles = ["worker"]
|
||||
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
|
||||
|
||||
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
||||
queue_name=args.distributed_queue_name):
|
||||
stop = asyncio.Event()
|
||||
try:
|
||||
await stop.wait()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
def entrypoint():
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
entrypoint()
|
||||
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from asyncio import AbstractEventLoop
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import asdict
|
||||
@ -6,6 +8,7 @@ from typing import Optional
|
||||
|
||||
from aio_pika import connect_robust
|
||||
from aio_pika.patterns import JsonRPC
|
||||
from aiormq import AMQPConnectionError
|
||||
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||
@ -28,36 +31,63 @@ class DistributedPromptWorker:
|
||||
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
|
||||
|
||||
async def _do_work_item(self, request: dict) -> dict:
|
||||
await self.on_will_complete_work_item(request)
|
||||
try:
|
||||
request_obj = RpcRequest.from_dict(request)
|
||||
except Exception as e:
|
||||
request_dict_prompt_id_recovered = request["prompt_id"] \
|
||||
if request is not None and "prompt_id" in request else ""
|
||||
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
|
||||
ExecutionStatus("error", False, [str(e)])))
|
||||
ExecutionStatus("error", False, [str(e)])))
|
||||
reply: RpcReply
|
||||
try:
|
||||
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
|
||||
request_obj.prompt_id,
|
||||
client_id=request_obj.user_id)
|
||||
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, [])))
|
||||
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict,
|
||||
ExecutionStatus("success", True, []))
|
||||
except Exception as e:
|
||||
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)])))
|
||||
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {},
|
||||
ExecutionStatus("error", False, [str(e)]))
|
||||
|
||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||
await self.on_did_complete_work_item(request_obj, reply)
|
||||
return asdict(reply)
|
||||
|
||||
async def init(self):
|
||||
await self._exit_stack.__aenter__()
|
||||
if not self._embedded_comfy_client.is_running:
|
||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||
|
||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||
try:
|
||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||
except AMQPConnectionError as connection_error:
|
||||
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
||||
raise connection_error
|
||||
self._channel = await self._connection.channel()
|
||||
self._rpc = await JsonRPC.create(channel=self._channel)
|
||||
self._rpc.host_exceptions = True
|
||||
|
||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||
|
||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||
await self.init()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
async def _close(self):
|
||||
await self._rpc.close()
|
||||
await self._channel.close()
|
||||
await self._connection.close()
|
||||
|
||||
async def close(self):
|
||||
await self._close()
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await self._close()
|
||||
return await self._exit_stack.__aexit__(*args)
|
||||
|
||||
async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply):
|
||||
pass
|
||||
|
||||
async def on_will_complete_work_item(self, request: dict):
|
||||
pass
|
||||
|
||||
4
setup.py
4
setup.py
@ -57,7 +57,6 @@ try:
|
||||
print(f"comfyui setup.py: torch version was {torch.__version__} and built without build isolation, using this torch instead of upgrading", file=sys.stderr)
|
||||
is_build_isolated_and_torch_version = torch.__version__
|
||||
except Exception as e:
|
||||
print(f"comfyui setup.py: torch could not be imported because running with build isolation or not installed ({e}), installing torch for your platform", file=sys.stderr)
|
||||
is_build_isolated_and_torch_version = None
|
||||
|
||||
def _is_nvidia() -> bool:
|
||||
@ -186,7 +185,8 @@ setup(
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'comfyui-openapi-gen = comfy.cmd.openapi_gen:main',
|
||||
'comfyui = comfy.cmd.main:main'
|
||||
'comfyui = comfy.cmd.main:entrypoint',
|
||||
'comfyui-worker = comfy.cmd.worker:entrypoint'
|
||||
],
|
||||
},
|
||||
package_data={
|
||||
|
||||
Loading…
Reference in New Issue
Block a user