ComfyUI/comfy/distributed/distributed_prompt_worker.py
2025-05-07 14:53:39 -07:00

157 lines
6.3 KiB
Python

import asyncio
import logging
from asyncio import AbstractEventLoop
from contextlib import AsyncExitStack
from dataclasses import asdict
from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC
from aiohttp import web
from aiormq import AMQPConnectionError
from .executors import ContextVarExecutor
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply
from .process_pool_executor import ProcessPoolExecutor
from ..client.embedded_comfy_client import Comfy
from ..cmd.main_pre import tracer
from ..component_model.queue_types import ExecutionStatus
class DistributedPromptWorker:
"""
A distributed prompt worker.
"""
def __init__(self, embedded_comfy_client: Optional[Comfy] = None,
connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui",
health_check_port: int = 9090,
loop: Optional[AbstractEventLoop] = None,
executor: Optional[ContextVarExecutor | ProcessPoolExecutor] = None):
self._rpc = None
self._channel = None
self._exit_stack = AsyncExitStack()
self._queue_name = queue_name
self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop()
self._executor = executor
self._embedded_comfy_client = embedded_comfy_client
self._health_check_port = health_check_port
self._health_check_site: Optional[web.TCPSite] = None
async def _health_check(self, request):
if not hasattr(self, "_connection") or self._connection is None:
return web.Response(text="UNHEALTHY: RabbitMQ connection is not established", status=503)
is_healthy = await self._is_connection_healthy()
if is_healthy:
return web.Response(text="HEALTHY", status=200)
else:
return web.Response(text="UNHEALTHY: RabbitMQ connection is not healthy", status=503)
async def _start_health_check_server(self):
app = web.Application()
app.router.add_get('/health', self._health_check)
runner = web.AppRunner(app)
await runner.setup()
try:
site = web.TCPSite(runner, port=self._health_check_port)
await site.start()
self._health_check_site = site
logging.info(f"health check server started on port {self._health_check_port}")
except OSError as e:
if e.errno == 98:
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
else:
logging.error(f"failed to start health check server with error {str(e)}, starting anyway")
@tracer.start_as_current_span("Do Work Item")
async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request)
try:
request_obj = RpcRequest.from_dict(request)
except Exception as e:
request_dict_prompt_id_recovered = request["prompt_id"] \
if request is not None and "prompt_id" in request else ""
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
ExecutionStatus("error", False, [str(e)])))
reply: RpcReply
try:
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
request_obj.prompt_id,
client_id=request_obj.user_id)
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict,
ExecutionStatus("success", True, []))
except Exception as e:
reply = RpcReply(request_obj.prompt_id, request_obj.user_token, {},
ExecutionStatus("error", False, [str(e)]))
await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply)
async def _is_connection_healthy(self):
if self._connection is None:
return False
return (
not self._connection.is_closed
and self._connection.connected.is_set()
and await self._check_connection_ready()
)
async def _check_connection_ready(self):
try:
await asyncio.wait_for(self._connection.ready(), timeout=1.0)
return True
except asyncio.TimeoutError:
return False
@tracer.start_as_current_span("Initialize Prompt Worker")
async def init(self):
await self._exit_stack.__aenter__()
await self._start_health_check_server()
try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
if self._embedded_comfy_client is None:
self._embedded_comfy_client = Comfy(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop), executor=self._executor)
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
await self._rpc.register(self._queue_name, self._do_work_item)
async def __aenter__(self) -> "DistributedPromptWorker":
await self.init()
return self
async def _close(self):
await self._rpc.close()
await self._channel.close()
await self._connection.close()
if self._health_check_site:
await self._health_check_site.stop()
async def close(self):
await self._close()
await self._exit_stack.aclose()
async def __aexit__(self, *args):
await self._close()
return await self._exit_stack.__aexit__(*args)
async def on_did_complete_work_item(self, request: RpcRequest, reply: RpcReply):
pass
async def on_will_complete_work_item(self, request: dict):
pass