mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
- Binary previews are not yet supported - Use `--distributed-queue-connection-uri=amqp://guest:guest@rabbitmqserver/` - Roles supported: frontend, worker or both (see `--help`) - Run `comfy-worker` for a lightweight worker you can wrap your head around - Workers and frontends must have the same directory structure (set with `--cwd`) and supported nodes. Frontends must still have access to inputs and outputs. - Configuration notes: distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates. distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL. distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
100 lines
4.0 KiB
Python
100 lines
4.0 KiB
Python
import asyncio
|
|
import logging
|
|
import sys
|
|
from asyncio import AbstractEventLoop
|
|
from contextlib import AsyncExitStack
|
|
from dataclasses import asdict
|
|
from typing import Optional
|
|
|
|
from aio_pika import connect_robust
|
|
from aio_pika.patterns import JsonRPC
|
|
from aiormq import AMQPConnectionError
|
|
|
|
from .distributed_progress import DistributedExecutorToClientProgress
|
|
from .distributed_types import RpcRequest, RpcReply
|
|
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
|
from ..component_model.queue_types import ExecutionStatus
|
|
|
|
|
|
class DistributedPromptWorker:
|
|
"""
|
|
A work in progress distributed prompt worker.
|
|
"""
|
|
|
|
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
|
|
connection_uri: str = "amqp://localhost:5672/",
|
|
queue_name: str = "comfyui",
|
|
loop: Optional[AbstractEventLoop] = None):
|
|
self._rpc = None
|
|
self._channel = None
|
|
self._exit_stack = AsyncExitStack()
|
|
self._queue_name = queue_name
|
|
self._connection_uri = connection_uri
|
|
self._loop = loop or asyncio.get_event_loop()
|
|
self._embedded_comfy_client = embedded_comfy_client
|
|
|
|
async def _do_work_item(self, request: dict) -> dict:
|
|
await self.on_will_complete_work_item(request)
|
|
try:
|
|
request_obj = RpcRequest.from_dict(request)
|
|
except Exception as e:
|
|
request_dict_prompt_id_recovered = request["prompt_id"] \
|
|
if request is not None and "prompt_id" in request else ""
|
|
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
|
|
ExecutionStatus("error", False, [str(e)])))
|
|
reply: RpcReply
|
|
try:
|
|
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
|
|
request_obj.prompt_id,
|
|
client_id=request_obj.user_id)
|
|
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict,
|
|
ExecutionStatus("success", True, []))
|
|
except Exception as e:
|
|
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {},
|
|
ExecutionStatus("error", False, [str(e)]))
|
|
|
|
await self.on_did_complete_work_item(request_obj, reply)
|
|
return asdict(reply)
|
|
|
|
async def init(self):
|
|
await self._exit_stack.__aenter__()
|
|
try:
|
|
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
|
except AMQPConnectionError as connection_error:
|
|
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
|
raise connection_error
|
|
self._channel = await self._connection.channel()
|
|
self._rpc = await JsonRPC.create(channel=self._channel)
|
|
self._rpc.host_exceptions = True
|
|
|
|
if self._embedded_comfy_client is None:
|
|
self._embedded_comfy_client = EmbeddedComfyClient(
|
|
progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
|
|
if not self._embedded_comfy_client.is_running:
|
|
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
|
|
|
await self._rpc.register(self._queue_name, self._do_work_item)
|
|
|
|
async def __aenter__(self) -> "DistributedPromptWorker":
|
|
await self.init()
|
|
return self
|
|
|
|
async def _close(self):
|
|
await self._rpc.close()
|
|
await self._channel.close()
|
|
await self._connection.close()
|
|
|
|
async def close(self):
|
|
await self._close()
|
|
await self._exit_stack.aclose()
|
|
|
|
async def __aexit__(self, *args):
|
|
await self._close()
|
|
return await self._exit_stack.__aexit__(*args)
|
|
|
|
async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply):
|
|
pass
|
|
|
|
async def on_will_complete_work_item(self, request: dict):
|
|
pass
|