Improved support for distributed queues

This commit is contained in:
doctorpangloss 2024-02-08 14:55:07 -08:00
parent 3367362cec
commit 92898b8c9d
14 changed files with 278 additions and 72 deletions

View File

@ -186,7 +186,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language. You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
7. To create the standalone binary: 7. To create the standalone binary:
```shell ```shell
python -m pyinstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py python -m PyInstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py
``` ```
### Authoring Custom Nodes ### Authoring Custom Nodes

View File

@ -119,6 +119,17 @@ parser.add_argument("--plausible-analytics-domain", required=False,
help="Specifies the domain name for analytics events.") help="Specifies the domain name for analytics events.")
parser.add_argument("--analytics-use-identity-provider", action="store_true", parser.add_argument("--analytics-use-identity-provider", action="store_true",
help="Uses platform identifiers for unique visitor analytics.") help="Uses platform identifiers for unique visitor analytics.")
parser.add_argument("--distributed-queue-connection-uri", type=str, default=None,
help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
parser.add_argument(
'--distributed-queue-roles',
action='append',
choices=['worker', 'prompter'],
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMPQ URL to submit prompts; workers will pull requests off the AMPQ URL.'
)
parser.add_argument("--distributed-queue-name", type=str, default="comfyui",
help="This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
if options.args_parsing: if options.args_parsing:
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()

View File

@ -70,6 +70,9 @@ class Configuration(dict):
analytics_use_identity_provider (bool): Use platform identifiers for analytics. analytics_use_identity_provider (bool): Use platform identifiers for analytics.
write_out_config_file (bool): Enable writing out the configuration file. write_out_config_file (bool): Enable writing out the configuration file.
create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits. create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits.
distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
distributed_queue_name (str): This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
@ -127,6 +130,9 @@ class Configuration(dict):
self.analytics_use_identity_provider: bool = False self.analytics_use_identity_provider: bool = False
self.write_out_config_file: bool = False self.write_out_config_file: bool = False
self.create_directories: bool = False self.create_directories: bool = False
self.distributed_queue_connection_uri: Optional[str] = None
self.distributed_queue_roles: List[str] = []
self.distributed_queue_name: str = "comfyui"
for key, value in kwargs.items(): for key, value in kwargs.items():
self[key] = value self[key] = value

View File

@ -18,6 +18,7 @@ class ServerStub(ExecutorToClientProgress):
""" """
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events. This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
""" """
def __init__(self): def __init__(self):
self.client_id = str(uuid.uuid4()) self.client_id = str(uuid.uuid4())
self.last_node_id = None self.last_node_id = None
@ -69,18 +70,25 @@ class EmbeddedComfyClient:
In order to use this in blocking methods, learn more about asyncio online. In order to use this in blocking methods, learn more about asyncio online.
""" """
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None, def __init__(self, configuration: Optional[Configuration] = None,
progress_handler: Optional[ExecutorToClientProgress] = None,
loop: Optional[AbstractEventLoop] = None,
max_workers: int = 1): max_workers: int = 1):
self._server_stub = ServerStub() self._progress_handler = progress_handler or ServerStub()
self._executor = ThreadPoolExecutor(max_workers=max_workers) self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._configuration = configuration self._configuration = configuration
# we don't want to import the executor yet # we don't want to import the executor yet
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
self._is_running = False
@property
def is_running(self) -> bool:
return self._is_running
async def __aenter__(self): async def __aenter__(self):
# Perform asynchronous initialization here, if needed
await self._initialize_prompt_executor() await self._initialize_prompt_executor()
self._is_running = True
return self return self
async def __aexit__(self, *args): async def __aexit__(self, *args):
@ -101,6 +109,7 @@ class EmbeddedComfyClient:
await self._loop.run_in_executor(self._executor, cleanup) await self._loop.run_in_executor(self._executor, cleanup)
self._executor.shutdown(wait=True) self._executor.shutdown(wait=True)
self._is_running = False
async def _initialize_prompt_executor(self): async def _initialize_prompt_executor(self):
# This method must be async since it's used in __aenter__ # This method must be async since it's used in __aenter__
@ -115,19 +124,23 @@ class EmbeddedComfyClient:
from ..cmd.execution import PromptExecutor from ..cmd.execution import PromptExecutor
self._prompt_executor = PromptExecutor(self._server_stub) self._prompt_executor = PromptExecutor(self._progress_handler)
await self._loop.run_in_executor(self._executor, create_executor_in_thread) await self._loop.run_in_executor(self._executor, create_executor_in_thread)
async def queue_prompt(self, prompt: PromptDict) -> dict: async def queue_prompt(self,
prompt_id = str(uuid.uuid4()) prompt: PromptDict | dict,
prompt_id: Optional[str] = None,
client_id: Optional[str] = None) -> dict:
prompt_id = prompt_id or str(uuid.uuid4())
client_id = client_id or self._progress_handler.client_id or None
def execute_prompt() -> dict: def execute_prompt() -> dict:
from ..cmd.execution import validate_prompt from ..cmd.execution import validate_prompt
prompt_mut = make_mutable(prompt) prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut) validation_tuple = validate_prompt(prompt_mut)
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": self._server_stub.client_id}, self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2]) execute_outputs=validation_tuple[2])
if self._prompt_executor.success: if self._prompt_executor.success:
return self._prompt_executor.outputs_ui return self._prompt_executor.outputs_ui

View File

@ -807,7 +807,7 @@ class PromptQueue(AbstractPromptQueue):
status_dict: Optional[dict] = None status_dict: Optional[dict] = None
if status is not None: if status is not None:
status_dict = copy.deepcopy(status._asdict()) status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
self.history[prompt[1]] = { self.history[prompt[1]] = {
"prompt": prompt, "prompt": prompt,

View File

@ -1,12 +1,9 @@
from __future__ import annotations
from typing import Mapping, Any from typing import Mapping, Any
def make_mutable(obj: Any) -> dict: def _make_mutable(obj: Any) -> Any:
"""
Makes an immutable dict, frozenset or tuple mutable. Otherwise, returns the value.
:param obj: any object
:return:
"""
if isinstance(obj, Mapping) and not isinstance(obj, dict) and not hasattr(obj, "__setitem__"): if isinstance(obj, Mapping) and not isinstance(obj, dict) and not hasattr(obj, "__setitem__"):
obj = dict(obj) obj = dict(obj)
for key, value in obj.items(): for key, value in obj.items():
@ -18,3 +15,17 @@ def make_mutable(obj: Any) -> dict:
if isinstance(obj, frozenset): if isinstance(obj, frozenset):
obj = set([make_mutable(x) for x in obj]) obj = set([make_mutable(x) for x in obj])
return obj return obj
def make_mutable(obj: Mapping) -> dict:
"""
Makes a copy of an immutable dict into a mutable dict.
If the object is already a mutable type or a value type like a string or integer, returns the value.
Returns dict, set or tuple depending on its input, but you should not use it this way.
:param obj: any object
:return:
"""
return _make_mutable(obj)

View File

@ -40,21 +40,21 @@ class HistoryEntry(TypedDict):
status: NotRequired[ExecutionStatusAsDict] status: NotRequired[ExecutionStatusAsDict]
class ExtraData(TypedDict):
client_id: NotRequired[str]
extra_pnginfo: NotRequired[str]
token: NotRequired[str]
@dataclass @dataclass
class QueueItem: class NamedQueueTuple:
""" """
An item awaiting processing in the queue A wrapper class for a queue tuple, the object that is given to executors.
Attributes: Attributes:
queue_tuple (QueueTuple): the corresponding queued workflow and other related data queue_tuple (QueueTuple): the corresponding queued workflow and other related data
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
or a dictionary of outputs
""" """
queue_tuple: QueueTuple queue_tuple: QueueTuple
completed: asyncio.Future[TaskInvocation | dict] | None
def __lt__(self, other: QueueItem):
return self.queue_tuple[0] < other.queue_tuple[0]
@property @property
def priority(self) -> float: def priority(self) -> float:
@ -69,7 +69,7 @@ class QueueItem:
return self.queue_tuple[2] return self.queue_tuple[2]
@property @property
def extra_data(self) -> Optional[dict]: def extra_data(self) -> Optional[ExtraData]:
if len(self.queue_tuple) > 2: if len(self.queue_tuple) > 2:
return self.queue_tuple[3] return self.queue_tuple[3]
else: else:
@ -83,6 +83,22 @@ class QueueItem:
return None return None
@dataclass
class QueueItem(NamedQueueTuple):
"""
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
processing.
Attributes:
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
or a dictionary of outputs
"""
completed: asyncio.Future[TaskInvocation | dict] | None
def __lt__(self, other: QueueItem):
return self.queue_tuple[0] < other.queue_tuple[0]
class BinaryEventTypes(Enum): class BinaryEventTypes(Enum):
PREVIEW_IMAGE = 1 PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2 UNENCODED_PREVIEW_IMAGE = 2

View File

@ -0,0 +1,32 @@
import asyncio
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import RPC
from comfy.distributed.distributed_types import RpcRequest, RpcReply
class DistributedPromptClient:
def __init__(self, queue_name: str = "comfyui",
connection_uri="amqp://localhost/",
loop: Optional[AbstractEventLoop] = None):
self.queue_name = queue_name
self.connection_uri = connection_uri
self.loop = loop or asyncio.get_event_loop()
async def __aenter__(self):
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
self.channel = await self.connection.channel()
self.rpc = await RPC.create(channel=self.channel)
self.rpc.host_exceptions = True
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.channel.close()
await self.rpc.close()
await self.connection.close()
async def queue_prompt(self, request: RpcRequest) -> RpcReply:
return RpcReply(**(await self.rpc.call(self.queue_name, {"request": asdict(request)})))

View File

@ -1,13 +1,17 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from asyncio import AbstractEventLoop import uuid
from asyncio import AbstractEventLoop, Queue
from dataclasses import asdict
from typing import Optional, Dict, List, Mapping, Tuple, Callable from typing import Optional, Dict, List, Mapping, Tuple, Callable
import jwt
from aio_pika import connect_robust from aio_pika import connect_robust
from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import RPC from aio_pika.patterns import JsonRPC
from .distributed_types import RpcRequest, RpcReply
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
@ -29,11 +33,30 @@ class DistributedPromptQueue(AbstractPromptQueue):
async def put_async(self, queue_item: QueueItem): async def put_async(self, queue_item: QueueItem):
assert self.is_caller assert self.is_caller
if self._closing:
return
self.caller_local_in_progress[queue_item.prompt_id] = queue_item self.caller_local_in_progress[queue_item.prompt_id] = queue_item
if self.caller_server is not None: if self.caller_server is not None:
self.caller_server.queue_updated() self.caller_server.queue_updated()
try: try:
res: TaskInvocation = await self.rpc.call(self.queue_name, {"item": queue_item.queue_tuple}) if "token" in queue_item.extra_data:
user_token = queue_item.extra_data["token"]
else:
if "client_id" in queue_item.extra_data:
client_id = queue_item.extra_data["client_id"]
elif self.caller_server.client_id is not None:
client_id = self.caller_server.client_id
else:
client_id = str(uuid.uuid4())
# todo: should we really do this?
self.caller_server.client_id = client_id
# create a stub token
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
assert self.rpc is not None
res: TaskInvocation = RpcReply(
**(await self.rpc.call(self.queue_name, {"request": asdict(request)}))).as_task_invocation()
self.caller_history.put(queue_item, res.outputs, res.status) self.caller_history.put(queue_item, res.outputs, res.status)
if self.caller_server is not None: if self.caller_server is not None:
@ -63,12 +86,16 @@ class DistributedPromptQueue(AbstractPromptQueue):
def put(self, item: QueueItem): def put(self, item: QueueItem):
# caller: execute on main thread # caller: execute on main thread
assert self.is_caller assert self.is_caller
if self._closing:
return
# this is called by the web server and its event loop is perfectly fine to use # this is called by the web server and its event loop is perfectly fine to use
# the future is now ignored # the future is now ignored
self.loop.call_soon_threadsafe(self.put_async, item) self.loop.call_soon_threadsafe(self.put_async, item)
async def _callee_do_work_item(self, item: QueueTuple) -> TaskInvocation: async def _callee_do_work_item(self, request: dict) -> dict:
assert self.is_callee assert self.is_callee
request_obj = RpcRequest.from_dict(request)
item = request_obj.as_queue_tuple().queue_tuple
item_with_completer = QueueItem(item, self.loop.create_future()) item_with_completer = QueueItem(item, self.loop.create_future())
self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
# todo: check if we have the local model content needed to execute this request and if not, reject it # todo: check if we have the local model content needed to execute this request and if not, reject it
@ -80,9 +107,10 @@ class DistributedPromptQueue(AbstractPromptQueue):
assert not item_with_completer.completed.done() assert not item_with_completer.completed.done()
# now we wait for the worker thread to complete the item # now we wait for the worker thread to complete the item
return await item_with_completer.completed invocation = await item_with_completer.completed
return asdict(RpcReply.from_task_invocation(invocation, request_obj.user_token))
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, int]]: def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str | int]]:
# callee: executed on the worker thread # callee: executed on the worker thread
assert self.is_callee assert self.is_callee
try: try:
@ -175,6 +203,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
self.is_caller = is_caller self.is_caller = is_caller
self.is_callee = is_callee self.is_callee = is_callee
self._closing = False
# as rpc caller # as rpc caller
self.caller_server = server self.caller_server = server
@ -182,9 +211,9 @@ class DistributedPromptQueue(AbstractPromptQueue):
self.caller_history: History = History() self.caller_history: History = History()
# as rpc callee # as rpc callee
self.callee_local_queue = asyncio.Queue() self.callee_local_queue: Queue = Queue()
self.callee_local_in_progress: Dict[int | str, QueueItem] = {} self.callee_local_in_progress: Dict[int | str, QueueItem] = {}
self.rpc: Optional[RPC] = None self.rpc: Optional[JsonRPC] = None
# todo: the prompt queue really shouldn't do this # todo: the prompt queue really shouldn't do this
if server is not None: if server is not None:
@ -193,8 +222,14 @@ class DistributedPromptQueue(AbstractPromptQueue):
async def init(self): async def init(self):
self.connection = await connect_robust(self.connection_uri, loop=self.loop) self.connection = await connect_robust(self.connection_uri, loop=self.loop)
self.channel = await self.connection.channel() self.channel = await self.connection.channel()
self.rpc = await RPC.create(channel=self.channel) self.rpc = await JsonRPC.create(channel=self.channel)
self.rpc.host_exceptions = True self.rpc.host_exceptions = True
# this makes the queue available to complete work items # this makes the queue available to complete work items
if self.is_callee: if self.is_callee:
await self.rpc.register(self.queue_name, self._callee_do_work_item) await self.rpc.register(self.queue_name, self._callee_do_work_item)
async def close(self):
self._closing = True
await self.rpc.close()
await self.channel.close()
await self.connection.close()

View File

@ -1,14 +1,15 @@
import asyncio import asyncio
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from contextlib import AsyncExitStack
from dataclasses import asdict
from typing import Optional from typing import Optional
from aio_pika import connect_robust from aio_pika import connect_robust
from aio_pika.patterns import RPC from aio_pika.patterns import JsonRPC
from ..api.components.schema.prompt import Prompt from .distributed_types import RpcRequest, RpcReply
from ..cli_args_types import Configuration
from ..client.embedded_comfy_client import EmbeddedComfyClient from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..component_model.queue_types import TaskInvocation, QueueTuple, QueueItem, ExecutionStatus from ..component_model.queue_types import ExecutionStatus
class DistributedPromptWorker: class DistributedPromptWorker:
@ -16,31 +17,42 @@ class DistributedPromptWorker:
A work in progress distributed prompt worker. A work in progress distributed prompt worker.
""" """
def __init__(self, embedded_comfy_client: EmbeddedComfyClient, def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
connection_uri: str = "amqp://localhost:5672/", connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui", queue_name: str = "comfyui",
loop: Optional[AbstractEventLoop] = None, configuration: Configuration = None): loop: Optional[AbstractEventLoop] = None):
self._exit_stack = AsyncExitStack()
self._queue_name = queue_name self._queue_name = queue_name
self._configuration = configuration
self._connection_uri = connection_uri self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
async def _do_work_item(self, item: QueueTuple) -> TaskInvocation: async def _do_work_item(self, request: dict) -> dict:
item_without_completer = QueueItem(item, completed=None)
try: try:
output_dict = await self._embedded_comfy_client.queue_prompt(Prompt.validate(item_without_completer.prompt)) request_obj = RpcRequest.from_dict(request)
return TaskInvocation(item_without_completer.prompt_id, outputs=output_dict,
status=ExecutionStatus("success", True, []))
except Exception as e: except Exception as e:
return TaskInvocation(item_without_completer.prompt_id, outputs={}, request_dict_prompt_id_recovered = request["prompt_id"] \
status=ExecutionStatus("error", False, [str(e)])) if request is not None and "prompt_id" in request else ""
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
ExecutionStatus("error", False, [str(e)])))
try:
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
request_obj.prompt_id,
client_id=request_obj.user_id)
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, [])))
except Exception as e:
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)])))
async def __aenter__(self) -> "DistributedPromptWorker": async def __aenter__(self) -> "DistributedPromptWorker":
await self._exit_stack.__aenter__()
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
self._connection = await connect_robust(self._connection_uri, loop=self._loop) self._connection = await connect_robust(self._connection_uri, loop=self._loop)
self._channel = await self._connection.channel() self._channel = await self._connection.channel()
self._rpc = await RPC.create(channel=self._channel) self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True self._rpc.host_exceptions = True
await self._rpc.register(self._queue_name, self._do_work_item) await self._rpc.register(self._queue_name, self._do_work_item)
return self return self
@ -48,3 +60,4 @@ class DistributedPromptWorker:
await self._rpc.close() await self._rpc.close()
await self._channel.close() await self._channel.close()
await self._connection.close() await self._connection.close()
return await self._exit_stack.__aexit__(*args)

View File

@ -0,0 +1,58 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, Literal, List
import jwt
from ..api.components.schema.prompt import PromptDict, Prompt
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
@dataclass
class DistributedBase:
prompt_id: str
user_token: str
@property
def user_id(self) -> str:
return self.decoded_token["sub"]
@property
def decoded_token(self) -> dict:
return jwt.decode(self.user_token, algorithms=['HS256', "none"],
# todo: this should be configurable
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})
@dataclass
class RpcRequest(DistributedBase):
prompt: dict | PromptDict
def as_queue_tuple(self) -> NamedQueueTuple:
# this loads the nodes in this instance
# should always be okay to call in an executor
from ..cmd.execution import validate_prompt
from ..component_model.make_mutable import make_mutable
mutated_prompt_dict = make_mutable(self.prompt)
validation_tuple = validate_prompt(mutated_prompt_dict)
return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2]))
@classmethod
def from_dict(cls, request_dict):
request = RpcRequest(**request_dict)
request.prompt = Prompt.validate(request.prompt)
return request
@dataclass
class RpcReply(DistributedBase):
outputs: dict
execution_status: ExecutionStatus | Tuple[Literal['success', 'error'], bool, List[str]]
@staticmethod
def from_task_invocation(task_invocation: TaskInvocation, user_token: str) -> 'RpcReply':
return RpcReply(str(task_invocation.item_id), user_token, task_invocation.outputs, task_invocation.status)
def as_task_invocation(self):
return TaskInvocation(self.prompt_id, self.outputs, ExecutionStatus(*self.execution_status))

View File

@ -15,7 +15,7 @@ class History:
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus): def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple, self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
outputs=outputs, outputs=outputs,
status=status._asdict()) status=ExecutionStatus(*status)._asdict())
def copy(self, prompt_id: Optional[str | int] = None, max_items: int = MAXIMUM_HISTORY_SIZE, def copy(self, prompt_id: Optional[str | int] = None, max_items: int = MAXIMUM_HISTORY_SIZE,
offset: int = 0) -> Dict[str, HistoryEntry]: offset: int = 0) -> Dict[str, HistoryEntry]:

View File

@ -28,4 +28,5 @@ tqdm
protobuf==3.20.3 protobuf==3.20.3
psutil psutil
ConfigArgParse ConfigArgParse
aio-pika aio-pika
pyjwt[crypto]

View File

@ -1,6 +1,7 @@
import os import os
import uuid import uuid
import jwt
import pytest import pytest
from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub
@ -8,32 +9,41 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from testcontainers.rabbitmq import RabbitMqContainer
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
@pytest.mark.asyncio
async def test_sign_jwt_auth_none():
client_id = str(uuid.uuid4())
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
assert user_token["sub"] == client_id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_queue_worker() -> None: async def test_basic_queue_worker() -> None:
os.environ["TC_HOST"] = "localhost"
# there are lots of side effects from importing that we have to deal with # there are lots of side effects from importing that we have to deal with
from testcontainers.rabbitmq import RabbitMqContainer
with RabbitMqContainer("rabbitmq:latest") as rabbitmq: with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params() params = rabbitmq.get_connection_params()
async with EmbeddedComfyClient() as client: async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
async with DistributedPromptWorker(client, # this unfortunately does a bunch of initialization on the test thread
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker: from comfy.cmd.execution import validate_prompt
# this unfortunately does a bunch of initialization on the test thread from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
from comfy.cmd.execution import validate_prompt # now submit some jobs
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
# now submit some jobs connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, await distributed_queue.init()
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
await distributed_queue.init() validation_tuple = validate_prompt(prompt)
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) item_id = str(uuid.uuid4())
validation_tuple = validate_prompt(prompt) queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
item_id = str(uuid.uuid4()) res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) assert res.item_id == item_id
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None)) assert len(res.outputs) == 1
assert res.item_id == item_id assert res.status is not None
assert len(res.outputs) == 1 assert res.status.status_str == "success"
assert res.status is not None await distributed_queue.close()
assert res.status.status_str == "success"