From d25394d386a2e911c56543d5cbcccaf5d3975557 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 27 Sep 2024 12:07:54 -0700 Subject: [PATCH] API now supports fire-and-forget, checking on queue status; prefetch_count now expressly set to 1 for workers --- comfy/api/openapi.yaml | 117 +++++++------ comfy/client/aio_client.py | 162 ++++++++++-------- comfy/client/embedded_comfy_client.py | 2 +- comfy/cmd/server.py | 68 ++++++-- comfy/component_model/executor_types.py | 9 +- comfy/distributed/distributed_progress.py | 7 +- .../distributed/distributed_prompt_worker.py | 1 + comfy/distributed/history.py | 10 +- comfy/distributed/server_stub.py | 5 +- comfy/model_management.py | 7 +- tests/conftest.py | 31 ++-- tests/distributed/test_distributed_queue.py | 98 +++++++++++ 12 files changed, 361 insertions(+), 156 deletions(-) diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index 6087e5809..605e2baab 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -331,9 +331,25 @@ paths: description: >- A POST request to /free with: {"unload_models":true} will unload models from vram. A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow. + /api/v1/prompts/{prompt_id}: + get: + summary: (API) Get prompt status + responses: + 204: + description: | + The prompt is still in progress + 200: + description: | + Prompt outputs + content: + application/json: + $ref: "#/components/schemas/Outputs" + 404: + description: | + The prompt was not found /api/v1/prompts: get: - summary: (API) Get prompt + summary: (API) Get last prompt description: | Return the last prompt run anywhere that was used to produce an image @@ -395,56 +411,6 @@ paths: For each SaveImage node, there will be two URLs: the internal URL returned by the worker, and the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration. - - Hashing function for web browsers: - - ```js - async function generateHash(body) { - // Stringify and sort keys in the JSON object - let str = JSON.stringify(body); - - // Encode the string as a Uint8Array - let encoder = new TextEncoder(); - let data = encoder.encode(str); - - // Create a SHA-256 hash of the data - let hash = await window.crypto.subtle.digest('SHA-256', data); - - // Convert the hash (which is an ArrayBuffer) to a hex string - let hashArray = Array.from(new Uint8Array(hash)); - let hashHex = hashArray.map(b => b.toString(16).padStart(2, '0')).join(''); - - return hashHex; - } - ``` - - Hashing function for nodejs: - - ```js - const crypto = require('crypto'); - - function generateHash(body) { - // Stringify and sort keys in the JSON object - let str = JSON.stringify(body); - - // Create a SHA-256 hash of the string - let hash = crypto.createHash('sha256'); - hash.update(str); - - // Return the hexadecimal representation of the hash - return hash.digest('hex'); - } - ``` - - Hashing function for python: - ```python - def digest(data: dict | str) -> str: - json_str = data if isinstance(data, str) else json.dumps(data, separators=(',', ':')) - json_bytes = json_str.encode('utf-8') - hash_object = hashlib.sha256(json_bytes) - return hash_object.hexdigest() - - ``` type: object required: - urls @@ -475,6 +441,28 @@ paths: "^\\d+$": type: string format: binary + 202: + description: | + The prompt was successfully queued. + content: + application/json: + description: Information about the item that was queued + schema: + type: object + properties: + prompt_id: + type: string + description: The ID of the prompt that was queued + headers: + Location: + description: The relative URL to check on the status of the request + schema: + type: string + Retry-After: + description: | + A hint for the number of seconds to check the provided Location for the status of your request. + + This is the server's estimate for when the request will be completed. 204: description: | The prompt was run but did not contain any SaveImage outputs, so nothing will be returned. @@ -517,11 +505,38 @@ paths: - "application/json" - "image/png" - "multipart/mixed" + - "application/json+respond-async" + - "image/png+respond-async" + - "multipart/mixed+respond-async" required: false description: | Specifies the media type the client is willing to receive. multipart/mixed will soon be supported to return all the images from the workflow. + + If +respond-async is specified after your Accept mimetype, the request will be run async and you will get 202 when the prompt was queued. + - in: header + name: Prefer + schema: + type: string + enum: + - "respond-async" + - "" + required: false + allowEmptyValue: true + description: | + When respond-async is in your Prefer header, the request will be run async and you will get 202 when the prompt was queued. + - in: path + name: prefer + schema: + type: string + enum: + - "respond-async" + - "" + required: false + allowEmptyValue: true + description: | + When respond-async is in the prefer query parameter, the request will be run async and you will get 202 when the prompt was queued. requestBody: content: application/json: diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index 812594fdd..2e3f33c88 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -7,8 +7,7 @@ from typing import Optional, List from urllib.parse import urlparse, urljoin import aiohttp -from aiohttp import WSMessage, ClientResponse -from typing_extensions import Dict +from aiohttp import WSMessage, ClientResponse, ClientTimeout from .client_types import V1QueuePromptResponse from ..api.api_client import JSONEncoder @@ -33,15 +32,46 @@ class AsyncRemoteComfyClient: self.websocket_address = websocket_address if websocket_address is not None else urljoin( f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}") self.loop = loop or asyncio.get_event_loop() + self._session: aiohttp.ClientSession | None = None + try: + if asyncio.get_event_loop() is not None: + self._ensure_session() + except RuntimeError as no_running_event_loop: + pass + + def _ensure_session(self) -> aiohttp.ClientSession: + if self._session is None: + self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0)) + return self._session + + @property + def session(self) -> aiohttp.ClientSession: + return self._ensure_session() async def len_queue(self) -> int: - async with aiohttp.ClientSession() as session: - async with session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application.json'}) as response: - if response.status == 200: - exec_info_dict = await response.json() - return exec_info_dict["exec_info"]["queue_remaining"] - else: - raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}") + async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response: + if response.status == 200: + exec_info_dict = await response.json() + return exec_info_dict["exec_info"]["queue_remaining"] + else: + raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}") + + async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str: + """ + Calls the API to queue a prompt, and forgets about it + :param prompt: + :return: the task ID + """ + prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) + response: ClientResponse + async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response: + + if 200 <= response.status < 400: + response_json = await response.json() + return response_json["prompt_id"] + else: + raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: """ @@ -50,15 +80,14 @@ class AsyncRemoteComfyClient: :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) """ prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - async with aiohttp.ClientSession() as session: - response: ClientResponse - async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: + response: ClientResponse + async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: - if response.status == 200: - return V1QueuePromptResponse(**(await response.json())) - else: - raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + if 200 <= response.status < 400: + return V1QueuePromptResponse(**(await response.json())) + else: + raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]: """ @@ -68,24 +97,24 @@ class AsyncRemoteComfyClient: """ return (await self.queue_prompt_api(prompt)).urls - async def queue_prompt(self, prompt: PromptDict) -> bytes: + async def queue_prompt(self, prompt: PromptDict) -> bytes | None: """ Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node. :param prompt: :return: """ prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - async with aiohttp.ClientSession() as session: - response: ClientResponse - async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'image/png'}) as response: + response: ClientResponse + headers = {'Content-Type': 'application/json', 'Accept': 'image/png'} + async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, + headers=headers) as response: - if 200 <= response.status < 400: - return await response.read() - else: - raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + if 200 <= response.status < 400: + return await response.read() + else: + raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") - async def queue_prompt_ui(self, prompt: PromptDict) -> Dict[str, List[Path]]: + async def queue_prompt_ui(self, prompt: PromptDict) -> dict[str, List[Path]]: """ Uses the comfyui UI API calls to retrieve a list of paths of output files :param prompt: @@ -93,46 +122,45 @@ class AsyncRemoteComfyClient: """ prompt_request = PromptRequest.validate({"prompt": prompt, "client_id": self.client_id}) prompt_request_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt_request) - async with aiohttp.ClientSession() as session: - async with session.ws_connect(self.websocket_address) as ws: - async with session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json, - headers={'Content-Type': 'application/json'}) as response: - if response.status == 200: - prompt_id = (await response.json())["prompt_id"] - else: - raise RuntimeError("could not prompt") - msg: WSMessage - async for msg in ws: - # Handle incoming messages - if msg.type == aiohttp.WSMsgType.TEXT: - msg_json = msg.json() - if msg_json["type"] == "executing": - data = msg_json["data"] - if data['node'] is None and data['prompt_id'] == prompt_id: - break - elif msg.type == aiohttp.WSMsgType.CLOSED: - break - elif msg.type == aiohttp.WSMsgType.ERROR: - break - async with session.get(urljoin(self.server_address, "/history")) as response: + async with self.session.ws_connect(self.websocket_address) as ws: + async with self.session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json, + headers={'Content-Type': 'application/json'}) as response: if response.status == 200: - history_json = immutabledict(GetHistoryDict.validate(await response.json())) + prompt_id = (await response.json())["prompt_id"] else: - raise RuntimeError("Couldn't get history") + raise RuntimeError("could not prompt") + msg: WSMessage + async for msg in ws: + # Handle incoming messages + if msg.type == aiohttp.WSMsgType.TEXT: + msg_json = msg.json() + if msg_json["type"] == "executing": + data = msg_json["data"] + if data['node'] is None and data['prompt_id'] == prompt_id: + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + break + async with self.session.get(urljoin(self.server_address, "/history")) as response: + if response.status == 200: + history_json = immutabledict(GetHistoryDict.validate(await response.json())) + else: + raise RuntimeError("Couldn't get history") - # images have filename, subfolder, type keys - # todo: use the OpenAPI spec for this when I get around to updating it - outputs_by_node_id = history_json[prompt_id].outputs - res: Dict[str, List[Path]] = {} - for node_id, output in outputs_by_node_id.items(): - if 'images' in output: - images = [] - image_dicts: List[dict] = output['images'] - for image_file_output_dict in image_dicts: - image_file_output_dict = defaultdict(None, image_file_output_dict) - filename = image_file_output_dict['filename'] - subfolder = image_file_output_dict['subfolder'] - type = image_file_output_dict['type'] - images.append(Path(file_output_path(filename, subfolder=subfolder, type=type))) - res[node_id] = images - return res + # images have filename, subfolder, type keys + # todo: use the OpenAPI spec for this when I get around to updating it + outputs_by_node_id = history_json[prompt_id].outputs + res: dict[str, List[Path]] = {} + for node_id, output in outputs_by_node_id.items(): + if 'images' in output: + images = [] + image_dicts: List[dict] = output['images'] + for image_file_output_dict in image_dicts: + image_file_output_dict = defaultdict(None, image_file_output_dict) + filename = image_file_output_dict['filename'] + subfolder = image_file_output_dict['subfolder'] + type = image_file_output_dict['type'] + images.append(Path(file_output_path(filename, subfolder=subfolder, type=type))) + res[node_id] = images + return res diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index fd3d26bc0..35551e91e 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -167,7 +167,7 @@ class EmbeddedComfyClient: async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: outputs = await self.queue_prompt(prompt) - return V1QueuePromptResponse(**outputs) + return V1QueuePromptResponse(urls=[], outputs=outputs) @tracer.start_as_current_span("Queue Prompt") async def queue_prompt(self, diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index d9a20b1dc..8b182c15d 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -13,7 +13,7 @@ import sys import traceback import urllib import uuid -from asyncio import Future, AbstractEventLoop +from asyncio import Future, AbstractEventLoop, Task from enum import Enum from io import BytesIO from posixpath import join as urljoin @@ -190,7 +190,7 @@ class PromptServer(ExecutorToClientProgress): self.number: int = 0 self.port: int = 8188 self._external_address: Optional[str] = None - self.receive_all_progress_notifications = True + self.background_tasks: dict[str, Task] = dict() middlewares = [cache_control] if args.enable_cors_header: @@ -726,11 +726,35 @@ class PromptServer(ExecutorToClientProgress): return web.json_response(task.result().to_dict()) + @routes.get("/api/v1/prompts/{prompt_id}") + async def get_api_prompt(request: web.Request) -> web.Response | web.FileResponse: + prompt_id: str = request.match_info.get("prompt_id", "") + if prompt_id == "": + return web.Response(status=404) + + history_items = self.prompt_queue.get_history(prompt_id) + if len(history_items) == 0: + # todo: this should really be moved to a stateful queue abstraction + if prompt_id in self.background_tasks: + return web.Response(status=204) + else: + # todo: this should check a stateful queue abstraction + return web.Response(status=404) + else: + history_entry = history_items[prompt_id] + return web.json_response(history_entry["outputs"]) + @routes.post("/api/v1/prompts") async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: - # check if the queue is too long accept = request.headers.get("accept", "application/json") content_type = request.headers.get("content-type", "application/json") + preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + if "+" in content_type: + content_type = content_type.split("+")[0] + + wait = not "respond-async" in preferences + + # check if the queue is too long queue_size = self.prompt_queue.size() queue_too_busy_size = PromptServer.get_too_busy_queue_size() if queue_size > queue_too_busy_size: @@ -778,17 +802,25 @@ class PromptServer(ExecutorToClientProgress): result: TaskInvocation completed: Future[TaskInvocation | dict] = self.loop.create_future() - item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed) + task_id = str(uuid.uuid4()) + item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed) try: if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue): # this enables span propagation seamlessly - result = await self.prompt_queue.put_async(item) - if result is None: - return web.Response(body="the queue is shutting down", status=503) + fut = self.prompt_queue.put_async(item) + if wait: + result = await fut + if result is None: + return web.Response(body="the queue is shutting down", status=503) + else: + return await self._schedule_background_task_with_web_response(fut, task_id) else: self.prompt_queue.put(item) - await completed + if wait: + await completed + else: + return await self._schedule_background_task_with_web_response(completed, task_id) task_invocation_or_dict: TaskInvocation | dict = completed.result() if isinstance(task_invocation_or_dict, dict): result = TaskInvocation(item_id=item.prompt_id, outputs=task_invocation_or_dict, status=ExecutionStatus("success", True, [])) @@ -867,6 +899,18 @@ class PromptServer(ExecutorToClientProgress): prompt = last_history_item['prompt'][2] return web.json_response(prompt, status=200) + async def _schedule_background_task_with_web_response(self, fut, task_id): + task = asyncio.create_task(fut, name=task_id) + self.background_tasks[task_id] = task + task.add_done_callback(lambda _: self.background_tasks.pop(task_id)) + # todo: type this from the OpenAPI spec + return web.json_response({ + "prompt_id": task_id + }, status=202, headers={ + "Location": f"api/v1/prompts/{task_id}", + "Retry-After": "60" + }) + @property def external_address(self): return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}" @@ -875,6 +919,10 @@ class PromptServer(ExecutorToClientProgress): def external_address(self, value): self._external_address = value + @property + def receive_all_progress_notifications(self) -> bool: + return True + async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout) @@ -989,11 +1037,11 @@ class PromptServer(ExecutorToClientProgress): for addr in addresses: address = addr[0] port = addr[1] - site = web.TCPSite(runner, address, port) + site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size()) await site.start() if not hasattr(self, 'address'): - self.address = address #TODO: remove this + self.address = address # TODO: remove this self.port = port if ':' in address: diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 74e6b1bf1..1bd60a6d3 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -80,7 +80,14 @@ class ExecutorToClientProgress(Protocol): client_id: Optional[str] last_node_id: Optional[str] last_prompt_id: Optional[str] - receive_all_progress_notifications: Optional[bool] + + @property + def receive_all_progress_notifications(self): + """ + Set to true if this should receive progress bar updates, in addition to the standard execution lifecycle messages + :return: + """ + return False def send_sync(self, event: SendSyncEvent, diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py index 2db22f540..0177465a5 100644 --- a/comfy/distributed/distributed_progress.py +++ b/comfy/distributed/distributed_progress.py @@ -28,7 +28,7 @@ def _get_name(queue_name: str, user_id: str) -> str: class DistributedExecutorToClientProgress(ExecutorToClientProgress): - def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True): + def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop): self._rpc = rpc self._queue_name = queue_name self._loop = loop @@ -37,7 +37,10 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress): self.node_id = None self.last_node_id = None self.last_prompt_id = None - self.receive_all_progress_notifications = receive_all_progress_notifications + + @property + def receive_all_progress_notifications(self) -> bool: + return True async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index e66188064..79086ac16 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -119,6 +119,7 @@ class DistributedPromptWorker: logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error) raise connection_error self._channel = await self._connection.channel() + await self._channel.set_qos(prefetch_count=1) self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False) if self._embedded_comfy_client is None: diff --git a/comfy/distributed/history.py b/comfy/distributed/history.py index ae2e3d39b..d81461902 100644 --- a/comfy/distributed/history.py +++ b/comfy/distributed/history.py @@ -1,8 +1,8 @@ from __future__ import annotations import copy -from typing import Optional, OrderedDict, List, Dict import collections +import typing from itertools import islice from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE @@ -10,22 +10,22 @@ from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStat class History: def __init__(self): - self.history: OrderedDict[str, HistoryEntry] = collections.OrderedDict() + self.history: typing.OrderedDict[str, HistoryEntry] = collections.OrderedDict() def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus): self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple, outputs=outputs, status=ExecutionStatus(*status)._asdict()) - def copy(self, prompt_id: Optional[str | int] = None, max_items: Optional[int] = None, - offset: Optional[int] = None) -> Dict[str, HistoryEntry]: + def copy(self, prompt_id: typing.Optional[str | int] = None, max_items: typing.Optional[int] = None, + offset: typing.Optional[int] = None) -> dict[str, HistoryEntry]: if offset is not None and offset < 0: offset = max(len(self.history) + offset, 0) max_items = max_items or MAXIMUM_HISTORY_SIZE if prompt_id in self.history: return {prompt_id: copy.deepcopy(self.history[prompt_id])} else: - ordered_dict = OrderedDict() + ordered_dict = collections.OrderedDict() for k in islice(self.history, offset, max_items): ordered_dict[k] = copy.deepcopy(self.history[k]) return ordered_dict diff --git a/comfy/distributed/server_stub.py b/comfy/distributed/server_stub.py index 35a8ed5a4..c21d3a6cc 100644 --- a/comfy/distributed/server_stub.py +++ b/comfy/distributed/server_stub.py @@ -16,7 +16,6 @@ class ServerStub(ExecutorToClientProgress): self.client_id = str(uuid.uuid4()) self.last_node_id = None self.last_prompt_id = None - self.receive_all_progress_notifications = False def send_sync(self, event: Literal["status", "executing"] | BinaryEventTypes | str | None, @@ -25,3 +24,7 @@ class ServerStub(ExecutorToClientProgress): def queue_updated(self, queue_remaining: Optional[int] = None): pass + + @property + def receive_all_progress_notifications(self) -> bool: + return False diff --git a/comfy/model_management.py b/comfy/model_management.py index e0ff23812..79e11f1e2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -173,10 +173,15 @@ try: except: pass + +class _ComfyOutOfMemoryException(RuntimeError): + pass + + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: - OOM_EXCEPTION = Exception + OOM_EXCEPTION = _ComfyOutOfMemoryException XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True diff --git a/tests/conftest.py b/tests/conftest.py index 2224783f6..add7fd5e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import logging import multiprocessing import os import pathlib @@ -87,11 +86,7 @@ def has_gpu() -> bool: @pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"]) -def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str: - """ - populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend - :return: - """ +def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1): from huggingface_hub import hf_hub_download hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors") hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors") @@ -99,6 +94,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str: tmp_path = tmp_path_factory.mktemp("comfy_background_server") executor_factory = request.param processes_to_close: List[subprocess.Popen] = [] + from testcontainers.rabbitmq import RabbitMqContainer with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() @@ -115,15 +111,18 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str: ] processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) - backend_command = [ - "comfyui-worker", - "--port=9002", - f"-w={str(tmp_path)}", - f"--distributed-queue-connection-uri={connection_uri}", - f"--executor-factory={executor_factory}" - ] - processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) + # Start multiple workers + for i in range(num_workers): + backend_command = [ + "comfyui-worker", + f"--port={9002 + i}", + f"-w={str(tmp_path)}", + f"--distributed-queue-connection-uri={connection_uri}", + f"--executor-factory={executor_factory}" + ] + processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) + try: server_address = f"http://{get_lan_ip()}:9001" start_time = time.time() @@ -134,10 +133,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str: if response.status_code == 200: connected = True break - except ConnectionRefusedError: + except requests.exceptions.ConnectionError: pass - except Exception as exc: - logging.warning("", exc_info=exc) time.sleep(1) if not connected: raise RuntimeError("could not connect to frontend") diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 3e5084143..73437fe29 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -139,3 +139,101 @@ async def test_basic_queue_worker_with_health_check(executor_factory): health_check_ok = await check_health(health_check_url) assert health_check_ok, "Health check server did not start properly" + + +@pytest.mark.asyncio +async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq): + # Create the client using the server address from the fixture + client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) + + # Create a test prompt + prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1) + + # Queue the prompt + task_id = await client.queue_and_forget_prompt_api(prompt) + + assert task_id is not None, "Failed to get a valid task ID" + + # Poll for the result + max_attempts = 60 # Increase max attempts for integration test + poll_interval = 1 # Increase poll interval for integration test + for _ in range(max_attempts): + try: + response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}") + if response.status == 200: + result = await response.json() + assert result is not None, "Received empty result" + + # Find the first output node with images + output_node = next((node for node in result.values() if 'images' in node), None) + assert output_node is not None, "No output node with images found" + + assert len(output_node['images']) > 0, "No images in output node" + assert 'filename' in output_node['images'][0], "No filename in image output" + assert 'subfolder' in output_node['images'][0], "No subfolder in image output" + assert 'type' in output_node['images'][0], "No type in image output" + + # Check if we can access the image + image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}" + image_response = await client.session.get(image_url) + assert image_response.status == 200, f"Failed to retrieve image from {image_url}" + + return # Test passed + elif response.status == 204: + await asyncio.sleep(poll_interval) + else: + response.raise_for_status() + except Exception as e: + print(f"Error while polling: {e}") + await asyncio.sleep(poll_interval) + + pytest.fail("Failed to get a 200 response with valid data within the timeout period") + + +class TestWorker(DistributedPromptWorker): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.processed_workflows: set[str] = set() + + async def on_will_complete_work_item(self, request: dict): + workflow_id = request.get('prompt_id', 'unknown') + self.processed_workflows.add(workflow_id) + await super().on_will_complete_work_item(request) + + +@pytest.mark.asyncio +async def test_two_workers_distinct_requests(): + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + # Start two test workers + workers: list[TestWorker] = [] + for i in range(2): + worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1)) + await worker.init() + workers.append(worker) + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + queue = DistributedPromptQueue(is_callee=False, is_caller=True, connection_uri=connection_uri) + await queue.init() + + # Submit two prompts + task1 = asyncio.create_task(queue.put_async(create_test_prompt())) + task2 = asyncio.create_task(queue.put_async(create_test_prompt())) + + # Wait for tasks to complete + await asyncio.gather(task1, task2) + + # Clean up + for worker in workers: + await worker.close() + await queue.close() + + # Assert that each worker processed exactly one distinct workflow + all_workflows = set() + for worker in workers: + assert len(worker.processed_workflows) == 1, f"Worker processed {len(worker.processed_workflows)} workflows instead of 1" + all_workflows.update(worker.processed_workflows) + + assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"