mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Better compatibility with pre-existing prompt_worker method
This commit is contained in:
parent
92898b8c9d
commit
72e92514a4
@ -124,11 +124,11 @@ parser.add_argument("--distributed-queue-connection-uri", type=str, default=None
|
||||
parser.add_argument(
|
||||
'--distributed-queue-roles',
|
||||
action='append',
|
||||
choices=['worker', 'prompter'],
|
||||
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMPQ URL to submit prompts; workers will pull requests off the AMPQ URL.'
|
||||
choices=['worker', 'frontend'],
|
||||
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.'
|
||||
)
|
||||
parser.add_argument("--distributed-queue-name", type=str, default="comfyui",
|
||||
help="This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
|
||||
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
|
||||
|
||||
|
||||
if options.args_parsing:
|
||||
|
||||
@ -71,8 +71,8 @@ class Configuration(dict):
|
||||
write_out_config_file (bool): Enable writing out the configuration file.
|
||||
create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits.
|
||||
distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
|
||||
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
|
||||
distributed_queue_name (str): This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
|
||||
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
|
||||
distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
@ -131,7 +131,7 @@ class Configuration(dict):
|
||||
self.write_out_config_file: bool = False
|
||||
self.create_directories: bool = False
|
||||
self.distributed_queue_connection_uri: Optional[str] = None
|
||||
self.distributed_queue_roles: List[str] = []
|
||||
self.distributed_queue_roles: List[str] = ["worker", "frontend"]
|
||||
self.distributed_queue_name: str = "comfyui"
|
||||
for key, value in kwargs.items():
|
||||
self[key] = value
|
||||
|
||||
@ -5,32 +5,13 @@ import gc
|
||||
import uuid
|
||||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..cli_args_types import Configuration
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
from ..component_model.queue_types import BinaryEventTypes
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage
|
||||
|
||||
|
||||
class ServerStub(ExecutorToClientProgress):
|
||||
"""
|
||||
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
|
||||
def send_sync(self,
|
||||
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
||||
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
|
||||
pass
|
||||
|
||||
def queue_updated(self):
|
||||
pass
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..distributed.server_stub import ServerStub
|
||||
|
||||
|
||||
class EmbeddedComfyClient:
|
||||
|
||||
@ -137,7 +137,7 @@ def format_value(x):
|
||||
return str(x)
|
||||
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui,
|
||||
def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui,
|
||||
object_storage):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -766,14 +766,12 @@ class PromptQueue(AbstractPromptQueue):
|
||||
self.server = server
|
||||
self.mutex = threading.RLock()
|
||||
self.not_empty = threading.Condition(self.mutex)
|
||||
self.next_task_id = 0
|
||||
self.queue: typing.List[QueueItem] = []
|
||||
self.currently_running: typing.Dict[int, QueueItem] = {}
|
||||
self.currently_running: typing.Dict[str, QueueItem] = {}
|
||||
# history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs
|
||||
# todo: use the new History class for the sake of simplicity
|
||||
self.history: typing.Dict[str, HistoryEntry] = {}
|
||||
self.flags = {}
|
||||
server.prompt_queue = self
|
||||
|
||||
def size(self) -> int:
|
||||
return len(self.queue)
|
||||
@ -784,20 +782,23 @@ class PromptQueue(AbstractPromptQueue):
|
||||
self.server.queue_updated()
|
||||
self.not_empty.notify()
|
||||
|
||||
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
|
||||
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
|
||||
with self.not_empty:
|
||||
while len(self.queue) == 0:
|
||||
self.not_empty.wait(timeout=timeout)
|
||||
if timeout is not None and len(self.queue) == 0:
|
||||
return None
|
||||
item_with_future: QueueItem = heapq.heappop(self.queue)
|
||||
task_id = self.next_task_id
|
||||
assert item_with_future.prompt_id is not None
|
||||
assert item_with_future.prompt_id != ""
|
||||
assert item_with_future.prompt_id not in self.currently_running
|
||||
assert isinstance(item_with_future.prompt_id, str)
|
||||
task_id = item_with_future.prompt_id
|
||||
self.currently_running[task_id] = item_with_future
|
||||
self.next_task_id += 1
|
||||
self.server.queue_updated()
|
||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||
|
||||
def task_done(self, item_id, outputs: dict,
|
||||
def task_done(self, item_id: str, outputs: dict,
|
||||
status: Optional[ExecutionStatus]):
|
||||
with self.mutex:
|
||||
queue_item = self.currently_running.pop(item_id)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from .. import options
|
||||
@ -78,12 +79,16 @@ if args.deterministic:
|
||||
|
||||
from .. import utils
|
||||
import yaml
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from ..cmd import execution
|
||||
from ..cmd import server as server_module
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus
|
||||
from .. import model_management
|
||||
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..distributed.server_stub import ServerStub
|
||||
|
||||
|
||||
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
@ -145,8 +150,8 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
||||
|
||||
def hijack_progress(server):
|
||||
def hook(value, total, preview_image):
|
||||
def hijack_progress(server: ExecutorToClientProgress):
|
||||
def hook(value: float, total: float, preview_image):
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
|
||||
@ -201,7 +206,7 @@ def cuda_malloc_warning():
|
||||
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def main():
|
||||
async def main():
|
||||
if args.temp_directory:
|
||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||
print(f"Setting temp directory to: {temp_dir}")
|
||||
@ -217,10 +222,24 @@ def main():
|
||||
if args.windows_standalone_build:
|
||||
folder_paths.create_directories()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = asyncio.get_event_loop()
|
||||
server = server_module.PromptServer(loop)
|
||||
q = execution.PromptQueue(server)
|
||||
if args.distributed_queue_connection_uri is not None:
|
||||
distributed = True
|
||||
q = DistributedPromptQueue(
|
||||
caller_server=server if "worker" in args.distributed_queue_roles else None,
|
||||
connection_uri=args.distributed_queue_connection_uri,
|
||||
is_caller="frontend" in args.distributed_queue_roles,
|
||||
is_callee="worker" in args.distributed_queue_roles,
|
||||
loop=loop,
|
||||
queue_name=args.distributed_queue_name
|
||||
)
|
||||
await q.init()
|
||||
loop.add_signal_handler(signal.SIGINT, lambda *args, **kwargs: q.close())
|
||||
else:
|
||||
distributed = False
|
||||
q = execution.PromptQueue(server)
|
||||
server.prompt_queue = q
|
||||
|
||||
try:
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
@ -237,7 +256,11 @@ def main():
|
||||
hijack_progress(server)
|
||||
cuda_malloc_warning()
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
|
||||
# in a distributed setting, the prompt worker will not be able to send execution events via the websocket
|
||||
# the distributed prompt queue will be responsible for simulating those events until the broker is configured to
|
||||
# pass those messages to the appropriate user
|
||||
worker_thread_server = server if not distributed else ServerStub()
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
|
||||
|
||||
# server has been imported and things should be looking good
|
||||
initialize_event_tracking(loop)
|
||||
@ -273,13 +296,14 @@ def main():
|
||||
server.address = args.listen
|
||||
server.port = args.port
|
||||
try:
|
||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
|
||||
call_on_start=call_on_start))
|
||||
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
|
||||
call_on_start=call_on_start)
|
||||
except KeyboardInterrupt:
|
||||
await q.close()
|
||||
print("\nStopped server")
|
||||
|
||||
cleanup_temp()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(main())
|
||||
|
||||
@ -14,6 +14,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
||||
|
||||
get is intended to be used by a worker.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def size(self) -> int:
|
||||
"""
|
||||
@ -32,7 +33,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
|
||||
def get(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
|
||||
"""
|
||||
Pops an item off the queue. Blocking. If a timeout is provided, this will return None after
|
||||
:param timeout: the number of seconds to time out for a blocking get
|
||||
@ -41,7 +42,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def task_done(self, item_id: int, outputs: dict,
|
||||
def task_done(self, item_id: str, outputs: dict,
|
||||
status: typing.Optional[ExecutionStatus]):
|
||||
"""
|
||||
Signals to the user interface that the task with the specified id is completed
|
||||
@ -110,5 +111,10 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_flags(self, reset) -> Flags:
|
||||
def get_flags(self, reset: bool = True) -> Flags:
|
||||
"""
|
||||
Resets the flags for the next model unload or free memory request.
|
||||
:param reset:
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -22,6 +22,14 @@ class StatusMessage(TypedDict):
|
||||
class ExecutingMessage(TypedDict):
|
||||
node: str | None
|
||||
prompt_id: NotRequired[str]
|
||||
output: NotRequired[dict]
|
||||
|
||||
|
||||
class ProgressMessage(TypedDict):
|
||||
value: float
|
||||
max: float
|
||||
prompt_id: Optional[str]
|
||||
node: Optional[str]
|
||||
|
||||
|
||||
class ExecutorToClientProgress(Protocol):
|
||||
@ -39,8 +47,8 @@ class ExecutorToClientProgress(Protocol):
|
||||
last_prompt_id: Optional[str]
|
||||
|
||||
def send_sync(self,
|
||||
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
||||
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
|
||||
event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None,
|
||||
data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None):
|
||||
"""
|
||||
Sends feedback to the client with the specified ID about a specific node
|
||||
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from asyncio import AbstractEventLoop, Queue
|
||||
from asyncio import AbstractEventLoop, Queue, QueueEmpty
|
||||
from dataclasses import asdict
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Optional, Dict, List, Mapping, Tuple, Callable
|
||||
|
||||
import jwt
|
||||
@ -12,6 +16,7 @@ from aio_pika.abc import AbstractConnection, AbstractChannel
|
||||
from aio_pika.patterns import JsonRPC
|
||||
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from .server_stub import ServerStub
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||
@ -29,38 +34,38 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
In a distributed queue, this only returns the client's apparent number of items it is waiting for
|
||||
:return:
|
||||
"""
|
||||
return len(self.caller_local_in_progress)
|
||||
return len(self._caller_local_in_progress)
|
||||
|
||||
async def put_async(self, queue_item: QueueItem):
|
||||
assert self.is_caller
|
||||
assert self._is_caller
|
||||
if self._closing:
|
||||
return
|
||||
self.caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||
if self.caller_server is not None:
|
||||
self.caller_server.queue_updated()
|
||||
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||
if self._caller_server is not None:
|
||||
self._caller_server.queue_updated()
|
||||
try:
|
||||
if "token" in queue_item.extra_data:
|
||||
user_token = queue_item.extra_data["token"]
|
||||
else:
|
||||
if "client_id" in queue_item.extra_data:
|
||||
client_id = queue_item.extra_data["client_id"]
|
||||
elif self.caller_server.client_id is not None:
|
||||
client_id = self.caller_server.client_id
|
||||
elif self._caller_server.client_id is not None:
|
||||
client_id = self._caller_server.client_id
|
||||
else:
|
||||
client_id = str(uuid.uuid4())
|
||||
# todo: should we really do this?
|
||||
self.caller_server.client_id = client_id
|
||||
self._caller_server.client_id = client_id
|
||||
|
||||
# create a stub token
|
||||
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
|
||||
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
|
||||
assert self.rpc is not None
|
||||
assert self._rpc is not None
|
||||
res: TaskInvocation = RpcReply(
|
||||
**(await self.rpc.call(self.queue_name, {"request": asdict(request)}))).as_task_invocation()
|
||||
**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
|
||||
|
||||
self.caller_history.put(queue_item, res.outputs, res.status)
|
||||
if self.caller_server is not None:
|
||||
self.caller_server.queue_updated()
|
||||
self._caller_history.put(queue_item, res.outputs, res.status)
|
||||
if self._caller_server is not None:
|
||||
self._caller_server.queue_updated()
|
||||
|
||||
# if this has a completion future, complete it
|
||||
if queue_item.completed is not None:
|
||||
@ -69,8 +74,8 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
except Exception as e:
|
||||
# if a caller-side error occurred, use the passed error for the messages
|
||||
# we didn't receive any outputs here
|
||||
self.caller_history.put(queue_item, outputs={},
|
||||
status=ExecutionStatus(status_str="error", completed=False, messages=[str(e)]))
|
||||
self._caller_history.put(queue_item, outputs={},
|
||||
status=ExecutionStatus(status_str="error", completed=False, messages=[str(e)]))
|
||||
|
||||
# if we have a completer, propoagate the exception to it
|
||||
if queue_item.completed is not None:
|
||||
@ -79,28 +84,31 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
# otherwise, this should raise in the event loop, which I suppose isn't handled
|
||||
raise e
|
||||
finally:
|
||||
self.caller_local_in_progress.pop(queue_item.prompt_id)
|
||||
if self.caller_server is not None:
|
||||
self.caller_server.queue_updated()
|
||||
self._caller_local_in_progress.pop(queue_item.prompt_id)
|
||||
if self._caller_server is not None:
|
||||
# todo: this ensures that the web ui is notified about the completed task, but it should really be done by worker
|
||||
self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id},
|
||||
self._caller_server.client_id)
|
||||
self._caller_server.queue_updated()
|
||||
|
||||
def put(self, item: QueueItem):
|
||||
# caller: execute on main thread
|
||||
assert self.is_caller
|
||||
assert self._is_caller
|
||||
if self._closing:
|
||||
return
|
||||
# this is called by the web server and its event loop is perfectly fine to use
|
||||
# the future is now ignored
|
||||
self.loop.call_soon_threadsafe(self.put_async, item)
|
||||
asyncio.run_coroutine_threadsafe(self.put_async(item), self._loop)
|
||||
|
||||
async def _callee_do_work_item(self, request: dict) -> dict:
|
||||
assert self.is_callee
|
||||
assert self._is_callee
|
||||
request_obj = RpcRequest.from_dict(request)
|
||||
item = request_obj.as_queue_tuple().queue_tuple
|
||||
item_with_completer = QueueItem(item, self.loop.create_future())
|
||||
self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
|
||||
item_with_completer = QueueItem(item, self._loop.create_future())
|
||||
self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
|
||||
# todo: check if we have the local model content needed to execute this request and if not, reject it
|
||||
# todo: check if we have enough memory to execute this request, and if not, reject it
|
||||
await self.callee_local_queue.put(item)
|
||||
self._callee_local_queue.put_nowait(item)
|
||||
|
||||
# technically this could be messed with or overwritten
|
||||
assert item_with_completer.completed is not None
|
||||
@ -110,11 +118,35 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
invocation = await item_with_completer.completed
|
||||
return asdict(RpcReply.from_task_invocation(invocation, request_obj.user_token))
|
||||
|
||||
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str | int]]:
|
||||
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str]]:
|
||||
# callee: executed on the worker thread
|
||||
assert self.is_callee
|
||||
assert self._is_callee
|
||||
# the loop receiving messages must not be mounted on the worker thread
|
||||
# otherwise receiving messages will be blocked forever
|
||||
worker_event_loop = asyncio.get_event_loop()
|
||||
assert self._loop != worker_event_loop, "get only makes sense in the context of the legacy comfyui prompt worker"
|
||||
# spin wait
|
||||
timeout = timeout or 30.0
|
||||
item = None
|
||||
while timeout > 0:
|
||||
try:
|
||||
item = self._callee_local_queue.get_nowait()
|
||||
break
|
||||
except QueueEmpty:
|
||||
start_time = time.time()
|
||||
sleep(0.1)
|
||||
timeout -= time.time() - start_time
|
||||
|
||||
if item is None:
|
||||
return None
|
||||
|
||||
return item, item[1]
|
||||
|
||||
async def get_async(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str]]:
|
||||
# callee: executed anywhere
|
||||
assert self._is_callee
|
||||
try:
|
||||
item = asyncio.run_coroutine_threadsafe(self.callee_local_queue.get(), self.loop).result(timeout)
|
||||
item: QueueTuple = await asyncio.wait_for(self._callee_local_queue.get(), timeout)
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
@ -122,20 +154,21 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
|
||||
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]):
|
||||
# callee: executed on the worker thread
|
||||
assert self.is_callee
|
||||
pending = self.callee_local_in_progress.pop(item_id)
|
||||
assert self._is_callee
|
||||
pending = self._callee_local_in_progress.pop(item_id)
|
||||
assert pending is not None
|
||||
assert pending.completed is not None
|
||||
assert not pending.completed.done()
|
||||
# finish the task. status will transmit the errors in comfy's domain-specific way
|
||||
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status))
|
||||
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
|
||||
|
||||
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:
|
||||
"""
|
||||
In a distributed queue, all queue items are assumed to be currently in progress
|
||||
:return:
|
||||
"""
|
||||
return [], [item.queue_tuple for item in self.caller_local_in_progress.values()]
|
||||
return [], [item.queue_tuple for item in self._caller_local_in_progress.values()]
|
||||
|
||||
def get_tasks_remaining(self) -> int:
|
||||
"""
|
||||
@ -143,7 +176,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
:return:
|
||||
"""
|
||||
# caller: executed on main thread
|
||||
return len(self.caller_local_in_progress)
|
||||
return len(self._caller_local_in_progress)
|
||||
|
||||
def wipe_queue(self) -> None:
|
||||
"""
|
||||
@ -162,13 +195,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
|
||||
def get_history(self, prompt_id: Optional[int] = None, max_items=None, offset=-1) \
|
||||
-> Mapping[str, HistoryEntry]:
|
||||
return self.caller_history.copy(prompt_id=prompt_id, max_items=max_items, offset=offset)
|
||||
return self._caller_history.copy(prompt_id=prompt_id, max_items=max_items, offset=offset)
|
||||
|
||||
def wipe_history(self):
|
||||
self.caller_history.clear()
|
||||
self._caller_history.clear()
|
||||
|
||||
def delete_history_item(self, id_to_delete):
|
||||
self.caller_history.pop(id_to_delete)
|
||||
self._caller_history.pop(id_to_delete)
|
||||
|
||||
def set_flag(self, name: str, data: bool) -> None:
|
||||
"""
|
||||
@ -179,7 +212,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_flags(self, reset) -> Flags:
|
||||
def get_flags(self, reset=True) -> Flags:
|
||||
"""
|
||||
Does nothing on distributed queues. Workers must manage their own memory.
|
||||
:param reset:
|
||||
@ -188,7 +221,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
return Flags()
|
||||
|
||||
def __init__(self,
|
||||
server: Optional[ExecutorToClientProgress | PromptServer] = None,
|
||||
caller_server: Optional[ExecutorToClientProgress | PromptServer] = None,
|
||||
queue_name: str = "comfyui",
|
||||
connection_uri="amqp://localhost/",
|
||||
is_caller=True,
|
||||
@ -196,40 +229,52 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
loop: Optional[AbstractEventLoop] = None):
|
||||
super().__init__()
|
||||
# this constructor is called on the main thread
|
||||
self.loop = loop or asyncio.get_event_loop() or asyncio.new_event_loop()
|
||||
self.queue_name = queue_name
|
||||
self.connection_uri = connection_uri
|
||||
self.connection: Optional[AbstractConnection] = None # Connection will be set up asynchronously
|
||||
self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
|
||||
self.is_caller = is_caller
|
||||
self.is_callee = is_callee
|
||||
self._loop = loop or asyncio.get_event_loop() or asyncio.new_event_loop()
|
||||
self._queue_name = queue_name
|
||||
self._connection_uri = connection_uri
|
||||
self._connection: Optional[AbstractConnection] = None # Connection will be set up asynchronously
|
||||
self._channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
|
||||
self._is_caller = is_caller
|
||||
self._is_callee = is_callee
|
||||
self._closing = False
|
||||
self._initialized = False
|
||||
|
||||
# as rpc caller
|
||||
self.caller_server = server
|
||||
self.caller_local_in_progress: dict[str | int, QueueItem] = {}
|
||||
self.caller_history: History = History()
|
||||
self._caller_server = caller_server or ServerStub()
|
||||
self._caller_local_in_progress: dict[str | int, QueueItem] = {}
|
||||
self._caller_history: History = History()
|
||||
|
||||
# as rpc callee
|
||||
self.callee_local_queue: Queue = Queue()
|
||||
self.callee_local_in_progress: Dict[int | str, QueueItem] = {}
|
||||
self.rpc: Optional[JsonRPC] = None
|
||||
self._callee_local_queue: Queue = Queue()
|
||||
self._callee_local_in_progress: Dict[int | str, QueueItem] = {}
|
||||
self._rpc: Optional[JsonRPC] = None
|
||||
|
||||
# todo: the prompt queue really shouldn't do this
|
||||
if server is not None:
|
||||
server.prompt_queue = self
|
||||
async def __aenter__(self):
|
||||
await self.init()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await self.close()
|
||||
|
||||
async def init(self):
|
||||
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
|
||||
self.channel = await self.connection.channel()
|
||||
self.rpc = await JsonRPC.create(channel=self.channel)
|
||||
self.rpc.host_exceptions = True
|
||||
if self._initialized:
|
||||
return
|
||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||
self._channel = await self._connection.channel()
|
||||
self._rpc = await JsonRPC.create(channel=self._channel)
|
||||
self._rpc.host_exceptions = True
|
||||
# this makes the queue available to complete work items
|
||||
if self.is_callee:
|
||||
await self.rpc.register(self.queue_name, self._callee_do_work_item)
|
||||
if self._is_callee:
|
||||
await self._rpc.register(self._queue_name, self._callee_do_work_item)
|
||||
self._initialized = True
|
||||
|
||||
async def close(self):
|
||||
if self._closing or not self._initialized:
|
||||
return
|
||||
|
||||
self._closing = True
|
||||
await self.rpc.close()
|
||||
await self.channel.close()
|
||||
await self.connection.close()
|
||||
await self._rpc.close()
|
||||
await self._channel.close()
|
||||
await self._connection.close()
|
||||
self._initialized = False
|
||||
self._closing = False
|
||||
|
||||
26
comfy/distributed/server_stub.py
Normal file
26
comfy/distributed/server_stub.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage
|
||||
from ..component_model.queue_types import BinaryEventTypes
|
||||
|
||||
|
||||
class ServerStub(ExecutorToClientProgress):
|
||||
"""
|
||||
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
|
||||
def send_sync(self,
|
||||
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
||||
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
|
||||
pass
|
||||
|
||||
def queue_updated(self):
|
||||
pass
|
||||
@ -446,10 +446,10 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
def __init__(self, total: float):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.total: float = total
|
||||
self.current: float = 0.0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
|
||||
4
main.py
4
main.py
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from comfy.cmd.main import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(main())
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
from comfy.component_model.make_mutable import make_mutable
|
||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation
|
||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
@ -15,6 +18,16 @@ from testcontainers.rabbitmq import RabbitMqContainer
|
||||
os.environ["TC_HOST"] = "localhost"
|
||||
|
||||
|
||||
def create_test_prompt() -> QueueItem:
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
|
||||
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||
validation_tuple = validate_prompt(prompt)
|
||||
item_id = str(uuid.uuid4())
|
||||
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||
return QueueItem(queue_tuple, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sign_jwt_auth_none():
|
||||
client_id = str(uuid.uuid4())
|
||||
@ -29,21 +42,56 @@ async def test_basic_queue_worker() -> None:
|
||||
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
||||
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"):
|
||||
# this unfortunately does a bunch of initialization on the test thread
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
# now submit some jobs
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||
await distributed_queue.init()
|
||||
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||
validation_tuple = validate_prompt(prompt)
|
||||
item_id = str(uuid.uuid4())
|
||||
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
|
||||
assert res.item_id == item_id
|
||||
queue_item = create_test_prompt()
|
||||
res: TaskInvocation = await distributed_queue.put_async(queue_item)
|
||||
assert res.item_id == queue_item.prompt_id
|
||||
assert len(res.outputs) == 1
|
||||
assert res.status is not None
|
||||
assert res.status.status_str == "success"
|
||||
await distributed_queue.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distributed_prompt_queues_same_process():
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
||||
connection_uri=connection_uri) as frontend:
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False,
|
||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
||||
test_prompt = create_test_prompt()
|
||||
test_prompt.completed = asyncio.Future()
|
||||
|
||||
frontend.put(test_prompt)
|
||||
|
||||
# start a worker thread
|
||||
thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
async def in_thread():
|
||||
incoming, incoming_prompt_id = worker.get()
|
||||
assert incoming is not None
|
||||
incoming_named = NamedQueueTuple(incoming)
|
||||
assert incoming_named.prompt_id == incoming_prompt_id
|
||||
async with EmbeddedComfyClient() as embedded_comfy_client:
|
||||
outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt,
|
||||
incoming_named.prompt_id)
|
||||
worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, []))
|
||||
|
||||
thread_pool.submit(lambda: asyncio.run(in_thread()))
|
||||
# this was completed over the comfyui queue interface, so it should be a task invocation
|
||||
frontend_pov_result: TaskInvocation = await test_prompt.completed
|
||||
assert frontend_pov_result is not None
|
||||
assert frontend_pov_result.item_id == test_prompt.prompt_id
|
||||
assert frontend_pov_result.outputs is not None
|
||||
assert len(frontend_pov_result.outputs) == 1
|
||||
assert frontend_pov_result.status is not None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user