ComfyUI/comfy/distributed/distributed_prompt_worker.py
doctorpangloss 80f8c40248 Distributed queueing with amqp-compatible servers like RabbitMQ.
- Binary previews are not yet supported
 - Use `--distributed-queue-connection-uri=amqp://guest:guest@rabbitmqserver/`
 - Roles supported: frontend, worker or both (see `--help`)
 - Run `comfy-worker` for a lightweight worker you can wrap your head
   around
 - Workers and frontends must have the same directory structure (set
   with `--cwd`) and supported nodes. Frontends must still have access
   to inputs and outputs.
 - Configuration notes:

   distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
   distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
   distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
2024-02-08 20:24:27 -08:00

100 lines
4.0 KiB
Python

import asyncio
import logging
import sys
from asyncio import AbstractEventLoop
from contextlib import AsyncExitStack
from dataclasses import asdict
from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC
from aiormq import AMQPConnectionError
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..component_model.queue_types import ExecutionStatus
class DistributedPromptWorker:
"""
A work in progress distributed prompt worker.
"""
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui",
loop: Optional[AbstractEventLoop] = None):
self._rpc = None
self._channel = None
self._exit_stack = AsyncExitStack()
self._queue_name = queue_name
self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client
async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request)
try:
request_obj = RpcRequest.from_dict(request)
except Exception as e:
request_dict_prompt_id_recovered = request["prompt_id"] \
if request is not None and "prompt_id" in request else ""
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
ExecutionStatus("error", False, [str(e)])))
reply: RpcReply
try:
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
request_obj.prompt_id,
client_id=request_obj.user_id)
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict,
ExecutionStatus("success", True, []))
except Exception as e:
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {},
ExecutionStatus("error", False, [str(e)]))
await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply)
async def init(self):
await self._exit_stack.__aenter__()
try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error
self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True
if self._embedded_comfy_client is None:
self._embedded_comfy_client = EmbeddedComfyClient(
progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
await self._rpc.register(self._queue_name, self._do_work_item)
async def __aenter__(self) -> "DistributedPromptWorker":
await self.init()
return self
async def _close(self):
await self._rpc.close()
await self._channel.close()
await self._connection.close()
async def close(self):
await self._close()
await self._exit_stack.aclose()
async def __aexit__(self, *args):
await self._close()
return await self._exit_stack.__aexit__(*args)
async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply):
pass
async def on_will_complete_work_item(self, request: dict):
pass