Distributed queueing with amqp-compatible servers like RabbitMQ.

- Binary previews are not yet supported
 - Use `--distributed-queue-connection-uri=amqp://guest:guest@rabbitmqserver/`
 - Roles supported: frontend, worker or both (see `--help`)
 - Run `comfy-worker` for a lightweight worker you can wrap your head
   around
 - Workers and frontends must have the same directory structure (set
   with `--cwd`) and supported nodes. Frontends must still have access
   to inputs and outputs.
 - Configuration notes:

   distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
   distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
   distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
This commit is contained in:
doctorpangloss 2024-02-08 20:24:27 -08:00
parent 0673262940
commit 80f8c40248
9 changed files with 161 additions and 30 deletions

0
comfy/auth/__init__.py Normal file
View File

14
comfy/auth/permissions.py Normal file
View File

@ -0,0 +1,14 @@
from typing import TypedDict
import jwt
class ComfyJwt(TypedDict, total=False):
sub: str
def jwt_decode(user_token: str) -> ComfyJwt:
# todo: set up a way for users to override this behavior easily
return ComfyJwt(**jwt.decode(user_token, algorithms=['HS256', "none"],
# todo: this should be configurable
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False}))

View File

@ -13,6 +13,7 @@ from ..component_model.make_mutable import make_mutable
from ..component_model.executor_types import ExecutorToClientProgress
from ..distributed.server_stub import ServerStub
_server_stub_instance = ServerStub()
class EmbeddedComfyClient:
"""
@ -117,14 +118,21 @@ class EmbeddedComfyClient:
client_id = client_id or self._progress_handler.client_id or None
def execute_prompt() -> dict:
from ..cmd.execution import validate_prompt
from ..cmd.execution import PromptExecutor, validate_prompt
prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut)
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
prompt_executor: PromptExecutor = self._prompt_executor
if client_id is None:
prompt_executor.server = _server_stub_instance
else:
prompt_executor.server = self._progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2])
if self._prompt_executor.success:
return self._prompt_executor.outputs_ui
if prompt_executor.success:
return prompt_executor.outputs_ui
else:
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))

View File

@ -227,7 +227,7 @@ async def main():
if args.distributed_queue_connection_uri is not None:
distributed = True
q = DistributedPromptQueue(
caller_server=server if "worker" in args.distributed_queue_roles else None,
caller_server=server if "frontend" in args.distributed_queue_roles else None,
connection_uri=args.distributed_queue_connection_uri,
is_caller="frontend" in args.distributed_queue_roles,
is_callee="worker" in args.distributed_queue_roles,

View File

@ -1,7 +1,7 @@
from __future__ import annotations # for Python 3.7-3.9
from typing_extensions import NotRequired, TypedDict
from typing import Optional, Literal, Protocol
from typing import Optional, Literal, Protocol, TypeAlias, Union
from comfy.component_model.queue_types import BinaryEventTypes
@ -32,6 +32,13 @@ class ProgressMessage(TypedDict):
node: Optional[str]
ExecutedMessage: TypeAlias = ExecutingMessage
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None]
class ExecutorToClientProgress(Protocol):
"""
Specifies the interface for the dependencies a prompt executor needs from a server.
@ -47,8 +54,9 @@ class ExecutorToClientProgress(Protocol):
last_prompt_id: Optional[str]
def send_sync(self,
event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None,
data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None):
event: SendSyncEvent,
data: SendSyncData,
sid: Optional[str] = None):
"""
Sends feedback to the client with the specified ID about a specific node

View File

@ -0,0 +1,84 @@
from __future__ import annotations
import asyncio
from asyncio import AbstractEventLoop
from functools import partial
from typing import Optional, Dict, Any
from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
from ..component_model.queue_types import BinaryEventTypes
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
caller_server: Optional[ExecutorToClientProgress] = None) -> None:
assert caller_server is not None
assert user_id is not None
caller_server.send_sync(event, data, sid=user_id)
def _get_name(queue_name: str, user_id: str) -> str:
return f"{queue_name}.{user_id}.progress"
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop):
self._rpc = rpc
self._queue_name = queue_name
self._loop = loop
self.client_id = None
self.node_id = None
self.last_node_id = None
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
return
if isinstance(data, bytes) or isinstance(data, bytearray):
return
if user_id is None:
# todo: user_id should never be none here
return
await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data})
def send_sync(self,
event: SendSyncEvent,
data: SendSyncData,
sid: Optional[str] = None):
asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop)
def queue_updated(self):
# todo: this should gather the global queue data
pass
class ProgressHandlers:
def __init__(self, rpc: RPC, caller_server: Optional[ExecutorToClientProgress], queue_name: str):
self._rpc = rpc
self._caller_server = caller_server
self._progress_handlers: Dict[str, Any] = {}
self._queue_name = queue_name
async def register_progress(self, user_id: str):
if user_id in self._progress_handlers:
return
handler = partial(_progress, user_id=user_id, caller_server=self._caller_server)
self._progress_handlers[user_id] = handler
await self._rpc.register(_get_name(self._queue_name, user_id), handler)
async def unregister_progress(self, user_id: str):
if user_id not in self._progress_handlers:
return
handler = self._progress_handlers.pop(user_id)
await self._rpc.unregister(handler)
async def unregister_all(self):
await asyncio.gather(*[self._rpc.unregister(handler) for handler in self._progress_handlers.values()])
self._progress_handlers.clear()

View File

@ -1,12 +1,10 @@
from __future__ import annotations
import asyncio
import threading
import time
import uuid
from asyncio import AbstractEventLoop, Queue, QueueEmpty
from dataclasses import asdict
from functools import partial
from time import sleep
from typing import Optional, Dict, List, Mapping, Tuple, Callable
@ -15,10 +13,12 @@ from aio_pika import connect_robust
from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import JsonRPC
from .distributed_progress import ProgressHandlers
from .distributed_types import RpcRequest, RpcReply
from .server_stub import ServerStub
from ..auth.permissions import jwt_decode
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
from .history import History
from ..cmd.server import PromptServer
@ -26,7 +26,7 @@ from ..cmd.server import PromptServer
class DistributedPromptQueue(AbstractPromptQueue):
"""
A distributed prompt queue for
A distributed prompt queue for the ComfyUI web client and single-threaded worker.
"""
def size(self) -> int:
@ -36,8 +36,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
"""
return len(self._caller_local_in_progress)
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
self._caller_server.send_sync(event, data, sid=sid)
async def put_async(self, queue_item: QueueItem):
assert self._is_caller
assert self._rpc is not None
if self._closing:
return
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
@ -46,23 +51,26 @@ class DistributedPromptQueue(AbstractPromptQueue):
try:
if "token" in queue_item.extra_data:
user_token = queue_item.extra_data["token"]
user_id = jwt_decode(user_token)["sub"]
else:
if "client_id" in queue_item.extra_data:
client_id = queue_item.extra_data["client_id"]
user_id = queue_item.extra_data["client_id"]
elif self._caller_server.client_id is not None:
client_id = self._caller_server.client_id
user_id = self._caller_server.client_id
else:
client_id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
# todo: should we really do this?
self._caller_server.client_id = client_id
self._caller_server.client_id = user_id
# create a stub token
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
user_token = jwt.encode({"sub": user_id}, key="", algorithm="none")
# register callbacks for progress
assert self._caller_progress_handlers is not None
await self._caller_progress_handlers.register_progress(user_id)
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
assert self._rpc is not None
res: TaskInvocation = RpcReply(
**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
self._caller_history.put(queue_item, res.outputs, res.status)
if self._caller_server is not None:
self._caller_server.queue_updated()
@ -241,6 +249,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
# as rpc caller
self._caller_server = caller_server or ServerStub()
self._caller_progress_handlers: Optional[ProgressHandlers] = None
self._caller_local_in_progress: dict[str | int, QueueItem] = {}
self._caller_history: History = History()
@ -263,6 +272,8 @@ class DistributedPromptQueue(AbstractPromptQueue):
self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True
if self._is_caller:
self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name)
# this makes the queue available to complete work items
if self._is_callee:
await self._rpc.register(self._queue_name, self._callee_do_work_item)
@ -273,6 +284,9 @@ class DistributedPromptQueue(AbstractPromptQueue):
return
self._closing = True
if self._is_caller:
await self._caller_progress_handlers.unregister_all()
await self._rpc.close()
await self._channel.close()
await self._connection.close()

View File

@ -10,6 +10,7 @@ from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC
from aiormq import AMQPConnectionError
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..component_model.queue_types import ExecutionStatus
@ -24,11 +25,13 @@ class DistributedPromptWorker:
connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui",
loop: Optional[AbstractEventLoop] = None):
self._rpc = None
self._channel = None
self._exit_stack = AsyncExitStack()
self._queue_name = queue_name
self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
self._embedded_comfy_client = embedded_comfy_client
async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request)
@ -55,9 +58,6 @@ class DistributedPromptWorker:
async def init(self):
await self._exit_stack.__aenter__()
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error:
@ -67,6 +67,12 @@ class DistributedPromptWorker:
self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True
if self._embedded_comfy_client is None:
self._embedded_comfy_client = EmbeddedComfyClient(
progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
await self._rpc.register(self._queue_name, self._do_work_item)
async def __aenter__(self) -> "DistributedPromptWorker":

View File

@ -3,9 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, Literal, List
import jwt
from ..api.components.schema.prompt import PromptDict, Prompt
from ..auth.permissions import ComfyJwt, jwt_decode
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
@ -19,10 +18,8 @@ class DistributedBase:
return self.decoded_token["sub"]
@property
def decoded_token(self) -> dict:
return jwt.decode(self.user_token, algorithms=['HS256', "none"],
# todo: this should be configurable
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})
def decoded_token(self) -> ComfyJwt:
return jwt_decode(self.user_token)
@dataclass