diff --git a/comfy/auth/__init__.py b/comfy/auth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/auth/permissions.py b/comfy/auth/permissions.py new file mode 100644 index 000000000..d96533677 --- /dev/null +++ b/comfy/auth/permissions.py @@ -0,0 +1,14 @@ +from typing import TypedDict + +import jwt + + +class ComfyJwt(TypedDict, total=False): + sub: str + + +def jwt_decode(user_token: str) -> ComfyJwt: + # todo: set up a way for users to override this behavior easily + return ComfyJwt(**jwt.decode(user_token, algorithms=['HS256', "none"], + # todo: this should be configurable + options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})) diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index a5bdf6b51..e4271a98a 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -13,6 +13,7 @@ from ..component_model.make_mutable import make_mutable from ..component_model.executor_types import ExecutorToClientProgress from ..distributed.server_stub import ServerStub +_server_stub_instance = ServerStub() class EmbeddedComfyClient: """ @@ -117,14 +118,21 @@ class EmbeddedComfyClient: client_id = client_id or self._progress_handler.client_id or None def execute_prompt() -> dict: - from ..cmd.execution import validate_prompt + from ..cmd.execution import PromptExecutor, validate_prompt prompt_mut = make_mutable(prompt) validation_tuple = validate_prompt(prompt_mut) - self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, - execute_outputs=validation_tuple[2]) - if self._prompt_executor.success: - return self._prompt_executor.outputs_ui + prompt_executor: PromptExecutor = self._prompt_executor + + if client_id is None: + prompt_executor.server = _server_stub_instance + else: + prompt_executor.server = self._progress_handler + + prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, + execute_outputs=validation_tuple[2]) + if prompt_executor.success: + return prompt_executor.outputs_ui else: raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages)) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index cba09c61c..4d1e82e8f 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -227,7 +227,7 @@ async def main(): if args.distributed_queue_connection_uri is not None: distributed = True q = DistributedPromptQueue( - caller_server=server if "worker" in args.distributed_queue_roles else None, + caller_server=server if "frontend" in args.distributed_queue_roles else None, connection_uri=args.distributed_queue_connection_uri, is_caller="frontend" in args.distributed_queue_roles, is_callee="worker" in args.distributed_queue_roles, diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index c1de2d9df..151d085ca 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -1,7 +1,7 @@ from __future__ import annotations # for Python 3.7-3.9 from typing_extensions import NotRequired, TypedDict -from typing import Optional, Literal, Protocol +from typing import Optional, Literal, Protocol, TypeAlias, Union from comfy.component_model.queue_types import BinaryEventTypes @@ -32,6 +32,13 @@ class ProgressMessage(TypedDict): node: Optional[str] +ExecutedMessage: TypeAlias = ExecutingMessage + +SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None] + +SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None] + + class ExecutorToClientProgress(Protocol): """ Specifies the interface for the dependencies a prompt executor needs from a server. @@ -47,8 +54,9 @@ class ExecutorToClientProgress(Protocol): last_prompt_id: Optional[str] def send_sync(self, - event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None, - data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None): + event: SendSyncEvent, + data: SendSyncData, + sid: Optional[str] = None): """ Sends feedback to the client with the specified ID about a specific node diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py new file mode 100644 index 000000000..048b15c18 --- /dev/null +++ b/comfy/distributed/distributed_progress.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import asyncio +from asyncio import AbstractEventLoop +from functools import partial + +from typing import Optional, Dict, Any + +from aio_pika.patterns import RPC + +from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress +from ..component_model.queue_types import BinaryEventTypes + + +async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None, + caller_server: Optional[ExecutorToClientProgress] = None) -> None: + assert caller_server is not None + assert user_id is not None + caller_server.send_sync(event, data, sid=user_id) + + +def _get_name(queue_name: str, user_id: str) -> str: + return f"{queue_name}.{user_id}.progress" + + +class DistributedExecutorToClientProgress(ExecutorToClientProgress): + def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop): + self._rpc = rpc + self._queue_name = queue_name + self._loop = loop + + self.client_id = None + self.node_id = None + self.last_node_id = None + + async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: + # for now, do not send binary data this way, since it cannot be json serialized / it's impractical + if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + return + + if isinstance(data, bytes) or isinstance(data, bytearray): + return + + if user_id is None: + # todo: user_id should never be none here + return + + await self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data}) + + def send_sync(self, + event: SendSyncEvent, + data: SendSyncData, + sid: Optional[str] = None): + asyncio.run_coroutine_threadsafe(self.send(event, data, sid), self._loop) + + def queue_updated(self): + # todo: this should gather the global queue data + pass + + +class ProgressHandlers: + def __init__(self, rpc: RPC, caller_server: Optional[ExecutorToClientProgress], queue_name: str): + self._rpc = rpc + self._caller_server = caller_server + self._progress_handlers: Dict[str, Any] = {} + self._queue_name = queue_name + + async def register_progress(self, user_id: str): + if user_id in self._progress_handlers: + return + + handler = partial(_progress, user_id=user_id, caller_server=self._caller_server) + self._progress_handlers[user_id] = handler + await self._rpc.register(_get_name(self._queue_name, user_id), handler) + + async def unregister_progress(self, user_id: str): + if user_id not in self._progress_handlers: + return + handler = self._progress_handlers.pop(user_id) + await self._rpc.unregister(handler) + + async def unregister_all(self): + await asyncio.gather(*[self._rpc.unregister(handler) for handler in self._progress_handlers.values()]) + self._progress_handlers.clear() diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 1ae74492d..8d3c810bd 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -1,12 +1,10 @@ from __future__ import annotations import asyncio -import threading import time import uuid from asyncio import AbstractEventLoop, Queue, QueueEmpty from dataclasses import asdict -from functools import partial from time import sleep from typing import Optional, Dict, List, Mapping, Tuple, Callable @@ -15,10 +13,12 @@ from aio_pika import connect_robust from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.patterns import JsonRPC +from .distributed_progress import ProgressHandlers from .distributed_types import RpcRequest, RpcReply from .server_stub import ServerStub +from ..auth.permissions import jwt_decode from ..component_model.abstract_prompt_queue import AbstractPromptQueue -from ..component_model.executor_types import ExecutorToClientProgress +from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation from .history import History from ..cmd.server import PromptServer @@ -26,7 +26,7 @@ from ..cmd.server import PromptServer class DistributedPromptQueue(AbstractPromptQueue): """ - A distributed prompt queue for + A distributed prompt queue for the ComfyUI web client and single-threaded worker. """ def size(self) -> int: @@ -36,8 +36,13 @@ class DistributedPromptQueue(AbstractPromptQueue): """ return len(self._caller_local_in_progress) + async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None: + self._caller_server.send_sync(event, data, sid=sid) + async def put_async(self, queue_item: QueueItem): assert self._is_caller + assert self._rpc is not None + if self._closing: return self._caller_local_in_progress[queue_item.prompt_id] = queue_item @@ -46,23 +51,26 @@ class DistributedPromptQueue(AbstractPromptQueue): try: if "token" in queue_item.extra_data: user_token = queue_item.extra_data["token"] + user_id = jwt_decode(user_token)["sub"] else: if "client_id" in queue_item.extra_data: - client_id = queue_item.extra_data["client_id"] + user_id = queue_item.extra_data["client_id"] elif self._caller_server.client_id is not None: - client_id = self._caller_server.client_id + user_id = self._caller_server.client_id else: - client_id = str(uuid.uuid4()) + user_id = str(uuid.uuid4()) # todo: should we really do this? - self._caller_server.client_id = client_id + self._caller_server.client_id = user_id # create a stub token - user_token = jwt.encode({"sub": client_id}, key="", algorithm="none") + user_token = jwt.encode({"sub": user_id}, key="", algorithm="none") + + # register callbacks for progress + assert self._caller_progress_handlers is not None + await self._caller_progress_handlers.register_progress(user_id) request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt) - assert self._rpc is not None res: TaskInvocation = RpcReply( **(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation() - self._caller_history.put(queue_item, res.outputs, res.status) if self._caller_server is not None: self._caller_server.queue_updated() @@ -241,6 +249,7 @@ class DistributedPromptQueue(AbstractPromptQueue): # as rpc caller self._caller_server = caller_server or ServerStub() + self._caller_progress_handlers: Optional[ProgressHandlers] = None self._caller_local_in_progress: dict[str | int, QueueItem] = {} self._caller_history: History = History() @@ -263,6 +272,8 @@ class DistributedPromptQueue(AbstractPromptQueue): self._channel = await self._connection.channel() self._rpc = await JsonRPC.create(channel=self._channel) self._rpc.host_exceptions = True + if self._is_caller: + self._caller_progress_handlers = ProgressHandlers(self._rpc, self._caller_server, self._queue_name) # this makes the queue available to complete work items if self._is_callee: await self._rpc.register(self._queue_name, self._callee_do_work_item) @@ -273,6 +284,9 @@ class DistributedPromptQueue(AbstractPromptQueue): return self._closing = True + if self._is_caller: + await self._caller_progress_handlers.unregister_all() + await self._rpc.close() await self._channel.close() await self._connection.close() diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index e2b335f59..e37043353 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -10,6 +10,7 @@ from aio_pika import connect_robust from aio_pika.patterns import JsonRPC from aiormq import AMQPConnectionError +from .distributed_progress import DistributedExecutorToClientProgress from .distributed_types import RpcRequest, RpcReply from ..client.embedded_comfy_client import EmbeddedComfyClient from ..component_model.queue_types import ExecutionStatus @@ -24,11 +25,13 @@ class DistributedPromptWorker: connection_uri: str = "amqp://localhost:5672/", queue_name: str = "comfyui", loop: Optional[AbstractEventLoop] = None): + self._rpc = None + self._channel = None self._exit_stack = AsyncExitStack() self._queue_name = queue_name self._connection_uri = connection_uri self._loop = loop or asyncio.get_event_loop() - self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient() + self._embedded_comfy_client = embedded_comfy_client async def _do_work_item(self, request: dict) -> dict: await self.on_will_complete_work_item(request) @@ -55,9 +58,6 @@ class DistributedPromptWorker: async def init(self): await self._exit_stack.__aenter__() - if not self._embedded_comfy_client.is_running: - await self._exit_stack.enter_async_context(self._embedded_comfy_client) - try: self._connection = await connect_robust(self._connection_uri, loop=self._loop) except AMQPConnectionError as connection_error: @@ -67,6 +67,12 @@ class DistributedPromptWorker: self._rpc = await JsonRPC.create(channel=self._channel) self._rpc.host_exceptions = True + if self._embedded_comfy_client is None: + self._embedded_comfy_client = EmbeddedComfyClient( + progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop)) + if not self._embedded_comfy_client.is_running: + await self._exit_stack.enter_async_context(self._embedded_comfy_client) + await self._rpc.register(self._queue_name, self._do_work_item) async def __aenter__(self) -> "DistributedPromptWorker": diff --git a/comfy/distributed/distributed_types.py b/comfy/distributed/distributed_types.py index 78fcd29d0..35c420d66 100644 --- a/comfy/distributed/distributed_types.py +++ b/comfy/distributed/distributed_types.py @@ -3,9 +3,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Tuple, Literal, List -import jwt - from ..api.components.schema.prompt import PromptDict, Prompt +from ..auth.permissions import ComfyJwt, jwt_decode from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus @@ -19,10 +18,8 @@ class DistributedBase: return self.decoded_token["sub"] @property - def decoded_token(self) -> dict: - return jwt.decode(self.user_token, algorithms=['HS256', "none"], - # todo: this should be configurable - options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False}) + def decoded_token(self) -> ComfyJwt: + return jwt_decode(self.user_token) @dataclass