Distributed queueing with amqp-compatible servers like RabbitMQ.

- Binary previews are not yet supported
 - Use `--distributed-queue-connection-uri=amqp://guest:guest@rabbitmqserver/`
 - Roles supported: frontend, worker or both (see `--help`)
 - Run `comfy-worker` for a lightweight worker you can wrap your head
   around
 - Workers and frontends must have the same directory structure (set
   with `--cwd`) and supported nodes. Frontends must still have access
   to inputs and outputs.
 - Configuration notes:

   distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
   distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
   distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
This commit is contained in:
doctorpangloss 2024-02-08 20:24:27 -08:00
parent 0673262940
commit 80f8c40248
9 changed files with 161 additions and 30 deletions

0
comfy/auth/__init__.py Normal file
View File

14
comfy/auth/permissions.py Normal file
View File

@ -0,0 +1,14 @@
from typing import TypedDict
import jwt
class ComfyJwt(TypedDict, total=False):
sub: str
def jwt_decode(user_token: str) -> ComfyJwt:
# todo: set up a way for users to override this behavior easily
return ComfyJwt(**jwt.decode(user_token, algorithms=['HS256', "none"],
# todo: this should be configurable
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False}))

View File

@ -13,6 +13,7 @@ from ..component_model.make_mutable import make_mutable
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
_server_stub_instance = ServerStub()
class EmbeddedComfyClient: class EmbeddedComfyClient:
""" """
@ -117,14 +118,21 @@ class EmbeddedComfyClient:
client_id = client_id or self._progress_handler.client_id or None client_id = client_id or self._progress_handler.client_id or None
def execute_prompt() -> dict: def execute_prompt() -> dict:
from ..cmd.execution import validate_prompt from ..cmd.execution import PromptExecutor, validate_prompt
prompt_mut = make_mutable(prompt) prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut) validation_tuple = validate_prompt(prompt_mut)
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, prompt_executor: PromptExecutor = self._prompt_executor
if client_id is None:
prompt_executor.server = _server_stub_instance
else:
prompt_executor.server = self._progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2]) execute_outputs=validation_tuple[2])
if self._prompt_executor.success: if prompt_executor.success:
return self._prompt_executor.outputs_ui return prompt_executor.outputs_ui
else: else:
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages)) raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))

View File

@ -227,7 +227,7 @@ async def main():
if args.distributed_queue_connection_uri is not None: if args.distributed_queue_connection_uri is not None:
distributed = True distributed = True
q = DistributedPromptQueue( q = DistributedPromptQueue(
caller_server=server if "worker" in args.distributed_queue_roles else None, caller_server=server if "frontend" in args.distributed_queue_roles else None,
connection_uri=args.distributed_queue_connection_uri, connection_uri=args.distributed_queue_connection_uri,
is_caller="frontend" in args.distributed_queue_roles, is_caller="frontend" in args.distributed_queue_roles,
is_callee="worker" in args.distributed_queue_roles, is_callee="worker" in args.distributed_queue_roles,

View File

@ -1,7 +1,7 @@
from __future__ import annotations # for Python 3.7-3.9 from __future__ import annotations # for Python 3.7-3.9
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from typing import Optional, Literal, Protocol from typing import Optional, Literal, Protocol, TypeAlias, Union
from comfy.component_model.queue_types import BinaryEventTypes from comfy.component_model.queue_types import BinaryEventTypes
@ -32,6 +32,13 @@ class ProgressMessage(TypedDict):
node: Optional[str] node: Optional[str]
ExecutedMessage: TypeAlias = ExecutingMessage
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None]
class ExecutorToClientProgress(Protocol): class ExecutorToClientProgress(Protocol):
""" """
Specifies the interface for the dependencies a prompt executor needs from a server. Specifies the interface for the dependencies a prompt executor needs from a server.
@ -47,8 +54,9 @@ class ExecutorToClientProgress(Protocol):
last_prompt_id: Optional[str] last_prompt_id: Optional[str]
def send_sync(self, def send_sync(self,
event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None, event: SendSyncEvent,
data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None): data: SendSyncData,
sid: Optional[str] = None):
""" """
Sends feedback to the client with the specified ID about a specific node Sends feedback to the client with the specified ID about a specific node

View File

@ -0,0 +1,84 @@
from __future__ import annotations
import asyncio
from asyncio import AbstractEventLoop
from functools import partial
from typing import Optional, Dict, Any
from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
from ..component_model.queue_types import BinaryEventTypes
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
caller_server: Optional[ExecutorToClientProgress] = None) -> None:
assert caller_server is not None
assert user_id is not None
caller_server.send_sync(event, data, sid=user_id)
def _get_name(queue_name: str, user_id: str) -> str:
return f"{queue_name}.{user_id}.progress"
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop):
self._rpc = rpc
self._queue_name = queue_name
self._loop = loop
self.client_id = None
self.node_id = None
self.last_node_id = None
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
return
if isinstance(data, bytes) or isinstance(data, bytearray):
return
if user_id is None:
# todo: user_id should never be none here
return
await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data})
def send_sync(self,
event: SendSyncEvent,
data: SendSyncData,
sid: Optional[str] = None):
asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop)
def queue_updated(self):
# todo: this should gather the global queue data
pass
class ProgressHandlers:
def __init__(self, rpc: RPC, caller_server: Optional[ExecutorToClientProgress], queue_name: str):
self._rpc = rpc
self._caller_server = caller_server
self._progress_handlers: Dict[str, Any] = {}
self._queue_name = queue_name
async def register_progress(self, user_id: str):
if user_id in self._progress_handlers:
return
handler = partial(_progress, user_id=user_id, caller_server=self._caller_server)
self._progress_handlers[user_id] = handler
await self._rpc.register(_get_name(self._queue_name, user_id), handler)
async def unregister_progress(self, user_id: str):
if user_id not in self._progress_handlers:
return
handler = self._progress_handlers.pop(user_id)
await self._rpc.unregister(handler)
async def unregister_all(self):
await asyncio.gather(*[self._rpc.unregister(handler) for handler in self._progress_handlers.values()])
self._progress_handlers.clear()

View File

@ -1,12 +1,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import threading
import time import time
import uuid import uuid
from asyncio import AbstractEventLoop, Queue, QueueEmpty from asyncio import AbstractEventLoop, Queue, QueueEmpty
from dataclasses import asdict from dataclasses import asdict
from functools import partial
from time import sleep from time import sleep
from typing import Optional, Dict, List, Mapping, Tuple, Callable from typing import Optional, Dict, List, Mapping, Tuple, Callable
@ -15,10 +13,12 @@ from aio_pika import connect_robust
from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import JsonRPC from aio_pika.patterns import JsonRPC
from .distributed_progress import ProgressHandlers
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from .server_stub import ServerStub from .server_stub import ServerStub
from ..auth.permissions import jwt_decode
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
from .history import History from .history import History
from ..cmd.server import PromptServer from ..cmd.server import PromptServer
@ -26,7 +26,7 @@ from ..cmd.server import PromptServer
class DistributedPromptQueue(AbstractPromptQueue): class DistributedPromptQueue(AbstractPromptQueue):
""" """
A distributed prompt queue for A distributed prompt queue for the ComfyUI web client and single-threaded worker.
""" """
def size(self) -> int: def size(self) -> int:
@ -36,8 +36,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
""" """
return len(self._caller_local_in_progress) return len(self._caller_local_in_progress)
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
self._caller_server.send_sync(event, data, sid=sid)
async def put_async(self, queue_item: QueueItem): async def put_async(self, queue_item: QueueItem):
assert self._is_caller assert self._is_caller
assert self._rpc is not None
if self._closing: if self._closing:
return return
self._caller_local_in_progress[queue_item.prompt_id] = queue_item self._caller_local_in_progress[queue_item.prompt_id] = queue_item
@ -46,23 +51,26 @@ class DistributedPromptQueue(AbstractPromptQueue):
try: try:
if "token" in queue_item.extra_data: if "token" in queue_item.extra_data:
user_token = queue_item.extra_data["token"] user_token = queue_item.extra_data["token"]
user_id = jwt_decode(user_token)["sub"]
else: else:
if "client_id" in queue_item.extra_data: if "client_id" in queue_item.extra_data:
client_id = queue_item.extra_data["client_id"] user_id = queue_item.extra_data["client_id"]
elif self._caller_server.client_id is not None: elif self._caller_server.client_id is not None:
client_id = self._caller_server.client_id user_id = self._caller_server.client_id
else: else:
client_id = str(uuid.uuid4()) user_id = str(uuid.uuid4())
# todo: should we really do this? # todo: should we really do this?
self._caller_server.client_id = client_id self._caller_server.client_id = user_id
# create a stub token # create a stub token
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none") user_token = jwt.encode({"sub": user_id}, key="", algorithm="none")
# register callbacks for progress
assert self._caller_progress_handlers is not None
await self._caller_progress_handlers.register_progress(user_id)
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt) request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
assert self._rpc is not None
res: TaskInvocation = RpcReply( res: TaskInvocation = RpcReply(
**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation() **(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
self._caller_history.put(queue_item, res.outputs, res.status) self._caller_history.put(queue_item, res.outputs, res.status)
if self._caller_server is not None: if self._caller_server is not None:
self._caller_server.queue_updated() self._caller_server.queue_updated()
@ -241,6 +249,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
# as rpc caller # as rpc caller
self._caller_server = caller_server or ServerStub() self._caller_server = caller_server or ServerStub()
self._caller_progress_handlers: Optional[ProgressHandlers] = None
self._caller_local_in_progress: dict[str | int, QueueItem] = {} self._caller_local_in_progress: dict[str | int, QueueItem] = {}
self._caller_history: History = History() self._caller_history: History = History()
@ -263,6 +272,8 @@ class DistributedPromptQueue(AbstractPromptQueue):
self._channel = await self._connection.channel() self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel) self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True self._rpc.host_exceptions = True
if self._is_caller:
self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name)
# this makes the queue available to complete work items # this makes the queue available to complete work items
if self._is_callee: if self._is_callee:
await self._rpc.register(self._queue_name, self._callee_do_work_item) await self._rpc.register(self._queue_name, self._callee_do_work_item)
@ -273,6 +284,9 @@ class DistributedPromptQueue(AbstractPromptQueue):
return return
self._closing = True self._closing = True
if self._is_caller:
await self._caller_progress_handlers.unregister_all()
await self._rpc.close() await self._rpc.close()
await self._channel.close() await self._channel.close()
await self._connection.close() await self._connection.close()

View File

@ -10,6 +10,7 @@ from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC from aio_pika.patterns import JsonRPC
from aiormq import AMQPConnectionError from aiormq import AMQPConnectionError
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..component_model.queue_types import ExecutionStatus from ..component_model.queue_types import ExecutionStatus
@ -24,11 +25,13 @@ class DistributedPromptWorker:
connection_uri: str = "amqp://localhost:5672/", connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui", queue_name: str = "comfyui",
loop: Optional[AbstractEventLoop] = None): loop: Optional[AbstractEventLoop] = None):
self._rpc = None
self._channel = None
self._exit_stack = AsyncExitStack() self._exit_stack = AsyncExitStack()
self._queue_name = queue_name self._queue_name = queue_name
self._connection_uri = connection_uri self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient() self._embedded_comfy_client = embedded_comfy_client
async def _do_work_item(self, request: dict) -> dict: async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request) await self.on_will_complete_work_item(request)
@ -55,9 +58,6 @@ class DistributedPromptWorker:
async def init(self): async def init(self):
await self._exit_stack.__aenter__() await self._exit_stack.__aenter__()
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
try: try:
self._connection = await connect_robust(self._connection_uri, loop=self._loop) self._connection = await connect_robust(self._connection_uri, loop=self._loop)
except AMQPConnectionError as connection_error: except AMQPConnectionError as connection_error:
@ -67,6 +67,12 @@ class DistributedPromptWorker:
self._rpc = await JsonRPC.create(channel=self._channel) self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True self._rpc.host_exceptions = True
if self._embedded_comfy_client is None:
self._embedded_comfy_client = EmbeddedComfyClient(
progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
await self._rpc.register(self._queue_name, self._do_work_item) await self._rpc.register(self._queue_name, self._do_work_item)
async def __aenter__(self) -> "DistributedPromptWorker": async def __aenter__(self) -> "DistributedPromptWorker":

View File

@ -3,9 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, Literal, List from typing import Tuple, Literal, List
import jwt
from ..api.components.schema.prompt import PromptDict, Prompt from ..api.components.schema.prompt import PromptDict, Prompt
from ..auth.permissions import ComfyJwt, jwt_decode
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
@ -19,10 +18,8 @@ class DistributedBase:
return self.decoded_token["sub"] return self.decoded_token["sub"]
@property @property
def decoded_token(self) -> dict: def decoded_token(self) -> ComfyJwt:
return jwt.decode(self.user_token, algorithms=['HS256', "none"], return jwt_decode(self.user_token)
# todo: this should be configurable
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})
@dataclass @dataclass