From 80f8c4024888b2ca5dafbe9d3e34d86a4a7b9d1f Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 8 Feb 2024 20:24:27 -0800 Subject: [PATCH] Distributed queueing with amqp-compatible servers like RabbitMQ. - Binary previews are not yet supported - Use `--distributed-queue-connection-uri=amqp://guest:guest@rabbitmqserver/` - Roles supported: frontend, worker or both (see `--help`) - Run `comfy-worker` for a lightweight worker you can wrap your head around - Workers and frontends must have the same directory structure (set with `--cwd`) and supported nodes. Frontends must still have access to inputs and outputs. - Configuration notes: distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates. distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL. distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID. --- comfy/auth/__init__.py | 0 comfy/auth/permissions.py | 14 ++++ comfy/client/embedded_comfy_client.py | 18 ++-- comfy/cmd/main.py | 2 +- comfy/component_model/executor_types.py | 14 +++- comfy/distributed/distributed_progress.py | 84 +++++++++++++++++++ comfy/distributed/distributed_prompt_queue.py | 36 +++++--- .../distributed/distributed_prompt_worker.py | 14 +++- comfy/distributed/distributed_types.py | 9 +- 9 files changed, 161 insertions(+), 30 deletions(-) create mode 100644 comfy/auth/__init__.py create mode 100644 comfy/auth/permissions.py create mode 100644 comfy/distributed/distributed_progress.py diff --git a/comfy/auth/__init__.py b/comfy/auth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/auth/permissions.py b/comfy/auth/permissions.py new file mode 100644 index 000000000..d96533677 --- /dev/null +++ b/comfy/auth/permissions.py @@ -0,0 +1,14 @@ +from typing import TypedDict + +import jwt + + +class ComfyJwt(TypedDict, total=False): + sub: str + + +def jwt_decode(user_token: str) -> ComfyJwt: + # todo: set up a way for users to override this behavior easily + return ComfyJwt(**jwt.decode(user_token, algorithms=['HS256', "none"], + # todo: this should be configurable + options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})) diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index a5bdf6b51..e4271a98a 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -13,6 +13,7 @@ from ..component_model.make_mutable import make_mutable from ..component_model.executor_types import ExecutorToClientProgress from ..distributed.server_stub import ServerStub +_server_stub_instance = ServerStub() class EmbeddedComfyClient: """ @@ -117,14 +118,21 @@ class EmbeddedComfyClient: client_id = client_id or self._progress_handler.client_id or None def execute_prompt() -> dict: - from ..cmd.execution import validate_prompt + from ..cmd.execution import PromptExecutor, validate_prompt prompt_mut = make_mutable(prompt) validation_tuple = validate_prompt(prompt_mut) - self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, - execute_outputs=validation_tuple[2]) - if self._prompt_executor.success: - return self._prompt_executor.outputs_ui + prompt_executor: PromptExecutor = self._prompt_executor + + if client_id is None: + prompt_executor.server = _server_stub_instance + else: + prompt_executor.server = self._progress_handler + + prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, + execute_outputs=validation_tuple[2]) + if prompt_executor.success: + return prompt_executor.outputs_ui else: raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages)) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index cba09c61c..4d1e82e8f 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -227,7 +227,7 @@ async def main(): if args.distributed_queue_connection_uri is not None: distributed = True q = DistributedPromptQueue( - caller_server=server if "worker" in args.distributed_queue_roles else None, + caller_server=server if "frontend" in args.distributed_queue_roles else None, connection_uri=args.distributed_queue_connection_uri, is_caller="frontend" in args.distributed_queue_roles, is_callee="worker" in args.distributed_queue_roles, diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index c1de2d9df..151d085ca 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -1,7 +1,7 @@ from __future__ import annotations # for Python 3.7-3.9 from typing_extensions import NotRequired, TypedDict -from typing import Optional, Literal, Protocol +from typing import Optional, Literal, Protocol, TypeAlias, Union from comfy.component_model.queue_types import BinaryEventTypes @@ -32,6 +32,13 @@ class ProgressMessage(TypedDict): node: Optional[str] +ExecutedMessage: TypeAlias = ExecutingMessage + +SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None] + +SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None] + + class ExecutorToClientProgress(Protocol): """ Specifies the interface for the dependencies a prompt executor needs from a server. @@ -47,8 +54,9 @@ class ExecutorToClientProgress(Protocol): last_prompt_id: Optional[str] def send_sync(self, - event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None, - data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None): + event: SendSyncEvent, + data: SendSyncData, + sid: Optional[str] = None): """ Sends feedback to the client with the specified ID about a specific node diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py new file mode 100644 index 000000000..048b15c18 --- /dev/null +++ b/comfy/distributed/distributed_progress.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import asyncio +from asyncio import AbstractEventLoop +from functools import partial + +from typing import Optional, Dict, Any + +from aio_pika.patterns import RPC + +from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress +from ..component_model.queue_types import BinaryEventTypes + + +async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None, + caller_server: Optional[ExecutorToClientProgress] = None) -> None: + assert caller_server is not None + assert user_id is not None + caller_server.send_sync(event, data, sid=user_id) + + +def _get_name(queue_name: str, user_id: str) -> str: + return f"{queue_name}.{user_id}.progress" + + +class DistributedExecutorToClientProgress(ExecutorToClientProgress): + def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop): + self._rpc = rpc + self._queue_name = queue_name + self._loop = loop + + self.client_id = None + self.node_id = None + self.last_node_id = None + + async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: + # for now, do not send binary data this way, since it cannot be json serialized / it's impractical + if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + return + + if isinstance(data, bytes) or isinstance(data, bytearray): + return + + if user_id is None: + # todo: user_id should never be none here + return + + await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data}) + + def send_sync(self, + event: SendSyncEvent, + data: SendSyncData, + sid: Optional[str] = None): + asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop) + + def queue_updated(self): + # todo: this should gather the global queue data + pass + + +class ProgressHandlers: + def __init__(self, rpc: RPC, caller_server: Optional[ExecutorToClientProgress], queue_name: str): + self._rpc = rpc + self._caller_server = caller_server + self._progress_handlers: Dict[str, Any] = {} + self._queue_name = queue_name + + async def register_progress(self, user_id: str): + if user_id in self._progress_handlers: + return + + handler = partial(_progress, user_id=user_id, caller_server=self._caller_server) + self._progress_handlers[user_id] = handler + await self._rpc.register(_get_name(self._queue_name, user_id), handler) + + async def unregister_progress(self, user_id: str): + if user_id not in self._progress_handlers: + return + handler = self._progress_handlers.pop(user_id) + await self._rpc.unregister(handler) + + async def unregister_all(self): + await asyncio.gather(*[self._rpc.unregister(handler) for handler in self._progress_handlers.values()]) + self._progress_handlers.clear() diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 1ae74492d..8d3c810bd 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -1,12 +1,10 @@ from __future__ import annotations import asyncio -import threading import time import uuid from asyncio import AbstractEventLoop, Queue, QueueEmpty from dataclasses import asdict -from functools import partial from time import sleep from typing import Optional, Dict, List, Mapping, Tuple, Callable @@ -15,10 +13,12 @@ from aio_pika import connect_robust from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.patterns import JsonRPC +from .distributed_progress import ProgressHandlers from .distributed_types import RpcRequest, RpcReply from .server_stub import ServerStub +from ..auth.permissions import jwt_decode from ..component_model.abstract_prompt_queue import AbstractPromptQueue -from ..component_model.executor_types import ExecutorToClientProgress +from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation from .history import History from ..cmd.server import PromptServer @@ -26,7 +26,7 @@ from ..cmd.server import PromptServer class DistributedPromptQueue(AbstractPromptQueue): """ - A distributed prompt queue for + A distributed prompt queue for the ComfyUI web client and single-threaded worker. """ def size(self) -> int: @@ -36,8 +36,13 @@ class DistributedPromptQueue(AbstractPromptQueue): """ return len(self._caller_local_in_progress) + async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None: + self._caller_server.send_sync(event, data, sid=sid) + async def put_async(self, queue_item: QueueItem): assert self._is_caller + assert self._rpc is not None + if self._closing: return self._caller_local_in_progress[queue_item.prompt_id] = queue_item @@ -46,23 +51,26 @@ class DistributedPromptQueue(AbstractPromptQueue): try: if "token" in queue_item.extra_data: user_token = queue_item.extra_data["token"] + user_id = jwt_decode(user_token)["sub"] else: if "client_id" in queue_item.extra_data: - client_id = queue_item.extra_data["client_id"] + user_id = queue_item.extra_data["client_id"] elif self._caller_server.client_id is not None: - client_id = self._caller_server.client_id + user_id = self._caller_server.client_id else: - client_id = str(uuid.uuid4()) + user_id = str(uuid.uuid4()) # todo: should we really do this? - self._caller_server.client_id = client_id + self._caller_server.client_id = user_id # create a stub token - user_token = jwt.encode({"sub": client_id}, key="", algorithm="none") + user_token = jwt.encode({"sub": user_id}, key="", algorithm="none") + + # register callbacks for progress + assert self._caller_progress_handlers is not None + await self._caller_progress_handlers.register_progress(user_id) request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt) - assert self._rpc is not None res: TaskInvocation = RpcReply( **(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation() - self._caller_history.put(queue_item, res.outputs, res.status) if self._caller_server is not None: self._caller_server.queue_updated() @@ -241,6 +249,7 @@ class DistributedPromptQueue(AbstractPromptQueue): # as rpc caller self._caller_server = caller_server or ServerStub() + self._caller_progress_handlers: Optional[ProgressHandlers] = None self._caller_local_in_progress: dict[str | int, QueueItem] = {} self._caller_history: History = History() @@ -263,6 +272,8 @@ class DistributedPromptQueue(AbstractPromptQueue): self._channel = await self._connection.channel() self._rpc = await JsonRPC.create(channel=self._channel) self._rpc.host_exceptions = True + if self._is_caller: + self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name) # this makes the queue available to complete work items if self._is_callee: await self._rpc.register(self._queue_name, self._callee_do_work_item) @@ -273,6 +284,9 @@ class DistributedPromptQueue(AbstractPromptQueue): return self._closing = True + if self._is_caller: + await self._caller_progress_handlers.unregister_all() + await self._rpc.close() await self._channel.close() await self._connection.close() diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index e2b335f59..e37043353 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -10,6 +10,7 @@ from aio_pika import connect_robust from aio_pika.patterns import JsonRPC from aiormq import AMQPConnectionError +from .distributed_progress import DistributedExecutorToClientProgress from .distributed_types import RpcRequest, RpcReply from ..client.embedded_comfy_client import EmbeddedComfyClient from ..component_model.queue_types import ExecutionStatus @@ -24,11 +25,13 @@ class DistributedPromptWorker: connection_uri: str = "amqp://localhost:5672/", queue_name: str = "comfyui", loop: Optional[AbstractEventLoop] = None): + self._rpc = None + self._channel = None self._exit_stack = AsyncExitStack() self._queue_name = queue_name self._connection_uri = connection_uri self._loop = loop or asyncio.get_event_loop() - self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient() + self._embedded_comfy_client = embedded_comfy_client async def _do_work_item(self, request: dict) -> dict: await self.on_will_complete_work_item(request) @@ -55,9 +58,6 @@ class DistributedPromptWorker: async def init(self): await self._exit_stack.__aenter__() - if not self._embedded_comfy_client.is_running: - await self._exit_stack.enter_async_context(self._embedded_comfy_client) - try: self._connection = await connect_robust(self._connection_uri, loop=self._loop) except AMQPConnectionError as connection_error: @@ -67,6 +67,12 @@ class DistributedPromptWorker: self._rpc = await JsonRPC.create(channel=self._channel) self._rpc.host_exceptions = True + if self._embedded_comfy_client is None: + self._embedded_comfy_client = EmbeddedComfyClient( + progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop)) + if not self._embedded_comfy_client.is_running: + await self._exit_stack.enter_async_context(self._embedded_comfy_client) + await self._rpc.register(self._queue_name, self._do_work_item) async def __aenter__(self) -> "DistributedPromptWorker": diff --git a/comfy/distributed/distributed_types.py b/comfy/distributed/distributed_types.py index 78fcd29d0..35c420d66 100644 --- a/comfy/distributed/distributed_types.py +++ b/comfy/distributed/distributed_types.py @@ -3,9 +3,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Tuple, Literal, List -import jwt - from ..api.components.schema.prompt import PromptDict, Prompt +from ..auth.permissions import ComfyJwt, jwt_decode from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus @@ -19,10 +18,8 @@ class DistributedBase: return self.decoded_token["sub"] @property - def decoded_token(self) -> dict: - return jwt.decode(self.user_token, algorithms=['HS256', "none"], - # todo: this should be configurable - options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False}) + def decoded_token(self) -> ComfyJwt: + return jwt_decode(self.user_token) @dataclass