diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0fa656a48..ebbbd4466 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -124,11 +124,11 @@ parser.add_argument("--distributed-queue-connection-uri", type=str, default=None parser.add_argument( '--distributed-queue-roles', action='append', - choices=['worker', 'prompter'], - help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMPQ URL to submit prompts; workers will pull requests off the AMPQ URL.' + choices=['worker', 'frontend'], + help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.' ) parser.add_argument("--distributed-queue-name", type=str, default="comfyui", - help="This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID") + help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID") if options.args_parsing: diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 0ee1e662d..1f9dfec9a 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -71,8 +71,8 @@ class Configuration(dict): write_out_config_file (bool): Enable writing out the configuration file. create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits. distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates. - distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL. - distributed_queue_name (str): This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID. + distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL. + distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID. """ def __init__(self, **kwargs): super().__init__() @@ -131,7 +131,7 @@ class Configuration(dict): self.write_out_config_file: bool = False self.create_directories: bool = False self.distributed_queue_connection_uri: Optional[str] = None - self.distributed_queue_roles: List[str] = [] + self.distributed_queue_roles: List[str] = ["worker", "frontend"] self.distributed_queue_name: str = "comfyui" for key, value in kwargs.items(): self[key] = value diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index c3d6384cd..a5bdf6b51 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -5,32 +5,13 @@ import gc import uuid from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor -from typing import Literal, Optional +from typing import Optional from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration from ..component_model.make_mutable import make_mutable -from ..component_model.queue_types import BinaryEventTypes -from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage - - -class ServerStub(ExecutorToClientProgress): - """ - This class is a stub implementation of ExecutorToClientProgress. This will handle progress events. - """ - - def __init__(self): - self.client_id = str(uuid.uuid4()) - self.last_node_id = None - self.last_prompt_id = None - - def send_sync(self, - event: Literal["status", "executing"] | BinaryEventTypes | str | None, - data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None): - pass - - def queue_updated(self): - pass +from ..component_model.executor_types import ExecutorToClientProgress +from ..distributed.server_stub import ServerStub class EmbeddedComfyClient: diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index e11a323ea..e92278142 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -137,7 +137,7 @@ def format_value(x): return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, +def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -766,14 +766,12 @@ class PromptQueue(AbstractPromptQueue): self.server = server self.mutex = threading.RLock() self.not_empty = threading.Condition(self.mutex) - self.next_task_id = 0 self.queue: typing.List[QueueItem] = [] - self.currently_running: typing.Dict[int, QueueItem] = {} + self.currently_running: typing.Dict[str, QueueItem] = {} # history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs # todo: use the new History class for the sake of simplicity self.history: typing.Dict[str, HistoryEntry] = {} self.flags = {} - server.prompt_queue = self def size(self) -> int: return len(self.queue) @@ -784,20 +782,23 @@ class PromptQueue(AbstractPromptQueue): self.server.queue_updated() self.not_empty.notify() - def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]: + def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, str]]: with self.not_empty: while len(self.queue) == 0: self.not_empty.wait(timeout=timeout) if timeout is not None and len(self.queue) == 0: return None item_with_future: QueueItem = heapq.heappop(self.queue) - task_id = self.next_task_id + assert item_with_future.prompt_id is not None + assert item_with_future.prompt_id != "" + assert item_with_future.prompt_id not in self.currently_running + assert isinstance(item_with_future.prompt_id, str) + task_id = item_with_future.prompt_id self.currently_running[task_id] = item_with_future - self.next_task_id += 1 self.server.queue_updated() return copy.deepcopy(item_with_future.queue_tuple), task_id - def task_done(self, item_id, outputs: dict, + def task_done(self, item_id: str, outputs: dict, status: Optional[ExecutionStatus]): with self.mutex: queue_item = self.currently_running.pop(item_id) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index e8ae96f17..c546636cf 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,3 +1,4 @@ +import signal import sys from .. import options @@ -78,12 +79,16 @@ if args.deterministic: from .. import utils import yaml +from contextlib import AsyncExitStack from ..cmd import execution from ..cmd import server as server_module from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus from .. import model_management +from ..distributed.distributed_prompt_queue import DistributedPromptQueue +from ..component_model.executor_types import ExecutorToClientProgress +from ..distributed.server_stub import ServerStub def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): @@ -145,8 +150,8 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) -def hijack_progress(server): - def hook(value, total, preview_image): +def hijack_progress(server: ExecutorToClientProgress): + def hook(value: float, total: float, preview_image): model_management.throw_exception_if_processing_interrupted() progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} @@ -201,7 +206,7 @@ def cuda_malloc_warning(): "\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def main(): +async def main(): if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") print(f"Setting temp directory to: {temp_dir}") @@ -217,10 +222,24 @@ def main(): if args.windows_standalone_build: folder_paths.create_directories() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = asyncio.get_event_loop() server = server_module.PromptServer(loop) - q = execution.PromptQueue(server) + if args.distributed_queue_connection_uri is not None: + distributed = True + q = DistributedPromptQueue( + caller_server=server if "worker" in args.distributed_queue_roles else None, + connection_uri=args.distributed_queue_connection_uri, + is_caller="frontend" in args.distributed_queue_roles, + is_callee="worker" in args.distributed_queue_roles, + loop=loop, + queue_name=args.distributed_queue_name + ) + await q.init() + loop.add_signal_handler(signal.SIGINT, lambda *args, **kwargs: q.close()) + else: + distributed = False + q = execution.PromptQueue(server) + server.prompt_queue = q try: extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") @@ -237,7 +256,11 @@ def main(): hijack_progress(server) cuda_malloc_warning() - threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start() + # in a distributed setting, the prompt worker will not be able to send execution events via the websocket + # the distributed prompt queue will be responsible for simulating those events until the broker is configured to + # pass those messages to the appropriate user + worker_thread_server = server if not distributed else ServerStub() + threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start() # server has been imported and things should be looking good initialize_event_tracking(loop) @@ -273,13 +296,14 @@ def main(): server.address = args.listen server.port = args.port try: - loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, - call_on_start=call_on_start)) + await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, + call_on_start=call_on_start) except KeyboardInterrupt: + await q.close() print("\nStopped server") cleanup_temp() if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/comfy/component_model/abstract_prompt_queue.py b/comfy/component_model/abstract_prompt_queue.py index d4054e9d4..aac095ca7 100644 --- a/comfy/component_model/abstract_prompt_queue.py +++ b/comfy/component_model/abstract_prompt_queue.py @@ -14,6 +14,7 @@ class AbstractPromptQueue(metaclass=ABCMeta): get is intended to be used by a worker. """ + @abstractmethod def size(self) -> int: """ @@ -32,7 +33,7 @@ class AbstractPromptQueue(metaclass=ABCMeta): pass @abstractmethod - def get(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, int]]: + def get(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]: """ Pops an item off the queue. Blocking. If a timeout is provided, this will return None after :param timeout: the number of seconds to time out for a blocking get @@ -41,7 +42,7 @@ class AbstractPromptQueue(metaclass=ABCMeta): pass @abstractmethod - def task_done(self, item_id: int, outputs: dict, + def task_done(self, item_id: str, outputs: dict, status: typing.Optional[ExecutionStatus]): """ Signals to the user interface that the task with the specified id is completed @@ -110,5 +111,10 @@ class AbstractPromptQueue(metaclass=ABCMeta): pass @abstractmethod - def get_flags(self, reset) -> Flags: + def get_flags(self, reset: bool = True) -> Flags: + """ + Resets the flags for the next model unload or free memory request. + :param reset: + :return: + """ pass diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index a7a966b3b..c1de2d9df 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -22,6 +22,14 @@ class StatusMessage(TypedDict): class ExecutingMessage(TypedDict): node: str | None prompt_id: NotRequired[str] + output: NotRequired[dict] + + +class ProgressMessage(TypedDict): + value: float + max: float + prompt_id: Optional[str] + node: Optional[str] class ExecutorToClientProgress(Protocol): @@ -39,8 +47,8 @@ class ExecutorToClientProgress(Protocol): last_prompt_id: Optional[str] def send_sync(self, - event: Literal["status", "executing"] | BinaryEventTypes | str | None, - data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None): + event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None, + data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None): """ Sends feedback to the client with the specified ID about a specific node diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 3812849b4..1ae74492d 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -1,9 +1,13 @@ from __future__ import annotations import asyncio +import threading +import time import uuid -from asyncio import AbstractEventLoop, Queue +from asyncio import AbstractEventLoop, Queue, QueueEmpty from dataclasses import asdict +from functools import partial +from time import sleep from typing import Optional, Dict, List, Mapping, Tuple, Callable import jwt @@ -12,6 +16,7 @@ from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.patterns import JsonRPC from .distributed_types import RpcRequest, RpcReply +from .server_stub import ServerStub from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation @@ -29,38 +34,38 @@ class DistributedPromptQueue(AbstractPromptQueue): In a distributed queue, this only returns the client's apparent number of items it is waiting for :return: """ - return len(self.caller_local_in_progress) + return len(self._caller_local_in_progress) async def put_async(self, queue_item: QueueItem): - assert self.is_caller + assert self._is_caller if self._closing: return - self.caller_local_in_progress[queue_item.prompt_id] = queue_item - if self.caller_server is not None: - self.caller_server.queue_updated() + self._caller_local_in_progress[queue_item.prompt_id] = queue_item + if self._caller_server is not None: + self._caller_server.queue_updated() try: if "token" in queue_item.extra_data: user_token = queue_item.extra_data["token"] else: if "client_id" in queue_item.extra_data: client_id = queue_item.extra_data["client_id"] - elif self.caller_server.client_id is not None: - client_id = self.caller_server.client_id + elif self._caller_server.client_id is not None: + client_id = self._caller_server.client_id else: client_id = str(uuid.uuid4()) # todo: should we really do this? - self.caller_server.client_id = client_id + self._caller_server.client_id = client_id # create a stub token user_token = jwt.encode({"sub": client_id}, key="", algorithm="none") request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt) - assert self.rpc is not None + assert self._rpc is not None res: TaskInvocation = RpcReply( - **(await self.rpc.call(self.queue_name, {"request": asdict(request)}))).as_task_invocation() + **(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation() - self.caller_history.put(queue_item, res.outputs, res.status) - if self.caller_server is not None: - self.caller_server.queue_updated() + self._caller_history.put(queue_item, res.outputs, res.status) + if self._caller_server is not None: + self._caller_server.queue_updated() # if this has a completion future, complete it if queue_item.completed is not None: @@ -69,8 +74,8 @@ class DistributedPromptQueue(AbstractPromptQueue): except Exception as e: # if a caller-side error occurred, use the passed error for the messages # we didn't receive any outputs here - self.caller_history.put(queue_item, outputs={}, - status=ExecutionStatus(status_str="error", completed=False, messages=[str(e)])) + self._caller_history.put(queue_item, outputs={}, + status=ExecutionStatus(status_str="error", completed=False, messages=[str(e)])) # if we have a completer, propoagate the exception to it if queue_item.completed is not None: @@ -79,28 +84,31 @@ class DistributedPromptQueue(AbstractPromptQueue): # otherwise, this should raise in the event loop, which I suppose isn't handled raise e finally: - self.caller_local_in_progress.pop(queue_item.prompt_id) - if self.caller_server is not None: - self.caller_server.queue_updated() + self._caller_local_in_progress.pop(queue_item.prompt_id) + if self._caller_server is not None: + # todo: this ensures that the web ui is notified about the completed task, but it should really be done by worker + self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id}, + self._caller_server.client_id) + self._caller_server.queue_updated() def put(self, item: QueueItem): # caller: execute on main thread - assert self.is_caller + assert self._is_caller if self._closing: return # this is called by the web server and its event loop is perfectly fine to use # the future is now ignored - self.loop.call_soon_threadsafe(self.put_async, item) + asyncio.run_coroutine_threadsafe(self.put_async(item), self._loop) async def _callee_do_work_item(self, request: dict) -> dict: - assert self.is_callee + assert self._is_callee request_obj = RpcRequest.from_dict(request) item = request_obj.as_queue_tuple().queue_tuple - item_with_completer = QueueItem(item, self.loop.create_future()) - self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer + item_with_completer = QueueItem(item, self._loop.create_future()) + self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer # todo: check if we have the local model content needed to execute this request and if not, reject it # todo: check if we have enough memory to execute this request, and if not, reject it - await self.callee_local_queue.put(item) + self._callee_local_queue.put_nowait(item) # technically this could be messed with or overwritten assert item_with_completer.completed is not None @@ -110,11 +118,35 @@ class DistributedPromptQueue(AbstractPromptQueue): invocation = await item_with_completer.completed return asdict(RpcReply.from_task_invocation(invocation, request_obj.user_token)) - def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str | int]]: + def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str]]: # callee: executed on the worker thread - assert self.is_callee + assert self._is_callee + # the loop receiving messages must not be mounted on the worker thread + # otherwise receiving messages will be blocked forever + worker_event_loop = asyncio.get_event_loop() + assert self._loop != worker_event_loop, "get only makes sense in the context of the legacy comfyui prompt worker" + # spin wait + timeout = timeout or 30.0 + item = None + while timeout > 0: + try: + item = self._callee_local_queue.get_nowait() + break + except QueueEmpty: + start_time = time.time() + sleep(0.1) + timeout -= time.time() - start_time + + if item is None: + return None + + return item, item[1] + + async def get_async(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str]]: + # callee: executed anywhere + assert self._is_callee try: - item = asyncio.run_coroutine_threadsafe(self.callee_local_queue.get(), self.loop).result(timeout) + item: QueueTuple = await asyncio.wait_for(self._callee_local_queue.get(), timeout) except TimeoutError: return None @@ -122,20 +154,21 @@ class DistributedPromptQueue(AbstractPromptQueue): def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]): # callee: executed on the worker thread - assert self.is_callee - pending = self.callee_local_in_progress.pop(item_id) + assert self._is_callee + pending = self._callee_local_in_progress.pop(item_id) assert pending is not None assert pending.completed is not None assert not pending.completed.done() # finish the task. status will transmit the errors in comfy's domain-specific way pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status)) + # todo: the caller is responsible for sending a websocket message right now that the UI expects for updates def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]: """ In a distributed queue, all queue items are assumed to be currently in progress :return: """ - return [], [item.queue_tuple for item in self.caller_local_in_progress.values()] + return [], [item.queue_tuple for item in self._caller_local_in_progress.values()] def get_tasks_remaining(self) -> int: """ @@ -143,7 +176,7 @@ class DistributedPromptQueue(AbstractPromptQueue): :return: """ # caller: executed on main thread - return len(self.caller_local_in_progress) + return len(self._caller_local_in_progress) def wipe_queue(self) -> None: """ @@ -162,13 +195,13 @@ class DistributedPromptQueue(AbstractPromptQueue): def get_history(self, prompt_id: Optional[int] = None, max_items=None, offset=-1) \ -> Mapping[str, HistoryEntry]: - return self.caller_history.copy(prompt_id=prompt_id, max_items=max_items, offset=offset) + return self._caller_history.copy(prompt_id=prompt_id, max_items=max_items, offset=offset) def wipe_history(self): - self.caller_history.clear() + self._caller_history.clear() def delete_history_item(self, id_to_delete): - self.caller_history.pop(id_to_delete) + self._caller_history.pop(id_to_delete) def set_flag(self, name: str, data: bool) -> None: """ @@ -179,7 +212,7 @@ class DistributedPromptQueue(AbstractPromptQueue): """ pass - def get_flags(self, reset) -> Flags: + def get_flags(self, reset=True) -> Flags: """ Does nothing on distributed queues. Workers must manage their own memory. :param reset: @@ -188,7 +221,7 @@ class DistributedPromptQueue(AbstractPromptQueue): return Flags() def __init__(self, - server: Optional[ExecutorToClientProgress | PromptServer] = None, + caller_server: Optional[ExecutorToClientProgress | PromptServer] = None, queue_name: str = "comfyui", connection_uri="amqp://localhost/", is_caller=True, @@ -196,40 +229,52 @@ class DistributedPromptQueue(AbstractPromptQueue): loop: Optional[AbstractEventLoop] = None): super().__init__() # this constructor is called on the main thread - self.loop = loop or asyncio.get_event_loop() or asyncio.new_event_loop() - self.queue_name = queue_name - self.connection_uri = connection_uri - self.connection: Optional[AbstractConnection] = None # Connection will be set up asynchronously - self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously - self.is_caller = is_caller - self.is_callee = is_callee + self._loop = loop or asyncio.get_event_loop() or asyncio.new_event_loop() + self._queue_name = queue_name + self._connection_uri = connection_uri + self._connection: Optional[AbstractConnection] = None # Connection will be set up asynchronously + self._channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously + self._is_caller = is_caller + self._is_callee = is_callee self._closing = False + self._initialized = False # as rpc caller - self.caller_server = server - self.caller_local_in_progress: dict[str | int, QueueItem] = {} - self.caller_history: History = History() + self._caller_server = caller_server or ServerStub() + self._caller_local_in_progress: dict[str | int, QueueItem] = {} + self._caller_history: History = History() # as rpc callee - self.callee_local_queue: Queue = Queue() - self.callee_local_in_progress: Dict[int | str, QueueItem] = {} - self.rpc: Optional[JsonRPC] = None + self._callee_local_queue: Queue = Queue() + self._callee_local_in_progress: Dict[int | str, QueueItem] = {} + self._rpc: Optional[JsonRPC] = None - # todo: the prompt queue really shouldn't do this - if server is not None: - server.prompt_queue = self + async def __aenter__(self): + await self.init() + return self + + async def __aexit__(self, *args): + await self.close() async def init(self): - self.connection = await connect_robust(self.connection_uri, loop=self.loop) - self.channel = await self.connection.channel() - self.rpc = await JsonRPC.create(channel=self.channel) - self.rpc.host_exceptions = True + if self._initialized: + return + self._connection = await connect_robust(self._connection_uri, loop=self._loop) + self._channel = await self._connection.channel() + self._rpc = await JsonRPC.create(channel=self._channel) + self._rpc.host_exceptions = True # this makes the queue available to complete work items - if self.is_callee: - await self.rpc.register(self.queue_name, self._callee_do_work_item) + if self._is_callee: + await self._rpc.register(self._queue_name, self._callee_do_work_item) + self._initialized = True async def close(self): + if self._closing or not self._initialized: + return + self._closing = True - await self.rpc.close() - await self.channel.close() - await self.connection.close() + await self._rpc.close() + await self._channel.close() + await self._connection.close() + self._initialized = False + self._closing = False diff --git a/comfy/distributed/server_stub.py b/comfy/distributed/server_stub.py new file mode 100644 index 000000000..490492a02 --- /dev/null +++ b/comfy/distributed/server_stub.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import uuid +from typing import Literal + +from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage +from ..component_model.queue_types import BinaryEventTypes + + +class ServerStub(ExecutorToClientProgress): + """ + This class is a stub implementation of ExecutorToClientProgress. This will handle progress events. + """ + + def __init__(self): + self.client_id = str(uuid.uuid4()) + self.last_node_id = None + self.last_prompt_id = None + + def send_sync(self, + event: Literal["status", "executing"] | BinaryEventTypes | str | None, + data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None): + pass + + def queue_updated(self): + pass diff --git a/comfy/utils.py b/comfy/utils.py index 3437aa258..c0bef4d8b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -446,10 +446,10 @@ def set_progress_bar_global_hook(function): PROGRESS_BAR_HOOK = function class ProgressBar: - def __init__(self, total): + def __init__(self, total: float): global PROGRESS_BAR_HOOK - self.total = total - self.current = 0 + self.total: float = total + self.current: float = 0.0 self.hook = PROGRESS_BAR_HOOK def update_absolute(self, value, total=None, preview=None): diff --git a/main.py b/main.py index 8649fd84c..0e3c9e6eb 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,6 @@ +import asyncio + from comfy.cmd.main import main if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 2f421159b..3801c3ef3 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -1,13 +1,16 @@ +import asyncio import os import uuid +from concurrent.futures import ThreadPoolExecutor import jwt import pytest -from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub +from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfy.distributed.server_stub import ServerStub from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner from comfy.component_model.make_mutable import make_mutable -from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation +from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker from testcontainers.rabbitmq import RabbitMqContainer @@ -15,6 +18,16 @@ from testcontainers.rabbitmq import RabbitMqContainer os.environ["TC_HOST"] = "localhost" +def create_test_prompt() -> QueueItem: + from comfy.cmd.execution import validate_prompt + + prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) + validation_tuple = validate_prompt(prompt) + item_id = str(uuid.uuid4()) + queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) + return QueueItem(queue_tuple, None) + + @pytest.mark.asyncio async def test_sign_jwt_auth_none(): client_id = str(uuid.uuid4()) @@ -29,21 +42,56 @@ async def test_basic_queue_worker() -> None: with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() - async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker: + async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"): # this unfortunately does a bunch of initialization on the test thread - from comfy.cmd.execution import validate_prompt from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue # now submit some jobs distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") await distributed_queue.init() - prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) - validation_tuple = validate_prompt(prompt) - item_id = str(uuid.uuid4()) - queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) - res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None)) - assert res.item_id == item_id + queue_item = create_test_prompt() + res: TaskInvocation = await distributed_queue.put_async(queue_item) + assert res.item_id == queue_item.prompt_id assert len(res.outputs) == 1 assert res.status is not None assert res.status.status_str == "success" await distributed_queue.close() + + +@pytest.mark.asyncio +async def test_distributed_prompt_queues_same_process(): + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, + connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, + connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker: + test_prompt = create_test_prompt() + test_prompt.completed = asyncio.Future() + + frontend.put(test_prompt) + + # start a worker thread + thread_pool = ThreadPoolExecutor(max_workers=1) + + async def in_thread(): + incoming, incoming_prompt_id = worker.get() + assert incoming is not None + incoming_named = NamedQueueTuple(incoming) + assert incoming_named.prompt_id == incoming_prompt_id + async with EmbeddedComfyClient() as embedded_comfy_client: + outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt, + incoming_named.prompt_id) + worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, [])) + + thread_pool.submit(lambda: asyncio.run(in_thread())) + # this was completed over the comfyui queue interface, so it should be a task invocation + frontend_pov_result: TaskInvocation = await test_prompt.completed + assert frontend_pov_result is not None + assert frontend_pov_result.item_id == test_prompt.prompt_id + assert frontend_pov_result.outputs is not None + assert len(frontend_pov_result.outputs) == 1 + assert frontend_pov_result.status is not None