API now supports fire-and-forget, checking on queue status; prefetch_count now expressly set to 1 for workers

This commit is contained in:
doctorpangloss 2024-09-27 12:07:54 -07:00
parent a664a1fbc9
commit d25394d386
12 changed files with 361 additions and 156 deletions

View File

@ -331,9 +331,25 @@ paths:
description: >-
A POST request to /free with: {"unload_models":true} will unload models from vram.
A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow.
/api/v1/prompts/{prompt_id}:
get:
summary: (API) Get prompt status
responses:
204:
description: |
The prompt is still in progress
200:
description: |
Prompt outputs
content:
application/json:
$ref: "#/components/schemas/Outputs"
404:
description: |
The prompt was not found
/api/v1/prompts:
get:
summary: (API) Get prompt
summary: (API) Get last prompt
description: |
Return the last prompt run anywhere that was used to produce an image
@ -395,56 +411,6 @@ paths:
For each SaveImage node, there will be two URLs: the internal URL returned by the worker, and
the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration.
Hashing function for web browsers:
```js
async function generateHash(body) {
// Stringify and sort keys in the JSON object
let str = JSON.stringify(body);
// Encode the string as a Uint8Array
let encoder = new TextEncoder();
let data = encoder.encode(str);
// Create a SHA-256 hash of the data
let hash = await window.crypto.subtle.digest('SHA-256', data);
// Convert the hash (which is an ArrayBuffer) to a hex string
let hashArray = Array.from(new Uint8Array(hash));
let hashHex = hashArray.map(b => b.toString(16).padStart(2, '0')).join('');
return hashHex;
}
```
Hashing function for nodejs:
```js
const crypto = require('crypto');
function generateHash(body) {
// Stringify and sort keys in the JSON object
let str = JSON.stringify(body);
// Create a SHA-256 hash of the string
let hash = crypto.createHash('sha256');
hash.update(str);
// Return the hexadecimal representation of the hash
return hash.digest('hex');
}
```
Hashing function for python:
```python
def digest(data: dict | str) -> str:
json_str = data if isinstance(data, str) else json.dumps(data, separators=(',', ':'))
json_bytes = json_str.encode('utf-8')
hash_object = hashlib.sha256(json_bytes)
return hash_object.hexdigest()
```
type: object
required:
- urls
@ -475,6 +441,28 @@ paths:
"^\\d+$":
type: string
format: binary
202:
description: |
The prompt was successfully queued.
content:
application/json:
description: Information about the item that was queued
schema:
type: object
properties:
prompt_id:
type: string
description: The ID of the prompt that was queued
headers:
Location:
description: The relative URL to check on the status of the request
schema:
type: string
Retry-After:
description: |
A hint for the number of seconds to check the provided Location for the status of your request.
This is the server's estimate for when the request will be completed.
204:
description: |
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
@ -517,11 +505,38 @@ paths:
- "application/json"
- "image/png"
- "multipart/mixed"
- "application/json+respond-async"
- "image/png+respond-async"
- "multipart/mixed+respond-async"
required: false
description: |
Specifies the media type the client is willing to receive.
multipart/mixed will soon be supported to return all the images from the workflow.
If +respond-async is specified after your Accept mimetype, the request will be run async and you will get 202 when the prompt was queued.
- in: header
name: Prefer
schema:
type: string
enum:
- "respond-async"
- ""
required: false
allowEmptyValue: true
description: |
When respond-async is in your Prefer header, the request will be run async and you will get 202 when the prompt was queued.
- in: path
name: prefer
schema:
type: string
enum:
- "respond-async"
- ""
required: false
allowEmptyValue: true
description: |
When respond-async is in the prefer query parameter, the request will be run async and you will get 202 when the prompt was queued.
requestBody:
content:
application/json:

View File

@ -7,8 +7,7 @@ from typing import Optional, List
from urllib.parse import urlparse, urljoin
import aiohttp
from aiohttp import WSMessage, ClientResponse
from typing_extensions import Dict
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from .client_types import V1QueuePromptResponse
from ..api.api_client import JSONEncoder
@ -33,15 +32,46 @@ class AsyncRemoteComfyClient:
self.websocket_address = websocket_address if websocket_address is not None else urljoin(
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
self.loop = loop or asyncio.get_event_loop()
self._session: aiohttp.ClientSession | None = None
try:
if asyncio.get_event_loop() is not None:
self._ensure_session()
except RuntimeError as no_running_event_loop:
pass
def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None:
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
return self._session
@property
def session(self) -> aiohttp.ClientSession:
return self._ensure_session()
async def len_queue(self) -> int:
async with aiohttp.ClientSession() as session:
async with session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application.json'}) as response:
if response.status == 200:
exec_info_dict = await response.json()
return exec_info_dict["exec_info"]["queue_remaining"]
else:
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
if response.status == 200:
exec_info_dict = await response.json()
return exec_info_dict["exec_info"]["queue_remaining"]
else:
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str:
"""
Calls the API to queue a prompt, and forgets about it
:param prompt:
:return: the task ID
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response:
if 200 <= response.status < 400:
response_json = await response.json()
return response_json["prompt_id"]
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
"""
@ -50,15 +80,14 @@ class AsyncRemoteComfyClient:
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
async with aiohttp.ClientSession() as session:
response: ClientResponse
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
response: ClientResponse
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
if response.status == 200:
return V1QueuePromptResponse(**(await response.json()))
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
if 200 <= response.status < 400:
return V1QueuePromptResponse(**(await response.json()))
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
"""
@ -68,24 +97,24 @@ class AsyncRemoteComfyClient:
"""
return (await self.queue_prompt_api(prompt)).urls
async def queue_prompt(self, prompt: PromptDict) -> bytes:
async def queue_prompt(self, prompt: PromptDict) -> bytes | None:
"""
Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node.
:param prompt:
:return:
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
async with aiohttp.ClientSession() as session:
response: ClientResponse
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'image/png'}) as response:
response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': 'image/png'}
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers=headers) as response:
if 200 <= response.status < 400:
return await response.read()
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
if 200 <= response.status < 400:
return await response.read()
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_ui(self, prompt: PromptDict) -> Dict[str, List[Path]]:
async def queue_prompt_ui(self, prompt: PromptDict) -> dict[str, List[Path]]:
"""
Uses the comfyui UI API calls to retrieve a list of paths of output files
:param prompt:
@ -93,46 +122,45 @@ class AsyncRemoteComfyClient:
"""
prompt_request = PromptRequest.validate({"prompt": prompt, "client_id": self.client_id})
prompt_request_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt_request)
async with aiohttp.ClientSession() as session:
async with session.ws_connect(self.websocket_address) as ws:
async with session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json,
headers={'Content-Type': 'application/json'}) as response:
if response.status == 200:
prompt_id = (await response.json())["prompt_id"]
else:
raise RuntimeError("could not prompt")
msg: WSMessage
async for msg in ws:
# Handle incoming messages
if msg.type == aiohttp.WSMsgType.TEXT:
msg_json = msg.json()
if msg_json["type"] == "executing":
data = msg_json["data"]
if data['node'] is None and data['prompt_id'] == prompt_id:
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
break
elif msg.type == aiohttp.WSMsgType.ERROR:
break
async with session.get(urljoin(self.server_address, "/history")) as response:
async with self.session.ws_connect(self.websocket_address) as ws:
async with self.session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json,
headers={'Content-Type': 'application/json'}) as response:
if response.status == 200:
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
prompt_id = (await response.json())["prompt_id"]
else:
raise RuntimeError("Couldn't get history")
raise RuntimeError("could not prompt")
msg: WSMessage
async for msg in ws:
# Handle incoming messages
if msg.type == aiohttp.WSMsgType.TEXT:
msg_json = msg.json()
if msg_json["type"] == "executing":
data = msg_json["data"]
if data['node'] is None and data['prompt_id'] == prompt_id:
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
break
elif msg.type == aiohttp.WSMsgType.ERROR:
break
async with self.session.get(urljoin(self.server_address, "/history")) as response:
if response.status == 200:
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
else:
raise RuntimeError("Couldn't get history")
# images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it
outputs_by_node_id = history_json[prompt_id].outputs
res: Dict[str, List[Path]] = {}
for node_id, output in outputs_by_node_id.items():
if 'images' in output:
images = []
image_dicts: List[dict] = output['images']
for image_file_output_dict in image_dicts:
image_file_output_dict = defaultdict(None, image_file_output_dict)
filename = image_file_output_dict['filename']
subfolder = image_file_output_dict['subfolder']
type = image_file_output_dict['type']
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
res[node_id] = images
return res
# images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it
outputs_by_node_id = history_json[prompt_id].outputs
res: dict[str, List[Path]] = {}
for node_id, output in outputs_by_node_id.items():
if 'images' in output:
images = []
image_dicts: List[dict] = output['images']
for image_file_output_dict in image_dicts:
image_file_output_dict = defaultdict(None, image_file_output_dict)
filename = image_file_output_dict['filename']
subfolder = image_file_output_dict['subfolder']
type = image_file_output_dict['type']
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
res[node_id] = images
return res

View File

@ -167,7 +167,7 @@ class EmbeddedComfyClient:
async def queue_prompt_api(self,
prompt: PromptDict) -> V1QueuePromptResponse:
outputs = await self.queue_prompt(prompt)
return V1QueuePromptResponse(**outputs)
return V1QueuePromptResponse(urls=[], outputs=outputs)
@tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self,

View File

@ -13,7 +13,7 @@ import sys
import traceback
import urllib
import uuid
from asyncio import Future, AbstractEventLoop
from asyncio import Future, AbstractEventLoop, Task
from enum import Enum
from io import BytesIO
from posixpath import join as urljoin
@ -190,7 +190,7 @@ class PromptServer(ExecutorToClientProgress):
self.number: int = 0
self.port: int = 8188
self._external_address: Optional[str] = None
self.receive_all_progress_notifications = True
self.background_tasks: dict[str, Task] = dict()
middlewares = [cache_control]
if args.enable_cors_header:
@ -726,11 +726,35 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(task.result().to_dict())
@routes.get("/api/v1/prompts/{prompt_id}")
async def get_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
prompt_id: str = request.match_info.get("prompt_id", "")
if prompt_id == "":
return web.Response(status=404)
history_items = self.prompt_queue.get_history(prompt_id)
if len(history_items) == 0:
# todo: this should really be moved to a stateful queue abstraction
if prompt_id in self.background_tasks:
return web.Response(status=204)
else:
# todo: this should check a stateful queue abstraction
return web.Response(status=404)
else:
history_entry = history_items[prompt_id]
return web.json_response(history_entry["outputs"])
@routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
# check if the queue is too long
accept = request.headers.get("accept", "application/json")
content_type = request.headers.get("content-type", "application/json")
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
if "+" in content_type:
content_type = content_type.split("+")[0]
wait = not "respond-async" in preferences
# check if the queue is too long
queue_size = self.prompt_queue.size()
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
if queue_size > queue_too_busy_size:
@ -778,17 +802,25 @@ class PromptServer(ExecutorToClientProgress):
result: TaskInvocation
completed: Future[TaskInvocation | dict] = self.loop.create_future()
item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
task_id = str(uuid.uuid4())
item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed)
try:
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
# this enables span propagation seamlessly
result = await self.prompt_queue.put_async(item)
if result is None:
return web.Response(body="the queue is shutting down", status=503)
fut = self.prompt_queue.put_async(item)
if wait:
result = await fut
if result is None:
return web.Response(body="the queue is shutting down", status=503)
else:
return await self._schedule_background_task_with_web_response(fut, task_id)
else:
self.prompt_queue.put(item)
await completed
if wait:
await completed
else:
return await self._schedule_background_task_with_web_response(completed, task_id)
task_invocation_or_dict: TaskInvocation | dict = completed.result()
if isinstance(task_invocation_or_dict, dict):
result = TaskInvocation(item_id=item.prompt_id, outputs=task_invocation_or_dict, status=ExecutionStatus("success", True, []))
@ -867,6 +899,18 @@ class PromptServer(ExecutorToClientProgress):
prompt = last_history_item['prompt'][2]
return web.json_response(prompt, status=200)
async def _schedule_background_task_with_web_response(self, fut, task_id):
task = asyncio.create_task(fut, name=task_id)
self.background_tasks[task_id] = task
task.add_done_callback(lambda _: self.background_tasks.pop(task_id))
# todo: type this from the OpenAPI spec
return web.json_response({
"prompt_id": task_id
}, status=202, headers={
"Location": f"api/v1/prompts/{task_id}",
"Retry-After": "60"
})
@property
def external_address(self):
return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}"
@ -875,6 +919,10 @@ class PromptServer(ExecutorToClientProgress):
def external_address(self, value):
self._external_address = value
@property
def receive_all_progress_notifications(self) -> bool:
return True
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
@ -989,11 +1037,11 @@ class PromptServer(ExecutorToClientProgress):
for addr in addresses:
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port)
site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size())
await site.start()
if not hasattr(self, 'address'):
self.address = address #TODO: remove this
self.address = address # TODO: remove this
self.port = port
if ':' in address:

View File

@ -80,7 +80,14 @@ class ExecutorToClientProgress(Protocol):
client_id: Optional[str]
last_node_id: Optional[str]
last_prompt_id: Optional[str]
receive_all_progress_notifications: Optional[bool]
@property
def receive_all_progress_notifications(self):
"""
Set to true if this should receive progress bar updates, in addition to the standard execution lifecycle messages
:return:
"""
return False
def send_sync(self,
event: SendSyncEvent,

View File

@ -28,7 +28,7 @@ def _get_name(queue_name: str, user_id: str) -> str:
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop):
self._rpc = rpc
self._queue_name = queue_name
self._loop = loop
@ -37,7 +37,10 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
self.node_id = None
self.last_node_id = None
self.last_prompt_id = None
self.receive_all_progress_notifications = receive_all_progress_notifications
@property
def receive_all_progress_notifications(self) -> bool:
return True
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:

View File

@ -119,6 +119,7 @@ class DistributedPromptWorker:
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
raise connection_error
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=1)
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
if self._embedded_comfy_client is None:

View File

@ -1,8 +1,8 @@
from __future__ import annotations
import copy
from typing import Optional, OrderedDict, List, Dict
import collections
import typing
from itertools import islice
from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE
@ -10,22 +10,22 @@ from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStat
class History:
def __init__(self):
self.history: OrderedDict[str, HistoryEntry] = collections.OrderedDict()
self.history: typing.OrderedDict[str, HistoryEntry] = collections.OrderedDict()
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
outputs=outputs,
status=ExecutionStatus(*status)._asdict())
def copy(self, prompt_id: Optional[str | int] = None, max_items: Optional[int] = None,
offset: Optional[int] = None) -> Dict[str, HistoryEntry]:
def copy(self, prompt_id: typing.Optional[str | int] = None, max_items: typing.Optional[int] = None,
offset: typing.Optional[int] = None) -> dict[str, HistoryEntry]:
if offset is not None and offset < 0:
offset = max(len(self.history) + offset, 0)
max_items = max_items or MAXIMUM_HISTORY_SIZE
if prompt_id in self.history:
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
else:
ordered_dict = OrderedDict()
ordered_dict = collections.OrderedDict()
for k in islice(self.history, offset, max_items):
ordered_dict[k] = copy.deepcopy(self.history[k])
return ordered_dict

View File

@ -16,7 +16,6 @@ class ServerStub(ExecutorToClientProgress):
self.client_id = str(uuid.uuid4())
self.last_node_id = None
self.last_prompt_id = None
self.receive_all_progress_notifications = False
def send_sync(self,
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
@ -25,3 +24,7 @@ class ServerStub(ExecutorToClientProgress):
def queue_updated(self, queue_remaining: Optional[int] = None):
pass
@property
def receive_all_progress_notifications(self) -> bool:
return False

View File

@ -173,10 +173,15 @@ try:
except:
pass
class _ComfyOutOfMemoryException(RuntimeError):
pass
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
OOM_EXCEPTION = _ComfyOutOfMemoryException
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True

View File

@ -1,4 +1,3 @@
import logging
import multiprocessing
import os
import pathlib
@ -87,11 +86,7 @@ def has_gpu() -> bool:
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
"""
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
:return:
"""
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
from huggingface_hub import hf_hub_download
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
@ -99,6 +94,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
executor_factory = request.param
processes_to_close: List[subprocess.Popen] = []
from testcontainers.rabbitmq import RabbitMqContainer
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
@ -115,15 +111,18 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
]
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
backend_command = [
"comfyui-worker",
"--port=9002",
f"-w={str(tmp_path)}",
f"--distributed-queue-connection-uri={connection_uri}",
f"--executor-factory={executor_factory}"
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
# Start multiple workers
for i in range(num_workers):
backend_command = [
"comfyui-worker",
f"--port={9002 + i}",
f"-w={str(tmp_path)}",
f"--distributed-queue-connection-uri={connection_uri}",
f"--executor-factory={executor_factory}"
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
try:
server_address = f"http://{get_lan_ip()}:9001"
start_time = time.time()
@ -134,10 +133,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
if response.status_code == 200:
connected = True
break
except ConnectionRefusedError:
except requests.exceptions.ConnectionError:
pass
except Exception as exc:
logging.warning("", exc_info=exc)
time.sleep(1)
if not connected:
raise RuntimeError("could not connect to frontend")

View File

@ -139,3 +139,101 @@ async def test_basic_queue_worker_with_health_check(executor_factory):
health_check_ok = await check_health(health_check_url)
assert health_check_ok, "Health check server did not start properly"
@pytest.mark.asyncio
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
# Create the client using the server address from the fixture
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
# Create a test prompt
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
# Queue the prompt
task_id = await client.queue_and_forget_prompt_api(prompt)
assert task_id is not None, "Failed to get a valid task ID"
# Poll for the result
max_attempts = 60 # Increase max attempts for integration test
poll_interval = 1 # Increase poll interval for integration test
for _ in range(max_attempts):
try:
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
if response.status == 200:
result = await response.json()
assert result is not None, "Received empty result"
# Find the first output node with images
output_node = next((node for node in result.values() if 'images' in node), None)
assert output_node is not None, "No output node with images found"
assert len(output_node['images']) > 0, "No images in output node"
assert 'filename' in output_node['images'][0], "No filename in image output"
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
assert 'type' in output_node['images'][0], "No type in image output"
# Check if we can access the image
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
image_response = await client.session.get(image_url)
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
return # Test passed
elif response.status == 204:
await asyncio.sleep(poll_interval)
else:
response.raise_for_status()
except Exception as e:
print(f"Error while polling: {e}")
await asyncio.sleep(poll_interval)
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
class TestWorker(DistributedPromptWorker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.processed_workflows: set[str] = set()
async def on_will_complete_work_item(self, request: dict):
workflow_id = request.get('prompt_id', 'unknown')
self.processed_workflows.add(workflow_id)
await super().on_will_complete_work_item(request)
@pytest.mark.asyncio
async def test_two_workers_distinct_requests():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
# Start two test workers
workers: list[TestWorker] = []
for i in range(2):
worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
await worker.init()
workers.append(worker)
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
queue = DistributedPromptQueue(is_callee=False, is_caller=True, connection_uri=connection_uri)
await queue.init()
# Submit two prompts
task1 = asyncio.create_task(queue.put_async(create_test_prompt()))
task2 = asyncio.create_task(queue.put_async(create_test_prompt()))
# Wait for tasks to complete
await asyncio.gather(task1, task2)
# Clean up
for worker in workers:
await worker.close()
await queue.close()
# Assert that each worker processed exactly one distinct workflow
all_workflows = set()
for worker in workers:
assert len(worker.processed_workflows) == 1, f"Worker processed {len(worker.processed_workflows)} workflows instead of 1"
all_workflows.update(worker.processed_workflows)
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"