ComfyUI/comfy/distributed/distributed_prompt_worker.py
2024-02-08 19:08:42 -08:00

94 lines
3.7 KiB
Python

import asyncio
import logging
import sys
from asyncio import AbstractEventLoop
from contextlib import AsyncExitStack
from dataclasses import asdict
from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC
from aiormq import AMQPConnectionError
from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..component_model.queue_types import ExecutionStatus
class DistributedPromptWorker:
"""
A work in progress distributed prompt worker.
"""
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui",
loop: Optional[AbstractEventLoop] = None):
self._exit_stack = AsyncExitStack()
self._queue_name = queue_name
self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request)
try:
request_obj = RpcRequest.from_dict(request)
except Exception as e:
request_dict_prompt_id_recovered = request["prompt_id"] \
if request is not None and "prompt_id" in request else ""
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
ExecutionStatus("error", False, [str(e)])))
reply: RpcReply
try:
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
request_obj.prompt_id,
client_id=request_obj.user_id)
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict,
ExecutionStatus("success", True, []))
except Exception as e:
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {},
ExecutionStatus("error", False, [str(e)]))
await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply)
async def init(self):
await self._exit_stack.__aenter__()
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error
self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True
await self._rpc.register(self._queue_name, self._do_work_item)
async def __aenter__(self) -> "DistributedPromptWorker":
await self.init()
return self
async def _close(self):
await self._rpc.close()
await self._channel.close()
await self._connection.close()
async def close(self):
await self._close()
await self._exit_stack.aclose()
async def __aexit__(self, *args):
await self._close()
return await self._exit_stack.__aexit__(*args)
async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply):
pass
async def on_will_complete_work_item(self, request: dict):
pass