mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Improved support for distributed queues
This commit is contained in:
parent
3367362cec
commit
92898b8c9d
@ -186,7 +186,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
|
|||||||
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
|
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
|
||||||
7. To create the standalone binary:
|
7. To create the standalone binary:
|
||||||
```shell
|
```shell
|
||||||
python -m pyinstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py
|
python -m PyInstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### Authoring Custom Nodes
|
### Authoring Custom Nodes
|
||||||
|
|||||||
@ -119,6 +119,17 @@ parser.add_argument("--plausible-analytics-domain", required=False,
|
|||||||
help="Specifies the domain name for analytics events.")
|
help="Specifies the domain name for analytics events.")
|
||||||
parser.add_argument("--analytics-use-identity-provider", action="store_true",
|
parser.add_argument("--analytics-use-identity-provider", action="store_true",
|
||||||
help="Uses platform identifiers for unique visitor analytics.")
|
help="Uses platform identifiers for unique visitor analytics.")
|
||||||
|
parser.add_argument("--distributed-queue-connection-uri", type=str, default=None,
|
||||||
|
help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
|
||||||
|
parser.add_argument(
|
||||||
|
'--distributed-queue-roles',
|
||||||
|
action='append',
|
||||||
|
choices=['worker', 'prompter'],
|
||||||
|
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMPQ URL to submit prompts; workers will pull requests off the AMPQ URL.'
|
||||||
|
)
|
||||||
|
parser.add_argument("--distributed-queue-name", type=str, default="comfyui",
|
||||||
|
help="This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
|
||||||
|
|
||||||
|
|
||||||
if options.args_parsing:
|
if options.args_parsing:
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|||||||
@ -70,6 +70,9 @@ class Configuration(dict):
|
|||||||
analytics_use_identity_provider (bool): Use platform identifiers for analytics.
|
analytics_use_identity_provider (bool): Use platform identifiers for analytics.
|
||||||
write_out_config_file (bool): Enable writing out the configuration file.
|
write_out_config_file (bool): Enable writing out the configuration file.
|
||||||
create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits.
|
create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits.
|
||||||
|
distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
|
||||||
|
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
|
||||||
|
distributed_queue_name (str): This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
|
||||||
"""
|
"""
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -127,6 +130,9 @@ class Configuration(dict):
|
|||||||
self.analytics_use_identity_provider: bool = False
|
self.analytics_use_identity_provider: bool = False
|
||||||
self.write_out_config_file: bool = False
|
self.write_out_config_file: bool = False
|
||||||
self.create_directories: bool = False
|
self.create_directories: bool = False
|
||||||
|
self.distributed_queue_connection_uri: Optional[str] = None
|
||||||
|
self.distributed_queue_roles: List[str] = []
|
||||||
|
self.distributed_queue_name: str = "comfyui"
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
self[key] = value
|
self[key] = value
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class ServerStub(ExecutorToClientProgress):
|
|||||||
"""
|
"""
|
||||||
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
|
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client_id = str(uuid.uuid4())
|
self.client_id = str(uuid.uuid4())
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
@ -69,18 +70,25 @@ class EmbeddedComfyClient:
|
|||||||
In order to use this in blocking methods, learn more about asyncio online.
|
In order to use this in blocking methods, learn more about asyncio online.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None,
|
def __init__(self, configuration: Optional[Configuration] = None,
|
||||||
|
progress_handler: Optional[ExecutorToClientProgress] = None,
|
||||||
|
loop: Optional[AbstractEventLoop] = None,
|
||||||
max_workers: int = 1):
|
max_workers: int = 1):
|
||||||
self._server_stub = ServerStub()
|
self._progress_handler = progress_handler or ServerStub()
|
||||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
self._configuration = configuration
|
self._configuration = configuration
|
||||||
# we don't want to import the executor yet
|
# we don't want to import the executor yet
|
||||||
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
|
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return self._is_running
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
# Perform asynchronous initialization here, if needed
|
|
||||||
await self._initialize_prompt_executor()
|
await self._initialize_prompt_executor()
|
||||||
|
self._is_running = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
async def __aexit__(self, *args):
|
||||||
@ -101,6 +109,7 @@ class EmbeddedComfyClient:
|
|||||||
await self._loop.run_in_executor(self._executor, cleanup)
|
await self._loop.run_in_executor(self._executor, cleanup)
|
||||||
|
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
async def _initialize_prompt_executor(self):
|
async def _initialize_prompt_executor(self):
|
||||||
# This method must be async since it's used in __aenter__
|
# This method must be async since it's used in __aenter__
|
||||||
@ -115,19 +124,23 @@ class EmbeddedComfyClient:
|
|||||||
|
|
||||||
from ..cmd.execution import PromptExecutor
|
from ..cmd.execution import PromptExecutor
|
||||||
|
|
||||||
self._prompt_executor = PromptExecutor(self._server_stub)
|
self._prompt_executor = PromptExecutor(self._progress_handler)
|
||||||
|
|
||||||
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
|
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
|
||||||
|
|
||||||
async def queue_prompt(self, prompt: PromptDict) -> dict:
|
async def queue_prompt(self,
|
||||||
prompt_id = str(uuid.uuid4())
|
prompt: PromptDict | dict,
|
||||||
|
prompt_id: Optional[str] = None,
|
||||||
|
client_id: Optional[str] = None) -> dict:
|
||||||
|
prompt_id = prompt_id or str(uuid.uuid4())
|
||||||
|
client_id = client_id or self._progress_handler.client_id or None
|
||||||
|
|
||||||
def execute_prompt() -> dict:
|
def execute_prompt() -> dict:
|
||||||
from ..cmd.execution import validate_prompt
|
from ..cmd.execution import validate_prompt
|
||||||
prompt_mut = make_mutable(prompt)
|
prompt_mut = make_mutable(prompt)
|
||||||
validation_tuple = validate_prompt(prompt_mut)
|
validation_tuple = validate_prompt(prompt_mut)
|
||||||
|
|
||||||
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": self._server_stub.client_id},
|
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||||
execute_outputs=validation_tuple[2])
|
execute_outputs=validation_tuple[2])
|
||||||
if self._prompt_executor.success:
|
if self._prompt_executor.success:
|
||||||
return self._prompt_executor.outputs_ui
|
return self._prompt_executor.outputs_ui
|
||||||
|
|||||||
@ -807,7 +807,7 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
|
|
||||||
status_dict: Optional[dict] = None
|
status_dict: Optional[dict] = None
|
||||||
if status is not None:
|
if status is not None:
|
||||||
status_dict = copy.deepcopy(status._asdict())
|
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
|
||||||
|
|
||||||
self.history[prompt[1]] = {
|
self.history[prompt[1]] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Mapping, Any
|
from typing import Mapping, Any
|
||||||
|
|
||||||
|
|
||||||
def make_mutable(obj: Any) -> dict:
|
def _make_mutable(obj: Any) -> Any:
|
||||||
"""
|
|
||||||
Makes an immutable dict, frozenset or tuple mutable. Otherwise, returns the value.
|
|
||||||
:param obj: any object
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if isinstance(obj, Mapping) and not isinstance(obj, dict) and not hasattr(obj, "__setitem__"):
|
if isinstance(obj, Mapping) and not isinstance(obj, dict) and not hasattr(obj, "__setitem__"):
|
||||||
obj = dict(obj)
|
obj = dict(obj)
|
||||||
for key, value in obj.items():
|
for key, value in obj.items():
|
||||||
@ -18,3 +15,17 @@ def make_mutable(obj: Any) -> dict:
|
|||||||
if isinstance(obj, frozenset):
|
if isinstance(obj, frozenset):
|
||||||
obj = set([make_mutable(x) for x in obj])
|
obj = set([make_mutable(x) for x in obj])
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def make_mutable(obj: Mapping) -> dict:
|
||||||
|
"""
|
||||||
|
Makes a copy of an immutable dict into a mutable dict.
|
||||||
|
|
||||||
|
If the object is already a mutable type or a value type like a string or integer, returns the value.
|
||||||
|
|
||||||
|
Returns dict, set or tuple depending on its input, but you should not use it this way.
|
||||||
|
|
||||||
|
:param obj: any object
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return _make_mutable(obj)
|
||||||
|
|||||||
@ -40,21 +40,21 @@ class HistoryEntry(TypedDict):
|
|||||||
status: NotRequired[ExecutionStatusAsDict]
|
status: NotRequired[ExecutionStatusAsDict]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraData(TypedDict):
|
||||||
|
client_id: NotRequired[str]
|
||||||
|
extra_pnginfo: NotRequired[str]
|
||||||
|
token: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueueItem:
|
class NamedQueueTuple:
|
||||||
"""
|
"""
|
||||||
An item awaiting processing in the queue
|
A wrapper class for a queue tuple, the object that is given to executors.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
|
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
|
||||||
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
|
|
||||||
or a dictionary of outputs
|
|
||||||
"""
|
"""
|
||||||
queue_tuple: QueueTuple
|
queue_tuple: QueueTuple
|
||||||
completed: asyncio.Future[TaskInvocation | dict] | None
|
|
||||||
|
|
||||||
def __lt__(self, other: QueueItem):
|
|
||||||
return self.queue_tuple[0] < other.queue_tuple[0]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def priority(self) -> float:
|
def priority(self) -> float:
|
||||||
@ -69,7 +69,7 @@ class QueueItem:
|
|||||||
return self.queue_tuple[2]
|
return self.queue_tuple[2]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_data(self) -> Optional[dict]:
|
def extra_data(self) -> Optional[ExtraData]:
|
||||||
if len(self.queue_tuple) > 2:
|
if len(self.queue_tuple) > 2:
|
||||||
return self.queue_tuple[3]
|
return self.queue_tuple[3]
|
||||||
else:
|
else:
|
||||||
@ -83,6 +83,22 @@ class QueueItem:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueueItem(NamedQueueTuple):
|
||||||
|
"""
|
||||||
|
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
|
||||||
|
processing.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
|
||||||
|
or a dictionary of outputs
|
||||||
|
"""
|
||||||
|
completed: asyncio.Future[TaskInvocation | dict] | None
|
||||||
|
|
||||||
|
def __lt__(self, other: QueueItem):
|
||||||
|
return self.queue_tuple[0] < other.queue_tuple[0]
|
||||||
|
|
||||||
|
|
||||||
class BinaryEventTypes(Enum):
|
class BinaryEventTypes(Enum):
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
UNENCODED_PREVIEW_IMAGE = 2
|
UNENCODED_PREVIEW_IMAGE = 2
|
||||||
|
|||||||
32
comfy/distributed/distributed_prompt_client.py
Normal file
32
comfy/distributed/distributed_prompt_client.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import asyncio
|
||||||
|
from asyncio import AbstractEventLoop
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from aio_pika import connect_robust
|
||||||
|
from aio_pika.patterns import RPC
|
||||||
|
|
||||||
|
from comfy.distributed.distributed_types import RpcRequest, RpcReply
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedPromptClient:
|
||||||
|
def __init__(self, queue_name: str = "comfyui",
|
||||||
|
connection_uri="amqp://localhost/",
|
||||||
|
loop: Optional[AbstractEventLoop] = None):
|
||||||
|
self.queue_name = queue_name
|
||||||
|
self.connection_uri = connection_uri
|
||||||
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
|
||||||
|
self.channel = await self.connection.channel()
|
||||||
|
self.rpc = await RPC.create(channel=self.channel)
|
||||||
|
self.rpc.host_exceptions = True
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.channel.close()
|
||||||
|
await self.rpc.close()
|
||||||
|
await self.connection.close()
|
||||||
|
|
||||||
|
async def queue_prompt(self, request: RpcRequest) -> RpcReply:
|
||||||
|
return RpcReply(**(await self.rpc.call(self.queue_name, {"request": asdict(request)})))
|
||||||
@ -1,13 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import AbstractEventLoop
|
import uuid
|
||||||
|
from asyncio import AbstractEventLoop, Queue
|
||||||
|
from dataclasses import asdict
|
||||||
from typing import Optional, Dict, List, Mapping, Tuple, Callable
|
from typing import Optional, Dict, List, Mapping, Tuple, Callable
|
||||||
|
|
||||||
|
import jwt
|
||||||
from aio_pika import connect_robust
|
from aio_pika import connect_robust
|
||||||
from aio_pika.abc import AbstractConnection, AbstractChannel
|
from aio_pika.abc import AbstractConnection, AbstractChannel
|
||||||
from aio_pika.patterns import RPC
|
from aio_pika.patterns import JsonRPC
|
||||||
|
|
||||||
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||||
@ -29,11 +33,30 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
|
|
||||||
async def put_async(self, queue_item: QueueItem):
|
async def put_async(self, queue_item: QueueItem):
|
||||||
assert self.is_caller
|
assert self.is_caller
|
||||||
|
if self._closing:
|
||||||
|
return
|
||||||
self.caller_local_in_progress[queue_item.prompt_id] = queue_item
|
self.caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||||
if self.caller_server is not None:
|
if self.caller_server is not None:
|
||||||
self.caller_server.queue_updated()
|
self.caller_server.queue_updated()
|
||||||
try:
|
try:
|
||||||
res: TaskInvocation = await self.rpc.call(self.queue_name, {"item": queue_item.queue_tuple})
|
if "token" in queue_item.extra_data:
|
||||||
|
user_token = queue_item.extra_data["token"]
|
||||||
|
else:
|
||||||
|
if "client_id" in queue_item.extra_data:
|
||||||
|
client_id = queue_item.extra_data["client_id"]
|
||||||
|
elif self.caller_server.client_id is not None:
|
||||||
|
client_id = self.caller_server.client_id
|
||||||
|
else:
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
# todo: should we really do this?
|
||||||
|
self.caller_server.client_id = client_id
|
||||||
|
|
||||||
|
# create a stub token
|
||||||
|
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
|
||||||
|
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
|
||||||
|
assert self.rpc is not None
|
||||||
|
res: TaskInvocation = RpcReply(
|
||||||
|
**(await self.rpc.call(self.queue_name, {"request": asdict(request)}))).as_task_invocation()
|
||||||
|
|
||||||
self.caller_history.put(queue_item, res.outputs, res.status)
|
self.caller_history.put(queue_item, res.outputs, res.status)
|
||||||
if self.caller_server is not None:
|
if self.caller_server is not None:
|
||||||
@ -63,12 +86,16 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
def put(self, item: QueueItem):
|
def put(self, item: QueueItem):
|
||||||
# caller: execute on main thread
|
# caller: execute on main thread
|
||||||
assert self.is_caller
|
assert self.is_caller
|
||||||
|
if self._closing:
|
||||||
|
return
|
||||||
# this is called by the web server and its event loop is perfectly fine to use
|
# this is called by the web server and its event loop is perfectly fine to use
|
||||||
# the future is now ignored
|
# the future is now ignored
|
||||||
self.loop.call_soon_threadsafe(self.put_async, item)
|
self.loop.call_soon_threadsafe(self.put_async, item)
|
||||||
|
|
||||||
async def _callee_do_work_item(self, item: QueueTuple) -> TaskInvocation:
|
async def _callee_do_work_item(self, request: dict) -> dict:
|
||||||
assert self.is_callee
|
assert self.is_callee
|
||||||
|
request_obj = RpcRequest.from_dict(request)
|
||||||
|
item = request_obj.as_queue_tuple().queue_tuple
|
||||||
item_with_completer = QueueItem(item, self.loop.create_future())
|
item_with_completer = QueueItem(item, self.loop.create_future())
|
||||||
self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
|
self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
|
||||||
# todo: check if we have the local model content needed to execute this request and if not, reject it
|
# todo: check if we have the local model content needed to execute this request and if not, reject it
|
||||||
@ -80,9 +107,10 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
assert not item_with_completer.completed.done()
|
assert not item_with_completer.completed.done()
|
||||||
|
|
||||||
# now we wait for the worker thread to complete the item
|
# now we wait for the worker thread to complete the item
|
||||||
return await item_with_completer.completed
|
invocation = await item_with_completer.completed
|
||||||
|
return asdict(RpcReply.from_task_invocation(invocation, request_obj.user_token))
|
||||||
|
|
||||||
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, int]]:
|
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str | int]]:
|
||||||
# callee: executed on the worker thread
|
# callee: executed on the worker thread
|
||||||
assert self.is_callee
|
assert self.is_callee
|
||||||
try:
|
try:
|
||||||
@ -175,6 +203,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
|
self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
|
||||||
self.is_caller = is_caller
|
self.is_caller = is_caller
|
||||||
self.is_callee = is_callee
|
self.is_callee = is_callee
|
||||||
|
self._closing = False
|
||||||
|
|
||||||
# as rpc caller
|
# as rpc caller
|
||||||
self.caller_server = server
|
self.caller_server = server
|
||||||
@ -182,9 +211,9 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
self.caller_history: History = History()
|
self.caller_history: History = History()
|
||||||
|
|
||||||
# as rpc callee
|
# as rpc callee
|
||||||
self.callee_local_queue = asyncio.Queue()
|
self.callee_local_queue: Queue = Queue()
|
||||||
self.callee_local_in_progress: Dict[int | str, QueueItem] = {}
|
self.callee_local_in_progress: Dict[int | str, QueueItem] = {}
|
||||||
self.rpc: Optional[RPC] = None
|
self.rpc: Optional[JsonRPC] = None
|
||||||
|
|
||||||
# todo: the prompt queue really shouldn't do this
|
# todo: the prompt queue really shouldn't do this
|
||||||
if server is not None:
|
if server is not None:
|
||||||
@ -193,8 +222,14 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
|||||||
async def init(self):
|
async def init(self):
|
||||||
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
|
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
|
||||||
self.channel = await self.connection.channel()
|
self.channel = await self.connection.channel()
|
||||||
self.rpc = await RPC.create(channel=self.channel)
|
self.rpc = await JsonRPC.create(channel=self.channel)
|
||||||
self.rpc.host_exceptions = True
|
self.rpc.host_exceptions = True
|
||||||
# this makes the queue available to complete work items
|
# this makes the queue available to complete work items
|
||||||
if self.is_callee:
|
if self.is_callee:
|
||||||
await self.rpc.register(self.queue_name, self._callee_do_work_item)
|
await self.rpc.register(self.queue_name, self._callee_do_work_item)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self._closing = True
|
||||||
|
await self.rpc.close()
|
||||||
|
await self.channel.close()
|
||||||
|
await self.connection.close()
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from dataclasses import asdict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from aio_pika import connect_robust
|
from aio_pika import connect_robust
|
||||||
from aio_pika.patterns import RPC
|
from aio_pika.patterns import JsonRPC
|
||||||
|
|
||||||
from ..api.components.schema.prompt import Prompt
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
from ..cli_args_types import Configuration
|
|
||||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from ..component_model.queue_types import TaskInvocation, QueueTuple, QueueItem, ExecutionStatus
|
from ..component_model.queue_types import ExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class DistributedPromptWorker:
|
class DistributedPromptWorker:
|
||||||
@ -16,31 +17,42 @@ class DistributedPromptWorker:
|
|||||||
A work in progress distributed prompt worker.
|
A work in progress distributed prompt worker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedded_comfy_client: EmbeddedComfyClient,
|
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
|
||||||
connection_uri: str = "amqp://localhost:5672/",
|
connection_uri: str = "amqp://localhost:5672/",
|
||||||
queue_name: str = "comfyui",
|
queue_name: str = "comfyui",
|
||||||
loop: Optional[AbstractEventLoop] = None, configuration: Configuration = None):
|
loop: Optional[AbstractEventLoop] = None):
|
||||||
|
self._exit_stack = AsyncExitStack()
|
||||||
self._queue_name = queue_name
|
self._queue_name = queue_name
|
||||||
self._configuration = configuration
|
|
||||||
self._connection_uri = connection_uri
|
self._connection_uri = connection_uri
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
self._embedded_comfy_client = embedded_comfy_client
|
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
|
||||||
|
|
||||||
async def _do_work_item(self, item: QueueTuple) -> TaskInvocation:
|
async def _do_work_item(self, request: dict) -> dict:
|
||||||
item_without_completer = QueueItem(item, completed=None)
|
|
||||||
try:
|
try:
|
||||||
output_dict = await self._embedded_comfy_client.queue_prompt(Prompt.validate(item_without_completer.prompt))
|
request_obj = RpcRequest.from_dict(request)
|
||||||
return TaskInvocation(item_without_completer.prompt_id, outputs=output_dict,
|
|
||||||
status=ExecutionStatus("success", True, []))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return TaskInvocation(item_without_completer.prompt_id, outputs={},
|
request_dict_prompt_id_recovered = request["prompt_id"] \
|
||||||
status=ExecutionStatus("error", False, [str(e)]))
|
if request is not None and "prompt_id" in request else ""
|
||||||
|
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
|
||||||
|
ExecutionStatus("error", False, [str(e)])))
|
||||||
|
try:
|
||||||
|
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
|
||||||
|
request_obj.prompt_id,
|
||||||
|
client_id=request_obj.user_id)
|
||||||
|
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, [])))
|
||||||
|
except Exception as e:
|
||||||
|
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)])))
|
||||||
|
|
||||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||||
|
await self._exit_stack.__aenter__()
|
||||||
|
if not self._embedded_comfy_client.is_running:
|
||||||
|
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||||
|
|
||||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||||
self._channel = await self._connection.channel()
|
self._channel = await self._connection.channel()
|
||||||
self._rpc = await RPC.create(channel=self._channel)
|
self._rpc = await JsonRPC.create(channel=self._channel)
|
||||||
self._rpc.host_exceptions = True
|
self._rpc.host_exceptions = True
|
||||||
|
|
||||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -48,3 +60,4 @@ class DistributedPromptWorker:
|
|||||||
await self._rpc.close()
|
await self._rpc.close()
|
||||||
await self._channel.close()
|
await self._channel.close()
|
||||||
await self._connection.close()
|
await self._connection.close()
|
||||||
|
return await self._exit_stack.__aexit__(*args)
|
||||||
|
|||||||
58
comfy/distributed/distributed_types.py
Normal file
58
comfy/distributed/distributed_types.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple, Literal, List
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from ..api.components.schema.prompt import PromptDict, Prompt
|
||||||
|
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DistributedBase:
|
||||||
|
prompt_id: str
|
||||||
|
user_token: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_id(self) -> str:
|
||||||
|
return self.decoded_token["sub"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decoded_token(self) -> dict:
|
||||||
|
return jwt.decode(self.user_token, algorithms=['HS256', "none"],
|
||||||
|
# todo: this should be configurable
|
||||||
|
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RpcRequest(DistributedBase):
|
||||||
|
prompt: dict | PromptDict
|
||||||
|
|
||||||
|
def as_queue_tuple(self) -> NamedQueueTuple:
|
||||||
|
# this loads the nodes in this instance
|
||||||
|
# should always be okay to call in an executor
|
||||||
|
from ..cmd.execution import validate_prompt
|
||||||
|
from ..component_model.make_mutable import make_mutable
|
||||||
|
mutated_prompt_dict = make_mutable(self.prompt)
|
||||||
|
validation_tuple = validate_prompt(mutated_prompt_dict)
|
||||||
|
return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2]))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, request_dict):
|
||||||
|
request = RpcRequest(**request_dict)
|
||||||
|
request.prompt = Prompt.validate(request.prompt)
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RpcReply(DistributedBase):
|
||||||
|
outputs: dict
|
||||||
|
execution_status: ExecutionStatus | Tuple[Literal['success', 'error'], bool, List[str]]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_task_invocation(task_invocation: TaskInvocation, user_token: str) -> 'RpcReply':
|
||||||
|
return RpcReply(str(task_invocation.item_id), user_token, task_invocation.outputs, task_invocation.status)
|
||||||
|
|
||||||
|
def as_task_invocation(self):
|
||||||
|
return TaskInvocation(self.prompt_id, self.outputs, ExecutionStatus(*self.execution_status))
|
||||||
@ -15,7 +15,7 @@ class History:
|
|||||||
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
|
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
|
||||||
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
|
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
status=status._asdict())
|
status=ExecutionStatus(*status)._asdict())
|
||||||
|
|
||||||
def copy(self, prompt_id: Optional[str | int] = None, max_items: int = MAXIMUM_HISTORY_SIZE,
|
def copy(self, prompt_id: Optional[str | int] = None, max_items: int = MAXIMUM_HISTORY_SIZE,
|
||||||
offset: int = 0) -> Dict[str, HistoryEntry]:
|
offset: int = 0) -> Dict[str, HistoryEntry]:
|
||||||
|
|||||||
@ -28,4 +28,5 @@ tqdm
|
|||||||
protobuf==3.20.3
|
protobuf==3.20.3
|
||||||
psutil
|
psutil
|
||||||
ConfigArgParse
|
ConfigArgParse
|
||||||
aio-pika
|
aio-pika
|
||||||
|
pyjwt[crypto]
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub
|
||||||
@ -8,32 +9,41 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
|||||||
from comfy.component_model.make_mutable import make_mutable
|
from comfy.component_model.make_mutable import make_mutable
|
||||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation
|
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation
|
||||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||||
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
|
||||||
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||||
|
os.environ["TC_HOST"] = "localhost"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sign_jwt_auth_none():
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
|
||||||
|
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
|
||||||
|
assert user_token["sub"] == client_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_queue_worker() -> None:
|
async def test_basic_queue_worker() -> None:
|
||||||
os.environ["TC_HOST"] = "localhost"
|
|
||||||
# there are lots of side effects from importing that we have to deal with
|
# there are lots of side effects from importing that we have to deal with
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
|
||||||
|
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
params = rabbitmq.get_connection_params()
|
params = rabbitmq.get_connection_params()
|
||||||
async with EmbeddedComfyClient() as client:
|
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
||||||
async with DistributedPromptWorker(client,
|
# this unfortunately does a bunch of initialization on the test thread
|
||||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
from comfy.cmd.execution import validate_prompt
|
||||||
# this unfortunately does a bunch of initialization on the test thread
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
from comfy.cmd.execution import validate_prompt
|
# now submit some jobs
|
||||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
||||||
# now submit some jobs
|
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
await distributed_queue.init()
|
||||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||||
await distributed_queue.init()
|
validation_tuple = validate_prompt(prompt)
|
||||||
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
item_id = str(uuid.uuid4())
|
||||||
validation_tuple = validate_prompt(prompt)
|
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||||
item_id = str(uuid.uuid4())
|
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
|
||||||
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
assert res.item_id == item_id
|
||||||
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
|
assert len(res.outputs) == 1
|
||||||
assert res.item_id == item_id
|
assert res.status is not None
|
||||||
assert len(res.outputs) == 1
|
assert res.status.status_str == "success"
|
||||||
assert res.status is not None
|
await distributed_queue.close()
|
||||||
assert res.status.status_str == "success"
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user