mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Distributed queueing with amqp-compatible servers like RabbitMQ.
- Binary previews are not yet supported - Use `--distributed-queue-connection-uri=amqp://guest:guest@rabbitmqserver/` - Roles supported: frontend, worker or both (see `--help`) - Run `comfy-worker` for a lightweight worker you can wrap your head around - Workers and frontends must have the same directory structure (set with `--cwd`) and supported nodes. Frontends must still have access to inputs and outputs. - Configuration notes: distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates. distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL. distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
This commit is contained in:
parent
0673262940
commit
80f8c40248
0
comfy/auth/__init__.py
Normal file
0
comfy/auth/__init__.py
Normal file
14
comfy/auth/permissions.py
Normal file
14
comfy/auth/permissions.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import TypedDict
|
||||
|
||||
import jwt
|
||||
|
||||
|
||||
class ComfyJwt(TypedDict, total=False):
|
||||
sub: str
|
||||
|
||||
|
||||
def jwt_decode(user_token: str) -> ComfyJwt:
|
||||
# todo: set up a way for users to override this behavior easily
|
||||
return ComfyJwt(**jwt.decode(user_token, algorithms=['HS256', "none"],
|
||||
# todo: this should be configurable
|
||||
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False}))
|
||||
@ -13,6 +13,7 @@ from ..component_model.make_mutable import make_mutable
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..distributed.server_stub import ServerStub
|
||||
|
||||
_server_stub_instance = ServerStub()
|
||||
|
||||
class EmbeddedComfyClient:
|
||||
"""
|
||||
@ -117,14 +118,21 @@ class EmbeddedComfyClient:
|
||||
client_id = client_id or self._progress_handler.client_id or None
|
||||
|
||||
def execute_prompt() -> dict:
|
||||
from ..cmd.execution import validate_prompt
|
||||
from ..cmd.execution import PromptExecutor, validate_prompt
|
||||
prompt_mut = make_mutable(prompt)
|
||||
validation_tuple = validate_prompt(prompt_mut)
|
||||
|
||||
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
prompt_executor: PromptExecutor = self._prompt_executor
|
||||
|
||||
if client_id is None:
|
||||
prompt_executor.server = _server_stub_instance
|
||||
else:
|
||||
prompt_executor.server = self._progress_handler
|
||||
|
||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple[2])
|
||||
if self._prompt_executor.success:
|
||||
return self._prompt_executor.outputs_ui
|
||||
if prompt_executor.success:
|
||||
return prompt_executor.outputs_ui
|
||||
else:
|
||||
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
|
||||
|
||||
|
||||
@ -227,7 +227,7 @@ async def main():
|
||||
if args.distributed_queue_connection_uri is not None:
|
||||
distributed = True
|
||||
q = DistributedPromptQueue(
|
||||
caller_server=server if "worker" in args.distributed_queue_roles else None,
|
||||
caller_server=server if "frontend" in args.distributed_queue_roles else None,
|
||||
connection_uri=args.distributed_queue_connection_uri,
|
||||
is_caller="frontend" in args.distributed_queue_roles,
|
||||
is_callee="worker" in args.distributed_queue_roles,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations # for Python 3.7-3.9
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from typing import Optional, Literal, Protocol
|
||||
from typing import Optional, Literal, Protocol, TypeAlias, Union
|
||||
|
||||
from comfy.component_model.queue_types import BinaryEventTypes
|
||||
|
||||
@ -32,6 +32,13 @@ class ProgressMessage(TypedDict):
|
||||
node: Optional[str]
|
||||
|
||||
|
||||
ExecutedMessage: TypeAlias = ExecutingMessage
|
||||
|
||||
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||
|
||||
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None]
|
||||
|
||||
|
||||
class ExecutorToClientProgress(Protocol):
|
||||
"""
|
||||
Specifies the interface for the dependencies a prompt executor needs from a server.
|
||||
@ -47,8 +54,9 @@ class ExecutorToClientProgress(Protocol):
|
||||
last_prompt_id: Optional[str]
|
||||
|
||||
def send_sync(self,
|
||||
event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None,
|
||||
data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None):
|
||||
event: SendSyncEvent,
|
||||
data: SendSyncData,
|
||||
sid: Optional[str] = None):
|
||||
"""
|
||||
Sends feedback to the client with the specified ID about a specific node
|
||||
|
||||
|
||||
84
comfy/distributed/distributed_progress.py
Normal file
84
comfy/distributed/distributed_progress.py
Normal file
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from functools import partial
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from aio_pika.patterns import RPC
|
||||
|
||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
|
||||
from ..component_model.queue_types import BinaryEventTypes
|
||||
|
||||
|
||||
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None,
|
||||
caller_server: Optional[ExecutorToClientProgress] = None) -> None:
|
||||
assert caller_server is not None
|
||||
assert user_id is not None
|
||||
caller_server.send_sync(event, data, sid=user_id)
|
||||
|
||||
|
||||
def _get_name(queue_name: str, user_id: str) -> str:
|
||||
return f"{queue_name}.{user_id}.progress"
|
||||
|
||||
|
||||
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop):
|
||||
self._rpc = rpc
|
||||
self._queue_name = queue_name
|
||||
self._loop = loop
|
||||
|
||||
self.client_id = None
|
||||
self.node_id = None
|
||||
self.last_node_id = None
|
||||
|
||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
||||
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
return
|
||||
|
||||
if isinstance(data, bytes) or isinstance(data, bytearray):
|
||||
return
|
||||
|
||||
if user_id is None:
|
||||
# todo: user_id should never be none here
|
||||
return
|
||||
|
||||
await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data})
|
||||
|
||||
def send_sync(self,
|
||||
event: SendSyncEvent,
|
||||
data: SendSyncData,
|
||||
sid: Optional[str] = None):
|
||||
asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop)
|
||||
|
||||
def queue_updated(self):
|
||||
# todo: this should gather the global queue data
|
||||
pass
|
||||
|
||||
|
||||
class ProgressHandlers:
|
||||
def __init__(self, rpc: RPC, caller_server: Optional[ExecutorToClientProgress], queue_name: str):
|
||||
self._rpc = rpc
|
||||
self._caller_server = caller_server
|
||||
self._progress_handlers: Dict[str, Any] = {}
|
||||
self._queue_name = queue_name
|
||||
|
||||
async def register_progress(self, user_id: str):
|
||||
if user_id in self._progress_handlers:
|
||||
return
|
||||
|
||||
handler = partial(_progress, user_id=user_id, caller_server=self._caller_server)
|
||||
self._progress_handlers[user_id] = handler
|
||||
await self._rpc.register(_get_name(self._queue_name, user_id), handler)
|
||||
|
||||
async def unregister_progress(self, user_id: str):
|
||||
if user_id not in self._progress_handlers:
|
||||
return
|
||||
handler = self._progress_handlers.pop(user_id)
|
||||
await self._rpc.unregister(handler)
|
||||
|
||||
async def unregister_all(self):
|
||||
await asyncio.gather(*[self._rpc.unregister(handler) for handler in self._progress_handlers.values()])
|
||||
self._progress_handlers.clear()
|
||||
@ -1,12 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from asyncio import AbstractEventLoop, Queue, QueueEmpty
|
||||
from dataclasses import asdict
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Optional, Dict, List, Mapping, Tuple, Callable
|
||||
|
||||
@ -15,10 +13,12 @@ from aio_pika import connect_robust
|
||||
from aio_pika.abc import AbstractConnection, AbstractChannel
|
||||
from aio_pika.patterns import JsonRPC
|
||||
|
||||
from .distributed_progress import ProgressHandlers
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from .server_stub import ServerStub
|
||||
from ..auth.permissions import jwt_decode
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||
from .history import History
|
||||
from ..cmd.server import PromptServer
|
||||
@ -26,7 +26,7 @@ from ..cmd.server import PromptServer
|
||||
|
||||
class DistributedPromptQueue(AbstractPromptQueue):
|
||||
"""
|
||||
A distributed prompt queue for
|
||||
A distributed prompt queue for the ComfyUI web client and single-threaded worker.
|
||||
"""
|
||||
|
||||
def size(self) -> int:
|
||||
@ -36,8 +36,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
"""
|
||||
return len(self._caller_local_in_progress)
|
||||
|
||||
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
|
||||
self._caller_server.send_sync(event, data, sid=sid)
|
||||
|
||||
async def put_async(self, queue_item: QueueItem):
|
||||
assert self._is_caller
|
||||
assert self._rpc is not None
|
||||
|
||||
if self._closing:
|
||||
return
|
||||
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||
@ -46,23 +51,26 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
try:
|
||||
if "token" in queue_item.extra_data:
|
||||
user_token = queue_item.extra_data["token"]
|
||||
user_id = jwt_decode(user_token)["sub"]
|
||||
else:
|
||||
if "client_id" in queue_item.extra_data:
|
||||
client_id = queue_item.extra_data["client_id"]
|
||||
user_id = queue_item.extra_data["client_id"]
|
||||
elif self._caller_server.client_id is not None:
|
||||
client_id = self._caller_server.client_id
|
||||
user_id = self._caller_server.client_id
|
||||
else:
|
||||
client_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
# todo: should we really do this?
|
||||
self._caller_server.client_id = client_id
|
||||
self._caller_server.client_id = user_id
|
||||
|
||||
# create a stub token
|
||||
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
|
||||
user_token = jwt.encode({"sub": user_id}, key="", algorithm="none")
|
||||
|
||||
# register callbacks for progress
|
||||
assert self._caller_progress_handlers is not None
|
||||
await self._caller_progress_handlers.register_progress(user_id)
|
||||
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
|
||||
assert self._rpc is not None
|
||||
res: TaskInvocation = RpcReply(
|
||||
**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
|
||||
|
||||
self._caller_history.put(queue_item, res.outputs, res.status)
|
||||
if self._caller_server is not None:
|
||||
self._caller_server.queue_updated()
|
||||
@ -241,6 +249,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
|
||||
# as rpc caller
|
||||
self._caller_server = caller_server or ServerStub()
|
||||
self._caller_progress_handlers: Optional[ProgressHandlers] = None
|
||||
self._caller_local_in_progress: dict[str | int, QueueItem] = {}
|
||||
self._caller_history: History = History()
|
||||
|
||||
@ -263,6 +272,8 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
self._channel = await self._connection.channel()
|
||||
self._rpc = await JsonRPC.create(channel=self._channel)
|
||||
self._rpc.host_exceptions = True
|
||||
if self._is_caller:
|
||||
self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name)
|
||||
# this makes the queue available to complete work items
|
||||
if self._is_callee:
|
||||
await self._rpc.register(self._queue_name, self._callee_do_work_item)
|
||||
@ -273,6 +284,9 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
return
|
||||
|
||||
self._closing = True
|
||||
if self._is_caller:
|
||||
await self._caller_progress_handlers.unregister_all()
|
||||
|
||||
await self._rpc.close()
|
||||
await self._channel.close()
|
||||
await self._connection.close()
|
||||
|
||||
@ -10,6 +10,7 @@ from aio_pika import connect_robust
|
||||
from aio_pika.patterns import JsonRPC
|
||||
from aiormq import AMQPConnectionError
|
||||
|
||||
from .distributed_progress import DistributedExecutorToClientProgress
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
@ -24,11 +25,13 @@ class DistributedPromptWorker:
|
||||
connection_uri: str = "amqp://localhost:5672/",
|
||||
queue_name: str = "comfyui",
|
||||
loop: Optional[AbstractEventLoop] = None):
|
||||
self._rpc = None
|
||||
self._channel = None
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._queue_name = queue_name
|
||||
self._connection_uri = connection_uri
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
|
||||
self._embedded_comfy_client = embedded_comfy_client
|
||||
|
||||
async def _do_work_item(self, request: dict) -> dict:
|
||||
await self.on_will_complete_work_item(request)
|
||||
@ -55,9 +58,6 @@ class DistributedPromptWorker:
|
||||
|
||||
async def init(self):
|
||||
await self._exit_stack.__aenter__()
|
||||
if not self._embedded_comfy_client.is_running:
|
||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||
|
||||
try:
|
||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||
except AMQPConnectionError as connection_error:
|
||||
@ -67,6 +67,12 @@ class DistributedPromptWorker:
|
||||
self._rpc = await JsonRPC.create(channel=self._channel)
|
||||
self._rpc.host_exceptions = True
|
||||
|
||||
if self._embedded_comfy_client is None:
|
||||
self._embedded_comfy_client = EmbeddedComfyClient(
|
||||
progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
|
||||
if not self._embedded_comfy_client.is_running:
|
||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||
|
||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||
|
||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||
|
||||
@ -3,9 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Literal, List
|
||||
|
||||
import jwt
|
||||
|
||||
from ..api.components.schema.prompt import PromptDict, Prompt
|
||||
from ..auth.permissions import ComfyJwt, jwt_decode
|
||||
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
|
||||
|
||||
|
||||
@ -19,10 +18,8 @@ class DistributedBase:
|
||||
return self.decoded_token["sub"]
|
||||
|
||||
@property
|
||||
def decoded_token(self) -> dict:
|
||||
return jwt.decode(self.user_token, algorithms=['HS256', "none"],
|
||||
# todo: this should be configurable
|
||||
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})
|
||||
def decoded_token(self) -> ComfyJwt:
|
||||
return jwt_decode(self.user_token)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user