mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improved support for distributed queues
This commit is contained in:
parent
3367362cec
commit
92898b8c9d
@ -186,7 +186,7 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
|
||||
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
|
||||
7. To create the standalone binary:
|
||||
```shell
|
||||
python -m pyinstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py
|
||||
python -m PyInstaller --onefile --noupx -n ComfyUI --add-data="comfy/;comfy/" --paths $(pwd) --paths comfy/cmd main.py
|
||||
```
|
||||
|
||||
### Authoring Custom Nodes
|
||||
|
||||
@ -119,6 +119,17 @@ parser.add_argument("--plausible-analytics-domain", required=False,
|
||||
help="Specifies the domain name for analytics events.")
|
||||
parser.add_argument("--analytics-use-identity-provider", action="store_true",
|
||||
help="Uses platform identifiers for unique visitor analytics.")
|
||||
parser.add_argument("--distributed-queue-connection-uri", type=str, default=None,
|
||||
help="Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress updates.")
|
||||
parser.add_argument(
|
||||
'--distributed-queue-roles',
|
||||
action='append',
|
||||
choices=['worker', 'prompter'],
|
||||
help='Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMPQ URL to submit prompts; workers will pull requests off the AMPQ URL.'
|
||||
)
|
||||
parser.add_argument("--distributed-queue-name", type=str, default="comfyui",
|
||||
help="This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
|
||||
|
||||
|
||||
if options.args_parsing:
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
@ -70,6 +70,9 @@ class Configuration(dict):
|
||||
analytics_use_identity_provider (bool): Use platform identifiers for analytics.
|
||||
write_out_config_file (bool): Enable writing out the configuration file.
|
||||
create_directories (bool): Creates the default models/, input/, output/ and temp/ directories, then exits.
|
||||
distributed_queue_connection_uri (Optional[str]): Servers and clients will connect to this AMQP URL to form a distributed queue and exchange prompt execution requests and progress updates.
|
||||
distributed_queue_roles (List[str]): Specifies one or more roles for the distributed queue. Acceptable values are "worker" or "prompter", or both by writing the flag twice with each role. Prompters will start the web UI and connect to the provided AMQP URL to submit prompts; workers will pull requests off the AMQP URL.
|
||||
distributed_queue_name (str): This name will be used by the prompters and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
@ -127,6 +130,9 @@ class Configuration(dict):
|
||||
self.analytics_use_identity_provider: bool = False
|
||||
self.write_out_config_file: bool = False
|
||||
self.create_directories: bool = False
|
||||
self.distributed_queue_connection_uri: Optional[str] = None
|
||||
self.distributed_queue_roles: List[str] = []
|
||||
self.distributed_queue_name: str = "comfyui"
|
||||
for key, value in kwargs.items():
|
||||
self[key] = value
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ class ServerStub(ExecutorToClientProgress):
|
||||
"""
|
||||
This class is a stub implementation of ExecutorToClientProgress. This will handle progress events.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.last_node_id = None
|
||||
@ -69,18 +70,25 @@ class EmbeddedComfyClient:
|
||||
In order to use this in blocking methods, learn more about asyncio online.
|
||||
"""
|
||||
|
||||
def __init__(self, configuration: Optional[Configuration] = None, loop: Optional[AbstractEventLoop] = None,
|
||||
def __init__(self, configuration: Optional[Configuration] = None,
|
||||
progress_handler: Optional[ExecutorToClientProgress] = None,
|
||||
loop: Optional[AbstractEventLoop] = None,
|
||||
max_workers: int = 1):
|
||||
self._server_stub = ServerStub()
|
||||
self._progress_handler = progress_handler or ServerStub()
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._configuration = configuration
|
||||
# we don't want to import the executor yet
|
||||
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
|
||||
self._is_running = False
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._is_running
|
||||
|
||||
async def __aenter__(self):
|
||||
# Perform asynchronous initialization here, if needed
|
||||
await self._initialize_prompt_executor()
|
||||
self._is_running = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
@ -101,6 +109,7 @@ class EmbeddedComfyClient:
|
||||
await self._loop.run_in_executor(self._executor, cleanup)
|
||||
|
||||
self._executor.shutdown(wait=True)
|
||||
self._is_running = False
|
||||
|
||||
async def _initialize_prompt_executor(self):
|
||||
# This method must be async since it's used in __aenter__
|
||||
@ -115,19 +124,23 @@ class EmbeddedComfyClient:
|
||||
|
||||
from ..cmd.execution import PromptExecutor
|
||||
|
||||
self._prompt_executor = PromptExecutor(self._server_stub)
|
||||
self._prompt_executor = PromptExecutor(self._progress_handler)
|
||||
|
||||
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
|
||||
|
||||
async def queue_prompt(self, prompt: PromptDict) -> dict:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
async def queue_prompt(self,
|
||||
prompt: PromptDict | dict,
|
||||
prompt_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None) -> dict:
|
||||
prompt_id = prompt_id or str(uuid.uuid4())
|
||||
client_id = client_id or self._progress_handler.client_id or None
|
||||
|
||||
def execute_prompt() -> dict:
|
||||
from ..cmd.execution import validate_prompt
|
||||
prompt_mut = make_mutable(prompt)
|
||||
validation_tuple = validate_prompt(prompt_mut)
|
||||
|
||||
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": self._server_stub.client_id},
|
||||
self._prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple[2])
|
||||
if self._prompt_executor.success:
|
||||
return self._prompt_executor.outputs_ui
|
||||
|
||||
@ -807,7 +807,7 @@ class PromptQueue(AbstractPromptQueue):
|
||||
|
||||
status_dict: Optional[dict] = None
|
||||
if status is not None:
|
||||
status_dict = copy.deepcopy(status._asdict())
|
||||
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Mapping, Any
|
||||
|
||||
|
||||
def make_mutable(obj: Any) -> dict:
|
||||
"""
|
||||
Makes an immutable dict, frozenset or tuple mutable. Otherwise, returns the value.
|
||||
:param obj: any object
|
||||
:return:
|
||||
"""
|
||||
def _make_mutable(obj: Any) -> Any:
|
||||
if isinstance(obj, Mapping) and not isinstance(obj, dict) and not hasattr(obj, "__setitem__"):
|
||||
obj = dict(obj)
|
||||
for key, value in obj.items():
|
||||
@ -18,3 +15,17 @@ def make_mutable(obj: Any) -> dict:
|
||||
if isinstance(obj, frozenset):
|
||||
obj = set([make_mutable(x) for x in obj])
|
||||
return obj
|
||||
|
||||
|
||||
def make_mutable(obj: Mapping) -> dict:
|
||||
"""
|
||||
Makes a copy of an immutable dict into a mutable dict.
|
||||
|
||||
If the object is already a mutable type or a value type like a string or integer, returns the value.
|
||||
|
||||
Returns dict, set or tuple depending on its input, but you should not use it this way.
|
||||
|
||||
:param obj: any object
|
||||
:return:
|
||||
"""
|
||||
return _make_mutable(obj)
|
||||
|
||||
@ -40,21 +40,21 @@ class HistoryEntry(TypedDict):
|
||||
status: NotRequired[ExecutionStatusAsDict]
|
||||
|
||||
|
||||
class ExtraData(TypedDict):
|
||||
client_id: NotRequired[str]
|
||||
extra_pnginfo: NotRequired[str]
|
||||
token: NotRequired[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueItem:
|
||||
class NamedQueueTuple:
|
||||
"""
|
||||
An item awaiting processing in the queue
|
||||
A wrapper class for a queue tuple, the object that is given to executors.
|
||||
|
||||
Attributes:
|
||||
queue_tuple (QueueTuple): the corresponding queued workflow and other related data
|
||||
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
|
||||
or a dictionary of outputs
|
||||
"""
|
||||
queue_tuple: QueueTuple
|
||||
completed: asyncio.Future[TaskInvocation | dict] | None
|
||||
|
||||
def __lt__(self, other: QueueItem):
|
||||
return self.queue_tuple[0] < other.queue_tuple[0]
|
||||
|
||||
@property
|
||||
def priority(self) -> float:
|
||||
@ -69,7 +69,7 @@ class QueueItem:
|
||||
return self.queue_tuple[2]
|
||||
|
||||
@property
|
||||
def extra_data(self) -> Optional[dict]:
|
||||
def extra_data(self) -> Optional[ExtraData]:
|
||||
if len(self.queue_tuple) > 2:
|
||||
return self.queue_tuple[3]
|
||||
else:
|
||||
@ -83,6 +83,22 @@ class QueueItem:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueItem(NamedQueueTuple):
|
||||
"""
|
||||
An item awaiting processing in the queue: a NamedQueueTuple with a future that is completed when the item is done
|
||||
processing.
|
||||
|
||||
Attributes:
|
||||
completed (Optional[Future[TaskInvocation | dict]]): A future of a task invocation (the signature of the task_done method)
|
||||
or a dictionary of outputs
|
||||
"""
|
||||
completed: asyncio.Future[TaskInvocation | dict] | None
|
||||
|
||||
def __lt__(self, other: QueueItem):
|
||||
return self.queue_tuple[0] < other.queue_tuple[0]
|
||||
|
||||
|
||||
class BinaryEventTypes(Enum):
|
||||
PREVIEW_IMAGE = 1
|
||||
UNENCODED_PREVIEW_IMAGE = 2
|
||||
|
||||
32
comfy/distributed/distributed_prompt_client.py
Normal file
32
comfy/distributed/distributed_prompt_client.py
Normal file
@ -0,0 +1,32 @@
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from dataclasses import asdict
|
||||
from typing import Optional
|
||||
|
||||
from aio_pika import connect_robust
|
||||
from aio_pika.patterns import RPC
|
||||
|
||||
from comfy.distributed.distributed_types import RpcRequest, RpcReply
|
||||
|
||||
|
||||
class DistributedPromptClient:
|
||||
def __init__(self, queue_name: str = "comfyui",
|
||||
connection_uri="amqp://localhost/",
|
||||
loop: Optional[AbstractEventLoop] = None):
|
||||
self.queue_name = queue_name
|
||||
self.connection_uri = connection_uri
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
|
||||
self.channel = await self.connection.channel()
|
||||
self.rpc = await RPC.create(channel=self.channel)
|
||||
self.rpc.host_exceptions = True
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.channel.close()
|
||||
await self.rpc.close()
|
||||
await self.connection.close()
|
||||
|
||||
async def queue_prompt(self, request: RpcRequest) -> RpcReply:
|
||||
return RpcReply(**(await self.rpc.call(self.queue_name, {"request": asdict(request)})))
|
||||
@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
import uuid
|
||||
from asyncio import AbstractEventLoop, Queue
|
||||
from dataclasses import asdict
|
||||
from typing import Optional, Dict, List, Mapping, Tuple, Callable
|
||||
|
||||
import jwt
|
||||
from aio_pika import connect_robust
|
||||
from aio_pika.abc import AbstractConnection, AbstractChannel
|
||||
from aio_pika.patterns import RPC
|
||||
from aio_pika.patterns import JsonRPC
|
||||
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||
@ -29,11 +33,30 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
|
||||
async def put_async(self, queue_item: QueueItem):
|
||||
assert self.is_caller
|
||||
if self._closing:
|
||||
return
|
||||
self.caller_local_in_progress[queue_item.prompt_id] = queue_item
|
||||
if self.caller_server is not None:
|
||||
self.caller_server.queue_updated()
|
||||
try:
|
||||
res: TaskInvocation = await self.rpc.call(self.queue_name, {"item": queue_item.queue_tuple})
|
||||
if "token" in queue_item.extra_data:
|
||||
user_token = queue_item.extra_data["token"]
|
||||
else:
|
||||
if "client_id" in queue_item.extra_data:
|
||||
client_id = queue_item.extra_data["client_id"]
|
||||
elif self.caller_server.client_id is not None:
|
||||
client_id = self.caller_server.client_id
|
||||
else:
|
||||
client_id = str(uuid.uuid4())
|
||||
# todo: should we really do this?
|
||||
self.caller_server.client_id = client_id
|
||||
|
||||
# create a stub token
|
||||
user_token = jwt.encode({"sub": client_id}, key="", algorithm="none")
|
||||
request = RpcRequest(prompt_id=queue_item.prompt_id, user_token=user_token, prompt=queue_item.prompt)
|
||||
assert self.rpc is not None
|
||||
res: TaskInvocation = RpcReply(
|
||||
**(await self.rpc.call(self.queue_name, {"request": asdict(request)}))).as_task_invocation()
|
||||
|
||||
self.caller_history.put(queue_item, res.outputs, res.status)
|
||||
if self.caller_server is not None:
|
||||
@ -63,12 +86,16 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
def put(self, item: QueueItem):
|
||||
# caller: execute on main thread
|
||||
assert self.is_caller
|
||||
if self._closing:
|
||||
return
|
||||
# this is called by the web server and its event loop is perfectly fine to use
|
||||
# the future is now ignored
|
||||
self.loop.call_soon_threadsafe(self.put_async, item)
|
||||
|
||||
async def _callee_do_work_item(self, item: QueueTuple) -> TaskInvocation:
|
||||
async def _callee_do_work_item(self, request: dict) -> dict:
|
||||
assert self.is_callee
|
||||
request_obj = RpcRequest.from_dict(request)
|
||||
item = request_obj.as_queue_tuple().queue_tuple
|
||||
item_with_completer = QueueItem(item, self.loop.create_future())
|
||||
self.callee_local_in_progress[item_with_completer.prompt_id] = item_with_completer
|
||||
# todo: check if we have the local model content needed to execute this request and if not, reject it
|
||||
@ -80,9 +107,10 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
assert not item_with_completer.completed.done()
|
||||
|
||||
# now we wait for the worker thread to complete the item
|
||||
return await item_with_completer.completed
|
||||
invocation = await item_with_completer.completed
|
||||
return asdict(RpcReply.from_task_invocation(invocation, request_obj.user_token))
|
||||
|
||||
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, int]]:
|
||||
def get(self, timeout: float | None = None) -> Optional[Tuple[QueueTuple, str | int]]:
|
||||
# callee: executed on the worker thread
|
||||
assert self.is_callee
|
||||
try:
|
||||
@ -175,6 +203,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
self.channel: Optional[AbstractChannel] = None # Channel will be set up asynchronously
|
||||
self.is_caller = is_caller
|
||||
self.is_callee = is_callee
|
||||
self._closing = False
|
||||
|
||||
# as rpc caller
|
||||
self.caller_server = server
|
||||
@ -182,9 +211,9 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
self.caller_history: History = History()
|
||||
|
||||
# as rpc callee
|
||||
self.callee_local_queue = asyncio.Queue()
|
||||
self.callee_local_queue: Queue = Queue()
|
||||
self.callee_local_in_progress: Dict[int | str, QueueItem] = {}
|
||||
self.rpc: Optional[RPC] = None
|
||||
self.rpc: Optional[JsonRPC] = None
|
||||
|
||||
# todo: the prompt queue really shouldn't do this
|
||||
if server is not None:
|
||||
@ -193,8 +222,14 @@ class DistributedPromptQueue(AbstractPromptQueue):
|
||||
async def init(self):
|
||||
self.connection = await connect_robust(self.connection_uri, loop=self.loop)
|
||||
self.channel = await self.connection.channel()
|
||||
self.rpc = await RPC.create(channel=self.channel)
|
||||
self.rpc = await JsonRPC.create(channel=self.channel)
|
||||
self.rpc.host_exceptions = True
|
||||
# this makes the queue available to complete work items
|
||||
if self.is_callee:
|
||||
await self.rpc.register(self.queue_name, self._callee_do_work_item)
|
||||
|
||||
async def close(self):
|
||||
self._closing = True
|
||||
await self.rpc.close()
|
||||
await self.channel.close()
|
||||
await self.connection.close()
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import asdict
|
||||
from typing import Optional
|
||||
|
||||
from aio_pika import connect_robust
|
||||
from aio_pika.patterns import RPC
|
||||
from aio_pika.patterns import JsonRPC
|
||||
|
||||
from ..api.components.schema.prompt import Prompt
|
||||
from ..cli_args_types import Configuration
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from ..component_model.queue_types import TaskInvocation, QueueTuple, QueueItem, ExecutionStatus
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
|
||||
|
||||
class DistributedPromptWorker:
|
||||
@ -16,31 +17,42 @@ class DistributedPromptWorker:
|
||||
A work in progress distributed prompt worker.
|
||||
"""
|
||||
|
||||
def __init__(self, embedded_comfy_client: EmbeddedComfyClient,
|
||||
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
|
||||
connection_uri: str = "amqp://localhost:5672/",
|
||||
queue_name: str = "comfyui",
|
||||
loop: Optional[AbstractEventLoop] = None, configuration: Configuration = None):
|
||||
loop: Optional[AbstractEventLoop] = None):
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._queue_name = queue_name
|
||||
self._configuration = configuration
|
||||
self._connection_uri = connection_uri
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._embedded_comfy_client = embedded_comfy_client
|
||||
self._embedded_comfy_client = embedded_comfy_client or EmbeddedComfyClient()
|
||||
|
||||
async def _do_work_item(self, item: QueueTuple) -> TaskInvocation:
|
||||
item_without_completer = QueueItem(item, completed=None)
|
||||
async def _do_work_item(self, request: dict) -> dict:
|
||||
try:
|
||||
output_dict = await self._embedded_comfy_client.queue_prompt(Prompt.validate(item_without_completer.prompt))
|
||||
return TaskInvocation(item_without_completer.prompt_id, outputs=output_dict,
|
||||
status=ExecutionStatus("success", True, []))
|
||||
request_obj = RpcRequest.from_dict(request)
|
||||
except Exception as e:
|
||||
return TaskInvocation(item_without_completer.prompt_id, outputs={},
|
||||
status=ExecutionStatus("error", False, [str(e)]))
|
||||
request_dict_prompt_id_recovered = request["prompt_id"] \
|
||||
if request is not None and "prompt_id" in request else ""
|
||||
return asdict(RpcReply(request_dict_prompt_id_recovered, "", {},
|
||||
ExecutionStatus("error", False, [str(e)])))
|
||||
try:
|
||||
output_dict = await self._embedded_comfy_client.queue_prompt(request_obj.prompt,
|
||||
request_obj.prompt_id,
|
||||
client_id=request_obj.user_id)
|
||||
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, output_dict, ExecutionStatus("success", True, [])))
|
||||
except Exception as e:
|
||||
return asdict(RpcReply(request_obj.prompt_id, request_obj.user_token, {}, ExecutionStatus("error", False, [str(e)])))
|
||||
|
||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||
await self._exit_stack.__aenter__()
|
||||
if not self._embedded_comfy_client.is_running:
|
||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||
|
||||
self._connection = await connect_robust(self._connection_uri, loop=self._loop)
|
||||
self._channel = await self._connection.channel()
|
||||
self._rpc = await RPC.create(channel=self._channel)
|
||||
self._rpc = await JsonRPC.create(channel=self._channel)
|
||||
self._rpc.host_exceptions = True
|
||||
|
||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||
return self
|
||||
|
||||
@ -48,3 +60,4 @@ class DistributedPromptWorker:
|
||||
await self._rpc.close()
|
||||
await self._channel.close()
|
||||
await self._connection.close()
|
||||
return await self._exit_stack.__aexit__(*args)
|
||||
|
||||
58
comfy/distributed/distributed_types.py
Normal file
58
comfy/distributed/distributed_types.py
Normal file
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Literal, List
|
||||
|
||||
import jwt
|
||||
|
||||
from ..api.components.schema.prompt import PromptDict, Prompt
|
||||
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedBase:
|
||||
prompt_id: str
|
||||
user_token: str
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
return self.decoded_token["sub"]
|
||||
|
||||
@property
|
||||
def decoded_token(self) -> dict:
|
||||
return jwt.decode(self.user_token, algorithms=['HS256', "none"],
|
||||
# todo: this should be configurable
|
||||
options={"verify_signature": False, 'verify_aud': False, 'verify_iss': False})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RpcRequest(DistributedBase):
|
||||
prompt: dict | PromptDict
|
||||
|
||||
def as_queue_tuple(self) -> NamedQueueTuple:
|
||||
# this loads the nodes in this instance
|
||||
# should always be okay to call in an executor
|
||||
from ..cmd.execution import validate_prompt
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
mutated_prompt_dict = make_mutable(self.prompt)
|
||||
validation_tuple = validate_prompt(mutated_prompt_dict)
|
||||
return NamedQueueTuple(queue_tuple=(0, self.prompt_id, mutated_prompt_dict, {}, validation_tuple[2]))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, request_dict):
|
||||
request = RpcRequest(**request_dict)
|
||||
request.prompt = Prompt.validate(request.prompt)
|
||||
return request
|
||||
|
||||
|
||||
@dataclass
|
||||
class RpcReply(DistributedBase):
|
||||
outputs: dict
|
||||
execution_status: ExecutionStatus | Tuple[Literal['success', 'error'], bool, List[str]]
|
||||
|
||||
@staticmethod
|
||||
def from_task_invocation(task_invocation: TaskInvocation, user_token: str) -> 'RpcReply':
|
||||
return RpcReply(str(task_invocation.item_id), user_token, task_invocation.outputs, task_invocation.status)
|
||||
|
||||
def as_task_invocation(self):
|
||||
return TaskInvocation(self.prompt_id, self.outputs, ExecutionStatus(*self.execution_status))
|
||||
@ -15,7 +15,7 @@ class History:
|
||||
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
|
||||
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
|
||||
outputs=outputs,
|
||||
status=status._asdict())
|
||||
status=ExecutionStatus(*status)._asdict())
|
||||
|
||||
def copy(self, prompt_id: Optional[str | int] = None, max_items: int = MAXIMUM_HISTORY_SIZE,
|
||||
offset: int = 0) -> Dict[str, HistoryEntry]:
|
||||
|
||||
@ -28,4 +28,5 @@ tqdm
|
||||
protobuf==3.20.3
|
||||
psutil
|
||||
ConfigArgParse
|
||||
aio-pika
|
||||
aio-pika
|
||||
pyjwt[crypto]
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient, ServerStub
|
||||
@ -8,32 +9,41 @@ from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
from comfy.component_model.make_mutable import make_mutable
|
||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation
|
||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||
os.environ["TC_HOST"] = "localhost"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sign_jwt_auth_none():
|
||||
client_id = str(uuid.uuid4())
|
||||
user_token_str = jwt.encode({"sub": client_id}, None, algorithm="none")
|
||||
user_token = jwt.decode(user_token_str, None, algorithms=["none"], options={"verify_signature": False})
|
||||
assert user_token["sub"] == client_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_queue_worker() -> None:
|
||||
os.environ["TC_HOST"] = "localhost"
|
||||
# there are lots of side effects from importing that we have to deal with
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
async with EmbeddedComfyClient() as client:
|
||||
async with DistributedPromptWorker(client,
|
||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
||||
# this unfortunately does a bunch of initialization on the test thread
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
# now submit some jobs
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||
await distributed_queue.init()
|
||||
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||
validation_tuple = validate_prompt(prompt)
|
||||
item_id = str(uuid.uuid4())
|
||||
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
|
||||
assert res.item_id == item_id
|
||||
assert len(res.outputs) == 1
|
||||
assert res.status is not None
|
||||
assert res.status.status_str == "success"
|
||||
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
|
||||
# this unfortunately does a bunch of initialization on the test thread
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
# now submit some jobs
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
|
||||
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
|
||||
await distributed_queue.init()
|
||||
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||
validation_tuple = validate_prompt(prompt)
|
||||
item_id = str(uuid.uuid4())
|
||||
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||
res: TaskInvocation = await distributed_queue.put_async(QueueItem(queue_tuple, None))
|
||||
assert res.item_id == item_id
|
||||
assert len(res.outputs) == 1
|
||||
assert res.status is not None
|
||||
assert res.status.status_str == "success"
|
||||
await distributed_queue.close()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user