Improved support for distributed queues

This commit is contained in:
doctorpangloss 2024-02-08 14:55:07 -08:00
parent 3367362cec
commit 92898b8c9d
14 changed files with 278 additions and 72 deletions

View File

@ -186,7 +186,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
7. To create the standalone binary:
```shell
python -m pyinstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py
python -m PyInstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py
```
### Authoring Custom Nodes

View File

@ -119,6 +119,17 @@ parser.add_argument("--plausible-analytics-domain", required=False,
help="Specifies the domain name for analytics events.")
parser.add_argument("--analytics-use-identity-provider", action="store_true",
help="Uses platform identifiers for unique visitor analytics.")
parser.add_argument("--distributed-queue-connection-uri", type=str, default=None,
help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
parser.add_argument(
'--distributed-queue-roles',
action='append',
choices=['worker', 'prompter'],
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMPQ URL to submit prompts; workers will pull requests off the AMPQ URL.'
)
parser.add_argument("--distributed-queue-name", type=str, default="comfyui",
help="This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
if options.args_parsing:
args, _ = parser.parse_known_args()

View File

@ -70,6 +70,9 @@ class Configuration(dict):
analytics_use_identity_provider (bool): Use platform identifiers for analytics.
write_out_config_file (bool): Enable writing out the configuration file.
create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits.
distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
distributed_queue_name (str): This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
"""
def __init__(self, **kwargs):
super().__init__()
@ -127,6 +130,9 @@ class Configuration(dict):
self.analytics_use_identity_provider: bool = False
self.write_out_config_file: bool = False
self.create_directories: bool = False
self.distributed_queue_connection_uri: Optional[str] = None
self.distributed_queue_roles: List[str] = []
self.distributed_queue_name: str = "comfyui"
for key, value in kwargs.items():
self[key] = value

View File

@ -18,6 +18,7 @@ class ServerStub(ExecutorToClientProgress):
"""
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
"""
def __init__(self):
self.client_id = str(uuid.uuid4())
self.last_node_id = None
@ -69,18 +70,25 @@ class EmbeddedComfyClient:
In order to use this in blocking methods, learn more about asyncio online.
"""
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None,
def __init__(self, configuration: Optional[Configuration] = None,
progress_handler: Optional[ExecutorToClientProgress] = None,
loop: Optional[AbstractEventLoop] = None,
max_workers: int = 1):
self._server_stub = ServerStub()
self._progress_handler = progress_handler or ServerStub()
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._loop = loop or asyncio.get_event_loop()
self._configuration = configuration
# we don't want to import the executor yet
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
self._is_running = False
@property
def is_running(self) -> bool:
return self._is_running
async def __aenter__(self):
# Perform asynchronous initialization here, if needed
await self._initialize_prompt_executor()
self._is_running = True
return self
async def __aexit__(self, *args):
@ -101,6 +109,7 @@ class EmbeddedComfyClient:
await self._loop.run_in_executor(self._executor, cleanup)
self._executor.shutdown(wait=True)
self._is_running = False
async def _initialize_prompt_executor(self):
# This method must be async since it's used in __aenter__
@ -115,19 +124,23 @@ class EmbeddedComfyClient:
from ..cmd.execution import PromptExecutor
self._prompt_executor = PromptExecutor(self._server_stub)
self._prompt_executor = PromptExecutor(self._progress_handler)
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
async def queue_prompt(self, prompt: PromptDict) -> dict:
prompt_id = str(uuid.uuid4())
async def queue_prompt(self,
prompt: PromptDict | dict,
prompt_id: Optional[str] = None,
client_id: Optional[str] = None) -> dict:
prompt_id = prompt_id or str(uuid.uuid4())
client_id = client_id or self._progress_handler.client_id or None
def execute_prompt() -> dict:
from ..cmd.execution import validate_prompt
prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut)
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": self._server_stub.client_id},
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2])
if self._prompt_executor.success:
return self._prompt_executor.outputs_ui

View File

@ -807,7 +807,7 @@ class PromptQueue(AbstractPromptQueue):
status_dict: Optional[dict] = None
if status is not None:
status_dict = copy.deepcopy(status._asdict())
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
self.history[prompt[1]] = {
"prompt": prompt,

View File

@ -1,12 +1,9 @@
from __future__ import annotations
from typing import Mapping, Any
def make_mutable(obj: Any) -> dict:
"""
Makes an immutable dict, frozenset or tuple mutable. Otherwise, returns the value.
:param obj: any object
:return:
"""
def _make_mutable(obj: Any) -> Any:
if isinstance(obj, Mapping) and not isinstance(obj, dict) and not hasattr(obj, "__setitem__"):
obj = dict(obj)
for key, value in obj.items():
@ -18,3 +15,17 @@ def make_mutable(obj: Any) -> dict:
if isinstance(obj, frozenset):
obj = set([make_mutable(x) for x in obj])
return obj
def make_mutable(obj: Mapping) -> dict:
"""
Makes a copy of an immutable dict into a mutable dict.
If the object is already a mutable type or a value type like a string or integer, returns the value.
Returns dict, set or tuple depending on its input, but you should not use it this way.
:param obj: any object
:return:
"""
return _make_mutable(obj)

View File

@ -40,21 +40,21 @@ class HistoryEntry(TypedDict):
status: NotRequired[ExecutionStatusAsDict]
class ExtraData(TypedDict):
client_id: NotRequired[str]
extra_pnginfo: NotRequired[str]
token: NotRequired[str]
@dataclass
class QueueItem:
class NamedQueueTuple:
"""
An item awaiting processing in the queue
A wrapper class for a queue tuple, the object that is given to executors.
Attributes:
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
or a dictionary of outputs
"""
queue_tuple: QueueTuple
completed: asyncio.Future[TaskInvocation | dict] | None
def __lt__(self, other: QueueItem):
return self.queue_tuple[0] < other.queue_tuple[0]
@property
def priority(self) -> float:
@ -69,7 +69,7 @@ class QueueItem:
return self.queue_tuple[2]
@property
def extra_data(self) -> Optional[dict]:
def extra_data(self) -> Optional[ExtraData]:
if len(self.queue_tuple) > 2:
return self.queue_tuple[3]
else:
@ -83,6 +83,22 @@ class QueueItem:
return None
@dataclass
class QueueItem(NamedQueueTuple):
"""
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
processing.
Attributes:
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
or a dictionary of outputs
"""
completed: asyncio.Future[TaskInvocation | dict] | None
def __lt__(self, other: QueueItem):
return self.queue_tuple[0] < other.queue_tuple[0]
class BinaryEventTypes(Enum):
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2

View File

@ -0,0 +1,32 @@
import asyncio
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import RPC
from comfy.distributed.distributed_types import RpcRequest, RpcReply
class DistributedPromptClient:
def __init__(self, queue_name: str = "comfyui",
connection_uri="amqp://localhost/",
loop: Optional[AbstractEventLoop] = None):
self.queue_name = queue_name
self.connection_uri = connection_uri
self.loop = loop or asyncio.get_event_loop()
async def __aenter__(self):
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
self.channel = await self.connection.channel()
self.rpc = await RPC.create(channel=self.channel)
self.rpc.host_exceptions = True
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.channel.close()
await self.rpc.close()
await self.connection.close()
async def queue_prompt(self, request: RpcRequest) -> RpcReply:
return RpcReply(**(await self.rpc.call(self.queue_name, {"request": asdict(request)})))

View File

@ -1,13 +1,17 @@
from __future__ import annotations
import asyncio
from asyncio import AbstractEventLoop
import uuid
from asyncio import AbstractEventLoop, Queue
from dataclasses import asdict
from typing import Optional, Dict, List, Mapping, Tuple, Callable
import jwt
from aio_pika import connect_robust
from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import RPC
from aio_pika.patterns import JsonRPC
from .distributed_types import RpcRequest, RpcReply
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
@ -29,11 +33,30 @@ class DistributedPromptQueue(AbstractPromptQueue):
async def put_async(self, queue_item: QueueItem):
assert self.is_caller
if self._closing:
return
self.caller_local_in_progress[queue_item.prompt_id] = queue_item
if self.caller_server is not None:
self.caller_server.queue_updated()
try:
res: TaskInvocation = await self.rpc.call(self.queue_name, {"item": queue_item.queue_tuple})
if "token" in queue_item.extra_data:
user_token = queue_item.extra_data["token"]
else:
if "client_id" in queue_item.extra_data:
client_id = queue_item.extra_data["client_id"]
elif self.caller_server.client_id is not None:
client_id = self.caller_server.client_id
else:
client_id = str(uuid.uuid4())
# todo: should we really do this?
self.caller_server.client_id = client_id
# create a stub token
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
assert self.rpc is not None
res: TaskInvocation = RpcReply(
**(await self.rpc.call(self.queue_name, {"request": asdict(request)}))).as_task_invocation()
self.caller_history.put(queue_item, res.outputs, res.status)
if self.caller_server is not None:
@ -63,12 +86,16 @@ class DistributedPromptQueue(AbstractPromptQueue):
def put(self, item: QueueItem):
# caller: execute on main thread
assert self.is_caller
if self._closing:
return
# this is called by the web server and its event loop is perfectly fine to use
# the future is now ignored
self.loop.call_soon_threadsafe(self.put_async, item)
async def _callee_do_work_item(self, item: QueueTuple) -> TaskInvocation:
async def _callee_do_work_item(self, request: dict) -> dict:
assert self.is_callee
request_obj = RpcRequest.from_dict(request)
item = request_obj.as_queue_tuple().queue_tuple
item_with_completer = QueueItem(item, self.loop.create_future())
self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
# todo: check if we have the local model content needed to execute this request and if not, reject it
@ -80,9 +107,10 @@ class DistributedPromptQueue(AbstractPromptQueue):
assert not item_with_completer.completed.done()
# now we wait for the worker thread to complete the item
return await item_with_completer.completed
invocation = await item_with_completer.completed
return asdict(RpcReply.from_task_invocation(invocation, request_obj.user_token))
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, int]]:
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str | int]]:
# callee: executed on the worker thread
assert self.is_callee
try:
@ -175,6 +203,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
self.is_caller = is_caller
self.is_callee = is_callee
self._closing = False
# as rpc caller
self.caller_server = server
@ -182,9 +211,9 @@ class DistributedPromptQueue(AbstractPromptQueue):
self.caller_history: History = History()
# as rpc callee
self.callee_local_queue = asyncio.Queue()
self.callee_local_queue: Queue = Queue()
self.callee_local_in_progress: Dict[int | str, QueueItem] = {}
self.rpc: Optional[RPC] = None
self.rpc: Optional[JsonRPC] = None
# todo: the prompt queue really shouldn't do this
if server is not None:
@ -193,8 +222,14 @@ class DistributedPromptQueue(AbstractPromptQueue):
async def init(self):
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
self.channel = await self.connection.channel()
self.rpc = await RPC.create(channel=self.channel)
self.rpc = await JsonRPC.create(channel=self.channel)
self.rpc.host_exceptions = True
# this makes the queue available to complete work items
if self.is_callee:
await self.rpc.register(self.queue_name, self._callee_do_work_item)
async def close(self):
self._closing = True
await self.rpc.close()
await self.channel.close()
await self.connection.close()

View File

@ -1,14 +1,15 @@
import asyncio
from asyncio import AbstractEventLoop
from contextlib import AsyncExitStack
from dataclasses import asdict
from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import RPC
from aio_pika.patterns import JsonRPC
from ..api.components.schema.prompt import Prompt
from ..cli_args_types import Configuration
from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..component_model.queue_types import TaskInvocation, QueueTuple, QueueItem, ExecutionStatus
from ..component_model.queue_types import ExecutionStatus
class DistributedPromptWorker:
@ -16,31 +17,42 @@ class DistributedPromptWorker:
A work in progress distributed prompt worker.
"""
def __init__(self, embedded_comfy_client: EmbeddedComfyClient,
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui",
loop: Optional[AbstractEventLoop] = None, configuration: Configuration = None):
loop: Optional[AbstractEventLoop] = None):
self._exit_stack = AsyncExitStack()
self._queue_name = queue_name
self._configuration = configuration
self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
async def _do_work_item(self, item: QueueTuple) -> TaskInvocation:
item_without_completer = QueueItem(item, completed=None)
async def _do_work_item(self, request: dict) -> dict:
try:
output_dict = await self._embedded_comfy_client.queue_prompt(Prompt.validate(item_without_completer.prompt))
return TaskInvocation(item_without_completer.prompt_id, outputs=output_dict,
status=ExecutionStatus("success", True, []))
request_obj = RpcRequest.from_dict(request)
except Exception as e:
return TaskInvocation(item_without_completer.prompt_id, outputs={},
status=ExecutionStatus("error", False, [str(e)]))
request_dict_prompt_id_recovered = request["prompt_id"] \
if request is not None and "prompt_id" in request else ""
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
ExecutionStatus("error", False, [str(e)])))
try:
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
request_obj.prompt_id,
client_id=request_obj.user_id)
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, [])))
except Exception as e:
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)])))
async def __aenter__(self) -> "DistributedPromptWorker":
await self._exit_stack.__aenter__()
if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
self._channel = await self._connection.channel()
self._rpc = await RPC.create(channel=self._channel)
self._rpc = await JsonRPC.create(channel=self._channel)
self._rpc.host_exceptions = True
await self._rpc.register(self._queue_name, self._do_work_item)
return self
@ -48,3 +60,4 @@ class DistributedPromptWorker:
await self._rpc.close()
await self._channel.close()
await self._connection.close()
return await self._exit_stack.__aexit__(*args)

View File

@ -0,0 +1,58 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, Literal, List
import jwt
from ..api.components.schema.prompt import PromptDict, Prompt
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
@dataclass
class DistributedBase:
prompt_id: str
user_token: str
@property
def user_id(self) -> str:
return self.decoded_token["sub"]
@property
def decoded_token(self) -> dict:
return jwt.decode(self.user_token, algorithms=['HS256', "none"],
# todo: this should be configurable
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})
@dataclass
class RpcRequest(DistributedBase):
prompt: dict | PromptDict
def as_queue_tuple(self) -> NamedQueueTuple:
# this loads the nodes in this instance
# should always be okay to call in an executor
from ..cmd.execution import validate_prompt
from ..component_model.make_mutable import make_mutable
mutated_prompt_dict = make_mutable(self.prompt)
validation_tuple = validate_prompt(mutated_prompt_dict)
return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2]))
@classmethod
def from_dict(cls, request_dict):
request = RpcRequest(**request_dict)
request.prompt = Prompt.validate(request.prompt)
return request
@dataclass
class RpcReply(DistributedBase):
outputs: dict
execution_status: ExecutionStatus | Tuple[Literal['success', 'error'], bool, List[str]]
@staticmethod
def from_task_invocation(task_invocation: TaskInvocation, user_token: str) -> 'RpcReply':
return RpcReply(str(task_invocation.item_id), user_token, task_invocation.outputs, task_invocation.status)
def as_task_invocation(self):
return TaskInvocation(self.prompt_id, self.outputs, ExecutionStatus(*self.execution_status))

View File

@ -15,7 +15,7 @@ class History:
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
outputs=outputs,
status=status._asdict())
status=ExecutionStatus(*status)._asdict())
def copy(self, prompt_id: Optional[str | int] = None, max_items: int = MAXIMUM_HISTORY_SIZE,
offset: int = 0) -> Dict[str, HistoryEntry]:

View File

@ -28,4 +28,5 @@ tqdm
protobuf==3.20.3
psutil
ConfigArgParse
aio-pika
aio-pika
pyjwt[crypto]

View File

@ -1,6 +1,7 @@
import os
import uuid
import jwt
import pytest
from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub
@ -8,32 +9,41 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from testcontainers.rabbitmq import RabbitMqContainer
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
@pytest.mark.asyncio
async def test_sign_jwt_auth_none():
client_id = str(uuid.uuid4())
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
assert user_token["sub"] == client_id
@pytest.mark.asyncio
async def test_basic_queue_worker() -> None:
os.environ["TC_HOST"] = "localhost"
# there are lots of side effects from importing that we have to deal with
from testcontainers.rabbitmq import RabbitMqContainer
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
async with EmbeddedComfyClient() as client:
async with DistributedPromptWorker(client,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
# this unfortunately does a bunch of initialization on the test thread
from comfy.cmd.execution import validate_prompt
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init()
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
validation_tuple = validate_prompt(prompt)
item_id = str(uuid.uuid4())
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
assert res.item_id == item_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
# this unfortunately does a bunch of initialization on the test thread
from comfy.cmd.execution import validate_prompt
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init()
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
validation_tuple = validate_prompt(prompt)
item_id = str(uuid.uuid4())
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
assert res.item_id == item_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()