Better compatibility with pre-existing prompt_worker method

This commit is contained in:
doctorpangloss 2024-02-08 18:07:37 -08:00
parent 92898b8c9d
commit 72e92514a4
12 changed files with 268 additions and 127 deletions

View File

@ -124,11 +124,11 @@ parser.add_argument("--distributed-queue-connection-uri", type=str, default=None
parser.add_argument(
'--distributed-queue-roles',
action='append',
choices=['worker', 'prompter'],
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMPQ URL to submit prompts; workers will pull requests off the AMPQ URL.'
choices=['worker', 'frontend'],
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.'
)
parser.add_argument("--distributed-queue-name", type=str, default="comfyui",
help="This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
if options.args_parsing:

View File

@ -71,8 +71,8 @@ class Configuration(dict):
write_out_config_file (bool): Enable writing out the configuration file.
create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits.
distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
distributed_queue_name (str): This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "frontend", or both by writing the flag twice with each role. Frontends will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
"""
def __init__(self, **kwargs):
super().__init__()
@ -131,7 +131,7 @@ class Configuration(dict):
self.write_out_config_file: bool = False
self.create_directories: bool = False
self.distributed_queue_connection_uri: Optional[str] = None
self.distributed_queue_roles: List[str] = []
self.distributed_queue_roles: List[str] = ["worker", "frontend"]
self.distributed_queue_name: str = "comfyui"
for key, value in kwargs.items():
self[key] = value

View File

@ -5,32 +5,13 @@ import gc
import uuid
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Optional
from typing import Optional
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..component_model.make_mutable import make_mutable
from ..component_model.queue_types import BinaryEventTypes
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage
class ServerStub(ExecutorToClientProgress):
"""
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
"""
def __init__(self):
self.client_id = str(uuid.uuid4())
self.last_node_id = None
self.last_prompt_id = None
def send_sync(self,
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
pass
def queue_updated(self):
pass
from ..component_model.executor_types import ExecutorToClientProgress
from ..distributed.server_stub import ServerStub
class EmbeddedComfyClient:

View File

@ -137,7 +137,7 @@ def format_value(x):
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui,
def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui,
object_storage):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
@ -766,14 +766,12 @@ class PromptQueue(AbstractPromptQueue):
self.server = server
self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex)
self.next_task_id = 0
self.queue: typing.List[QueueItem] = []
self.currently_running: typing.Dict[int, QueueItem] = {}
self.currently_running: typing.Dict[str, QueueItem] = {}
# history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs
# todo: use the new History class for the sake of simplicity
self.history: typing.Dict[str, HistoryEntry] = {}
self.flags = {}
server.prompt_queue = self
def size(self) -> int:
return len(self.queue)
@ -784,20 +782,23 @@ class PromptQueue(AbstractPromptQueue):
self.server.queue_updated()
self.not_empty.notify()
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0:
return None
item_with_future: QueueItem = heapq.heappop(self.queue)
task_id = self.next_task_id
assert item_with_future.prompt_id is not None
assert item_with_future.prompt_id != ""
assert item_with_future.prompt_id not in self.currently_running
assert isinstance(item_with_future.prompt_id, str)
task_id = item_with_future.prompt_id
self.currently_running[task_id] = item_with_future
self.next_task_id += 1
self.server.queue_updated()
return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id, outputs: dict,
def task_done(self, item_id: str, outputs: dict,
status: Optional[ExecutionStatus]):
with self.mutex:
queue_item = self.currently_running.pop(item_id)

View File

@ -1,3 +1,4 @@
import signal
import sys
from .. import options
@ -78,12 +79,16 @@ if args.deterministic:
from .. import utils
import yaml
from contextlib import AsyncExitStack
from ..cmd import execution
from ..cmd import server as server_module
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus
from .. import model_management
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..distributed.server_stub import ServerStub
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
@ -145,8 +150,8 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server):
def hook(value, total, preview_image):
def hijack_progress(server: ExecutorToClientProgress):
def hook(value: float, total: float, preview_image):
model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
@ -201,7 +206,7 @@ def cuda_malloc_warning():
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def main():
async def main():
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
print(f"Setting temp directory to: {temp_dir}")
@ -217,10 +222,24 @@ def main():
if args.windows_standalone_build:
folder_paths.create_directories()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
server = server_module.PromptServer(loop)
q = execution.PromptQueue(server)
if args.distributed_queue_connection_uri is not None:
distributed = True
q = DistributedPromptQueue(
caller_server=server if "worker" in args.distributed_queue_roles else None,
connection_uri=args.distributed_queue_connection_uri,
is_caller="frontend" in args.distributed_queue_roles,
is_callee="worker" in args.distributed_queue_roles,
loop=loop,
queue_name=args.distributed_queue_name
)
await q.init()
loop.add_signal_handler(signal.SIGINT, lambda *args, **kwargs: q.close())
else:
distributed = False
q = execution.PromptQueue(server)
server.prompt_queue = q
try:
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
@ -237,7 +256,11 @@ def main():
hijack_progress(server)
cuda_malloc_warning()
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
# in a distributed setting, the prompt worker will not be able to send execution events via the websocket
# the distributed prompt queue will be responsible for simulating those events until the broker is configured to
# pass those messages to the appropriate user
worker_thread_server = server if not distributed else ServerStub()
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
# server has been imported and things should be looking good
initialize_event_tracking(loop)
@ -273,13 +296,14 @@ def main():
server.address = args.listen
server.port = args.port
try:
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
call_on_start=call_on_start))
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
call_on_start=call_on_start)
except KeyboardInterrupt:
await q.close()
print("\nStopped server")
cleanup_temp()
if __name__ == "__main__":
main()
asyncio.run(main())

View File

@ -14,6 +14,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
get is intended to be used by a worker.
"""
@abstractmethod
def size(self) -> int:
"""
@ -32,7 +33,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
pass
@abstractmethod
def get(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
def get(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
"""
Pops an item off the queue. Blocking. If a timeout is provided, this will return None after
:param timeout: the number of seconds to time out for a blocking get
@ -41,7 +42,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
pass
@abstractmethod
def task_done(self, item_id: int, outputs: dict,
def task_done(self, item_id: str, outputs: dict,
status: typing.Optional[ExecutionStatus]):
"""
Signals to the user interface that the task with the specified id is completed
@ -110,5 +111,10 @@ class AbstractPromptQueue(metaclass=ABCMeta):
pass
@abstractmethod
def get_flags(self, reset) -> Flags:
def get_flags(self, reset: bool = True) -> Flags:
"""
Resets the flags for the next model unload or free memory request.
:param reset:
:return:
"""
pass

View File

@ -22,6 +22,14 @@ class StatusMessage(TypedDict):
class ExecutingMessage(TypedDict):
node: str | None
prompt_id: NotRequired[str]
output: NotRequired[dict]
class ProgressMessage(TypedDict):
value: float
max: float
prompt_id: Optional[str]
node: Optional[str]
class ExecutorToClientProgress(Protocol):
@ -39,8 +47,8 @@ class ExecutorToClientProgress(Protocol):
last_prompt_id: Optional[str]
def send_sync(self,
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
event: Literal["status", "executing", "progress"] | BinaryEventTypes | str | None,
data: StatusMessage | ExecutingMessage | ProgressMessage | bytes | bytearray | None, sid: str | None = None):
"""
Sends feedback to the client with the specified ID about a specific node

View File

@ -1,9 +1,13 @@
from __future__ import annotations
import asyncio
import threading
import time
import uuid
from asyncio import AbstractEventLoop, Queue
from asyncio import AbstractEventLoop, Queue, QueueEmpty
from dataclasses import asdict
from functools import partial
from time import sleep
from typing import Optional, Dict, List, Mapping, Tuple, Callable
import jwt
@ -12,6 +16,7 @@ from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import JsonRPC
from .distributed_types import RpcRequest, RpcReply
from .server_stub import ServerStub
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
@ -29,38 +34,38 @@ class DistributedPromptQueue(AbstractPromptQueue):
In a distributed queue, this only returns the client's apparent number of items it is waiting for
:return:
"""
return len(self.caller_local_in_progress)
return len(self._caller_local_in_progress)
async def put_async(self, queue_item: QueueItem):
assert self.is_caller
assert self._is_caller
if self._closing:
return
self.caller_local_in_progress[queue_item.prompt_id] = queue_item
if self.caller_server is not None:
self.caller_server.queue_updated()
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
if self._caller_server is not None:
self._caller_server.queue_updated()
try:
if "token" in queue_item.extra_data:
user_token = queue_item.extra_data["token"]
else:
if "client_id" in queue_item.extra_data:
client_id = queue_item.extra_data["client_id"]
elif self.caller_server.client_id is not None:
client_id = self.caller_server.client_id
elif self._caller_server.client_id is not None:
client_id = self._caller_server.client_id
else:
client_id = str(uuid.uuid4())
# todo: should we really do this?
self.caller_server.client_id = client_id
self._caller_server.client_id = client_id
# create a stub token
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
assert self.rpc is not None
assert self._rpc is not None
res: TaskInvocation = RpcReply(
**(await self.rpc.call(self.queue_name, {"request": asdict(request)}))).as_task_invocation()
**(await self._rpc.call(self._queue_name, {"request": asdict(request)}))).as_task_invocation()
self.caller_history.put(queue_item, res.outputs, res.status)
if self.caller_server is not None:
self.caller_server.queue_updated()
self._caller_history.put(queue_item, res.outputs, res.status)
if self._caller_server is not None:
self._caller_server.queue_updated()
# if this has a completion future, complete it
if queue_item.completed is not None:
@ -69,8 +74,8 @@ class DistributedPromptQueue(AbstractPromptQueue):
except Exception as e:
# if a caller-side error occurred, use the passed error for the messages
# we didn't receive any outputs here
self.caller_history.put(queue_item, outputs={},
status=ExecutionStatus(status_str="error", completed=False, messages=[str(e)]))
self._caller_history.put(queue_item, outputs={},
status=ExecutionStatus(status_str="error", completed=False, messages=[str(e)]))
# if we have a completer, propoagate the exception to it
if queue_item.completed is not None:
@ -79,28 +84,31 @@ class DistributedPromptQueue(AbstractPromptQueue):
# otherwise, this should raise in the event loop, which I suppose isn't handled
raise e
finally:
self.caller_local_in_progress.pop(queue_item.prompt_id)
if self.caller_server is not None:
self.caller_server.queue_updated()
self._caller_local_in_progress.pop(queue_item.prompt_id)
if self._caller_server is not None:
# todo: this ensures that the web ui is notified about the completed task, but it should really be done by worker
self._caller_server.send_sync("executing", {"node": None, "prompt_id": queue_item.prompt_id},
self._caller_server.client_id)
self._caller_server.queue_updated()
def put(self, item: QueueItem):
# caller: execute on main thread
assert self.is_caller
assert self._is_caller
if self._closing:
return
# this is called by the web server and its event loop is perfectly fine to use
# the future is now ignored
self.loop.call_soon_threadsafe(self.put_async, item)
asyncio.run_coroutine_threadsafe(self.put_async(item), self._loop)
async def _callee_do_work_item(self, request: dict) -> dict:
assert self.is_callee
assert self._is_callee
request_obj = RpcRequest.from_dict(request)
item = request_obj.as_queue_tuple().queue_tuple
item_with_completer = QueueItem(item, self.loop.create_future())
self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
item_with_completer = QueueItem(item, self._loop.create_future())
self._callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
# todo: check if we have the local model content needed to execute this request and if not, reject it
# todo: check if we have enough memory to execute this request, and if not, reject it
await self.callee_local_queue.put(item)
self._callee_local_queue.put_nowait(item)
# technically this could be messed with or overwritten
assert item_with_completer.completed is not None
@ -110,11 +118,35 @@ class DistributedPromptQueue(AbstractPromptQueue):
invocation = await item_with_completer.completed
return asdict(RpcReply.from_task_invocation(invocation, request_obj.user_token))
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str | int]]:
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str]]:
# callee: executed on the worker thread
assert self.is_callee
assert self._is_callee
# the loop receiving messages must not be mounted on the worker thread
# otherwise receiving messages will be blocked forever
worker_event_loop = asyncio.get_event_loop()
assert self._loop != worker_event_loop, "get only makes sense in the context of the legacy comfyui prompt worker"
# spin wait
timeout = timeout or 30.0
item = None
while timeout > 0:
try:
item = self._callee_local_queue.get_nowait()
break
except QueueEmpty:
start_time = time.time()
sleep(0.1)
timeout -= time.time() - start_time
if item is None:
return None
return item, item[1]
async def get_async(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str]]:
# callee: executed anywhere
assert self._is_callee
try:
item = asyncio.run_coroutine_threadsafe(self.callee_local_queue.get(), self.loop).result(timeout)
item: QueueTuple = await asyncio.wait_for(self._callee_local_queue.get(), timeout)
except TimeoutError:
return None
@ -122,20 +154,21 @@ class DistributedPromptQueue(AbstractPromptQueue):
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]):
# callee: executed on the worker thread
assert self.is_callee
pending = self.callee_local_in_progress.pop(item_id)
assert self._is_callee
pending = self._callee_local_in_progress.pop(item_id)
assert pending is not None
assert pending.completed is not None
assert not pending.completed.done()
# finish the task. status will transmit the errors in comfy's domain-specific way
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status))
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:
"""
In a distributed queue, all queue items are assumed to be currently in progress
:return:
"""
return [], [item.queue_tuple for item in self.caller_local_in_progress.values()]
return [], [item.queue_tuple for item in self._caller_local_in_progress.values()]
def get_tasks_remaining(self) -> int:
"""
@ -143,7 +176,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
:return:
"""
# caller: executed on main thread
return len(self.caller_local_in_progress)
return len(self._caller_local_in_progress)
def wipe_queue(self) -> None:
"""
@ -162,13 +195,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
def get_history(self, prompt_id: Optional[int] = None, max_items=None, offset=-1) \
-> Mapping[str, HistoryEntry]:
return self.caller_history.copy(prompt_id=prompt_id, max_items=max_items, offset=offset)
return self._caller_history.copy(prompt_id=prompt_id, max_items=max_items, offset=offset)
def wipe_history(self):
self.caller_history.clear()
self._caller_history.clear()
def delete_history_item(self, id_to_delete):
self.caller_history.pop(id_to_delete)
self._caller_history.pop(id_to_delete)
def set_flag(self, name: str, data: bool) -> None:
"""
@ -179,7 +212,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
"""
pass
def get_flags(self, reset) -> Flags:
def get_flags(self, reset=True) -> Flags:
"""
Does nothing on distributed queues. Workers must manage their own memory.
:param reset:
@ -188,7 +221,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
return Flags()
def __init__(self,
server: Optional[ExecutorToClientProgress | PromptServer] = None,
caller_server: Optional[ExecutorToClientProgress | PromptServer] = None,
queue_name: str = "comfyui",
connection_uri="amqp://localhost/",
is_caller=True,
@ -196,40 +229,52 @@ class DistributedPromptQueue(AbstractPromptQueue):
loop: Optional[AbstractEventLoop] = None):
super().__init__()
# this constructor is called on the main thread
self.loop = loop or asyncio.get_event_loop() or asyncio.new_event_loop()
self.queue_name = queue_name
self.connection_uri = connection_uri
self.connection: Optional[AbstractConnection] = None # Connection will be set up asynchronously
self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
self.is_caller = is_caller
self.is_callee = is_callee
self._loop = loop or asyncio.get_event_loop() or asyncio.new_event_loop()
self._queue_name = queue_name
self._connection_uri = connection_uri
self._connection: Optional[AbstractConnection] = None # Connection will be set up asynchronously
self._channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
self._is_caller = is_caller
self._is_callee = is_callee
self._closing = False
self._initialized = False
# as rpc caller
self.caller_server = server
self.caller_local_in_progress: dict[str | int, QueueItem] = {}
self.caller_history: History = History()
self._caller_server = caller_server or ServerStub()
self._caller_local_in_progress: dict[str | int, QueueItem] = {}
self._caller_history: History = History()
# as rpc callee
self.callee_local_queue: Queue = Queue()
self.callee_local_in_progress: Dict[int | str, QueueItem] = {}
self.rpc: Optional[JsonRPC] = None
self._callee_local_queue: Queue = Queue()
self._callee_local_in_progress: Dict[int | str, QueueItem] = {}
self._rpc: Optional[JsonRPC] = None
# todo: the prompt queue really shouldn't do this
if server is not None:
server.prompt_queue = self
async def __aenter__(self):
await self.init()
return self
async def __aexit__(self, *args):
await self.close()
async def init(self):
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
self.channel = await self.connection.channel()
self.rpc = await JsonRPC.create(channel=self.channel)
self.rpc.host_exceptions = True
if self._initialized:
return
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
self._channel = await self._connection.channel()
self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True
# this makes the queue available to complete work items
if self.is_callee:
await self.rpc.register(self.queue_name, self._callee_do_work_item)
if self._is_callee:
await self._rpc.register(self._queue_name, self._callee_do_work_item)
self._initialized = True
async def close(self):
if self._closing or not self._initialized:
return
self._closing = True
await self.rpc.close()
await self.channel.close()
await self.connection.close()
await self._rpc.close()
await self._channel.close()
await self._connection.close()
self._initialized = False
self._closing = False

View File

@ -0,0 +1,26 @@
from __future__ import annotations
import uuid
from typing import Literal
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, ExecutingMessage
from ..component_model.queue_types import BinaryEventTypes
class ServerStub(ExecutorToClientProgress):
"""
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
"""
def __init__(self):
self.client_id = str(uuid.uuid4())
self.last_node_id = None
self.last_prompt_id = None
def send_sync(self,
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
data: StatusMessage | ExecutingMessage | bytes | bytearray | None, sid: str | None = None):
pass
def queue_updated(self):
pass

View File

@ -446,10 +446,10 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
def __init__(self, total: float):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.total: float = total
self.current: float = 0.0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None, preview=None):

View File

@ -1,4 +1,6 @@
import asyncio
from comfy.cmd.main import main
if __name__ == "__main__":
main()
asyncio.run(main())

View File

@ -1,13 +1,16 @@
import asyncio
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
import jwt
import pytest
from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.distributed.server_stub import ServerStub
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from testcontainers.rabbitmq import RabbitMqContainer
@ -15,6 +18,16 @@ from testcontainers.rabbitmq import RabbitMqContainer
os.environ["TC_HOST"] = "localhost"
def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
validation_tuple = validate_prompt(prompt)
item_id = str(uuid.uuid4())
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
return QueueItem(queue_tuple, None)
@pytest.mark.asyncio
async def test_sign_jwt_auth_none():
client_id = str(uuid.uuid4())
@ -29,21 +42,56 @@ async def test_basic_queue_worker() -> None:
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"):
# this unfortunately does a bunch of initialization on the test thread
from comfy.cmd.execution import validate_prompt
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init()
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
validation_tuple = validate_prompt(prompt)
item_id = str(uuid.uuid4())
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
assert res.item_id == item_id
queue_item = create_test_prompt()
res: TaskInvocation = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()
@pytest.mark.asyncio
async def test_distributed_prompt_queues_same_process():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
test_prompt = create_test_prompt()
test_prompt.completed = asyncio.Future()
frontend.put(test_prompt)
# start a worker thread
thread_pool = ThreadPoolExecutor(max_workers=1)
async def in_thread():
incoming, incoming_prompt_id = worker.get()
assert incoming is not None
incoming_named = NamedQueueTuple(incoming)
assert incoming_named.prompt_id == incoming_prompt_id
async with EmbeddedComfyClient() as embedded_comfy_client:
outputs = await embedded_comfy_client.queue_prompt(incoming_named.prompt,
incoming_named.prompt_id)
worker.task_done(incoming_named.prompt_id, outputs, ExecutionStatus("success", True, []))
thread_pool.submit(lambda: asyncio.run(in_thread()))
# this was completed over the comfyui queue interface, so it should be a task invocation
frontend_pov_result: TaskInvocation = await test_prompt.completed
assert frontend_pov_result is not None
assert frontend_pov_result.item_id == test_prompt.prompt_id
assert frontend_pov_result.outputs is not None
assert len(frontend_pov_result.outputs) == 1
assert frontend_pov_result.status is not None