mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
API now supports fire-and-forget, checking on queue status; prefetch_count now expressly set to 1 for workers
This commit is contained in:
parent
a664a1fbc9
commit
d25394d386
@ -331,9 +331,25 @@ paths:
|
||||
description: >-
|
||||
A POST request to /free with: {"unload_models":true} will unload models from vram.
|
||||
A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow.
|
||||
/api/v1/prompts/{prompt_id}:
|
||||
get:
|
||||
summary: (API) Get prompt status
|
||||
responses:
|
||||
204:
|
||||
description: |
|
||||
The prompt is still in progress
|
||||
200:
|
||||
description: |
|
||||
Prompt outputs
|
||||
content:
|
||||
application/json:
|
||||
$ref: "#/components/schemas/Outputs"
|
||||
404:
|
||||
description: |
|
||||
The prompt was not found
|
||||
/api/v1/prompts:
|
||||
get:
|
||||
summary: (API) Get prompt
|
||||
summary: (API) Get last prompt
|
||||
description: |
|
||||
Return the last prompt run anywhere that was used to produce an image
|
||||
|
||||
@ -395,56 +411,6 @@ paths:
|
||||
|
||||
For each SaveImage node, there will be two URLs: the internal URL returned by the worker, and
|
||||
the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration.
|
||||
|
||||
Hashing function for web browsers:
|
||||
|
||||
```js
|
||||
async function generateHash(body) {
|
||||
// Stringify and sort keys in the JSON object
|
||||
let str = JSON.stringify(body);
|
||||
|
||||
// Encode the string as a Uint8Array
|
||||
let encoder = new TextEncoder();
|
||||
let data = encoder.encode(str);
|
||||
|
||||
// Create a SHA-256 hash of the data
|
||||
let hash = await window.crypto.subtle.digest('SHA-256', data);
|
||||
|
||||
// Convert the hash (which is an ArrayBuffer) to a hex string
|
||||
let hashArray = Array.from(new Uint8Array(hash));
|
||||
let hashHex = hashArray.map(b => b.toString(16).padStart(2, '0')).join('');
|
||||
|
||||
return hashHex;
|
||||
}
|
||||
```
|
||||
|
||||
Hashing function for nodejs:
|
||||
|
||||
```js
|
||||
const crypto = require('crypto');
|
||||
|
||||
function generateHash(body) {
|
||||
// Stringify and sort keys in the JSON object
|
||||
let str = JSON.stringify(body);
|
||||
|
||||
// Create a SHA-256 hash of the string
|
||||
let hash = crypto.createHash('sha256');
|
||||
hash.update(str);
|
||||
|
||||
// Return the hexadecimal representation of the hash
|
||||
return hash.digest('hex');
|
||||
}
|
||||
```
|
||||
|
||||
Hashing function for python:
|
||||
```python
|
||||
def digest(data: dict | str) -> str:
|
||||
json_str = data if isinstance(data, str) else json.dumps(data, separators=(',', ':'))
|
||||
json_bytes = json_str.encode('utf-8')
|
||||
hash_object = hashlib.sha256(json_bytes)
|
||||
return hash_object.hexdigest()
|
||||
|
||||
```
|
||||
type: object
|
||||
required:
|
||||
- urls
|
||||
@ -475,6 +441,28 @@ paths:
|
||||
"^\\d+$":
|
||||
type: string
|
||||
format: binary
|
||||
202:
|
||||
description: |
|
||||
The prompt was successfully queued.
|
||||
content:
|
||||
application/json:
|
||||
description: Information about the item that was queued
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
prompt_id:
|
||||
type: string
|
||||
description: The ID of the prompt that was queued
|
||||
headers:
|
||||
Location:
|
||||
description: The relative URL to check on the status of the request
|
||||
schema:
|
||||
type: string
|
||||
Retry-After:
|
||||
description: |
|
||||
A hint for the number of seconds to check the provided Location for the status of your request.
|
||||
|
||||
This is the server's estimate for when the request will be completed.
|
||||
204:
|
||||
description: |
|
||||
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
|
||||
@ -517,11 +505,38 @@ paths:
|
||||
- "application/json"
|
||||
- "image/png"
|
||||
- "multipart/mixed"
|
||||
- "application/json+respond-async"
|
||||
- "image/png+respond-async"
|
||||
- "multipart/mixed+respond-async"
|
||||
required: false
|
||||
description: |
|
||||
Specifies the media type the client is willing to receive.
|
||||
|
||||
multipart/mixed will soon be supported to return all the images from the workflow.
|
||||
|
||||
If +respond-async is specified after your Accept mimetype, the request will be run async and you will get 202 when the prompt was queued.
|
||||
- in: header
|
||||
name: Prefer
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- "respond-async"
|
||||
- ""
|
||||
required: false
|
||||
allowEmptyValue: true
|
||||
description: |
|
||||
When respond-async is in your Prefer header, the request will be run async and you will get 202 when the prompt was queued.
|
||||
- in: path
|
||||
name: prefer
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- "respond-async"
|
||||
- ""
|
||||
required: false
|
||||
allowEmptyValue: true
|
||||
description: |
|
||||
When respond-async is in the prefer query parameter, the request will be run async and you will get 202 when the prompt was queued.
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
|
||||
@ -7,8 +7,7 @@ from typing import Optional, List
|
||||
from urllib.parse import urlparse, urljoin
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import WSMessage, ClientResponse
|
||||
from typing_extensions import Dict
|
||||
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.api_client import JSONEncoder
|
||||
@ -33,15 +32,46 @@ class AsyncRemoteComfyClient:
|
||||
self.websocket_address = websocket_address if websocket_address is not None else urljoin(
|
||||
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
if asyncio.get_event_loop() is not None:
|
||||
self._ensure_session()
|
||||
except RuntimeError as no_running_event_loop:
|
||||
pass
|
||||
|
||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None:
|
||||
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
|
||||
return self._session
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
return self._ensure_session()
|
||||
|
||||
async def len_queue(self) -> int:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application.json'}) as response:
|
||||
if response.status == 200:
|
||||
exec_info_dict = await response.json()
|
||||
return exec_info_dict["exec_info"]["queue_remaining"]
|
||||
else:
|
||||
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
|
||||
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
|
||||
if response.status == 200:
|
||||
exec_info_dict = await response.json()
|
||||
return exec_info_dict["exec_info"]["queue_remaining"]
|
||||
else:
|
||||
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
|
||||
|
||||
async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str:
|
||||
"""
|
||||
Calls the API to queue a prompt, and forgets about it
|
||||
:param prompt:
|
||||
:return: the task ID
|
||||
"""
|
||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||
response: ClientResponse
|
||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response:
|
||||
|
||||
if 200 <= response.status < 400:
|
||||
response_json = await response.json()
|
||||
return response_json["prompt_id"]
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||
|
||||
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
|
||||
"""
|
||||
@ -50,15 +80,14 @@ class AsyncRemoteComfyClient:
|
||||
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
|
||||
"""
|
||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
response: ClientResponse
|
||||
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
|
||||
response: ClientResponse
|
||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
|
||||
|
||||
if response.status == 200:
|
||||
return V1QueuePromptResponse(**(await response.json()))
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||
if 200 <= response.status < 400:
|
||||
return V1QueuePromptResponse(**(await response.json()))
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||
|
||||
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
|
||||
"""
|
||||
@ -68,24 +97,24 @@ class AsyncRemoteComfyClient:
|
||||
"""
|
||||
return (await self.queue_prompt_api(prompt)).urls
|
||||
|
||||
async def queue_prompt(self, prompt: PromptDict) -> bytes:
|
||||
async def queue_prompt(self, prompt: PromptDict) -> bytes | None:
|
||||
"""
|
||||
Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node.
|
||||
:param prompt:
|
||||
:return:
|
||||
"""
|
||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
response: ClientResponse
|
||||
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers={'Content-Type': 'application/json', 'Accept': 'image/png'}) as response:
|
||||
response: ClientResponse
|
||||
headers = {'Content-Type': 'application/json', 'Accept': 'image/png'}
|
||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers=headers) as response:
|
||||
|
||||
if 200 <= response.status < 400:
|
||||
return await response.read()
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||
if 200 <= response.status < 400:
|
||||
return await response.read()
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||
|
||||
async def queue_prompt_ui(self, prompt: PromptDict) -> Dict[str, List[Path]]:
|
||||
async def queue_prompt_ui(self, prompt: PromptDict) -> dict[str, List[Path]]:
|
||||
"""
|
||||
Uses the comfyui UI API calls to retrieve a list of paths of output files
|
||||
:param prompt:
|
||||
@ -93,46 +122,45 @@ class AsyncRemoteComfyClient:
|
||||
"""
|
||||
prompt_request = PromptRequest.validate({"prompt": prompt, "client_id": self.client_id})
|
||||
prompt_request_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt_request)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.ws_connect(self.websocket_address) as ws:
|
||||
async with session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json,
|
||||
headers={'Content-Type': 'application/json'}) as response:
|
||||
if response.status == 200:
|
||||
prompt_id = (await response.json())["prompt_id"]
|
||||
else:
|
||||
raise RuntimeError("could not prompt")
|
||||
msg: WSMessage
|
||||
async for msg in ws:
|
||||
# Handle incoming messages
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
msg_json = msg.json()
|
||||
if msg_json["type"] == "executing":
|
||||
data = msg_json["data"]
|
||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
break
|
||||
async with session.get(urljoin(self.server_address, "/history")) as response:
|
||||
async with self.session.ws_connect(self.websocket_address) as ws:
|
||||
async with self.session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json,
|
||||
headers={'Content-Type': 'application/json'}) as response:
|
||||
if response.status == 200:
|
||||
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
|
||||
prompt_id = (await response.json())["prompt_id"]
|
||||
else:
|
||||
raise RuntimeError("Couldn't get history")
|
||||
raise RuntimeError("could not prompt")
|
||||
msg: WSMessage
|
||||
async for msg in ws:
|
||||
# Handle incoming messages
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
msg_json = msg.json()
|
||||
if msg_json["type"] == "executing":
|
||||
data = msg_json["data"]
|
||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
break
|
||||
async with self.session.get(urljoin(self.server_address, "/history")) as response:
|
||||
if response.status == 200:
|
||||
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
|
||||
else:
|
||||
raise RuntimeError("Couldn't get history")
|
||||
|
||||
# images have filename, subfolder, type keys
|
||||
# todo: use the OpenAPI spec for this when I get around to updating it
|
||||
outputs_by_node_id = history_json[prompt_id].outputs
|
||||
res: Dict[str, List[Path]] = {}
|
||||
for node_id, output in outputs_by_node_id.items():
|
||||
if 'images' in output:
|
||||
images = []
|
||||
image_dicts: List[dict] = output['images']
|
||||
for image_file_output_dict in image_dicts:
|
||||
image_file_output_dict = defaultdict(None, image_file_output_dict)
|
||||
filename = image_file_output_dict['filename']
|
||||
subfolder = image_file_output_dict['subfolder']
|
||||
type = image_file_output_dict['type']
|
||||
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
|
||||
res[node_id] = images
|
||||
return res
|
||||
# images have filename, subfolder, type keys
|
||||
# todo: use the OpenAPI spec for this when I get around to updating it
|
||||
outputs_by_node_id = history_json[prompt_id].outputs
|
||||
res: dict[str, List[Path]] = {}
|
||||
for node_id, output in outputs_by_node_id.items():
|
||||
if 'images' in output:
|
||||
images = []
|
||||
image_dicts: List[dict] = output['images']
|
||||
for image_file_output_dict in image_dicts:
|
||||
image_file_output_dict = defaultdict(None, image_file_output_dict)
|
||||
filename = image_file_output_dict['filename']
|
||||
subfolder = image_file_output_dict['subfolder']
|
||||
type = image_file_output_dict['type']
|
||||
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
|
||||
res[node_id] = images
|
||||
return res
|
||||
|
||||
@ -167,7 +167,7 @@ class EmbeddedComfyClient:
|
||||
async def queue_prompt_api(self,
|
||||
prompt: PromptDict) -> V1QueuePromptResponse:
|
||||
outputs = await self.queue_prompt(prompt)
|
||||
return V1QueuePromptResponse(**outputs)
|
||||
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
||||
|
||||
@tracer.start_as_current_span("Queue Prompt")
|
||||
async def queue_prompt(self,
|
||||
|
||||
@ -13,7 +13,7 @@ import sys
|
||||
import traceback
|
||||
import urllib
|
||||
import uuid
|
||||
from asyncio import Future, AbstractEventLoop
|
||||
from asyncio import Future, AbstractEventLoop, Task
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from posixpath import join as urljoin
|
||||
@ -190,7 +190,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.number: int = 0
|
||||
self.port: int = 8188
|
||||
self._external_address: Optional[str] = None
|
||||
self.receive_all_progress_notifications = True
|
||||
self.background_tasks: dict[str, Task] = dict()
|
||||
|
||||
middlewares = [cache_control]
|
||||
if args.enable_cors_header:
|
||||
@ -726,11 +726,35 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
return web.json_response(task.result().to_dict())
|
||||
|
||||
@routes.get("/api/v1/prompts/{prompt_id}")
|
||||
async def get_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||
prompt_id: str = request.match_info.get("prompt_id", "")
|
||||
if prompt_id == "":
|
||||
return web.Response(status=404)
|
||||
|
||||
history_items = self.prompt_queue.get_history(prompt_id)
|
||||
if len(history_items) == 0:
|
||||
# todo: this should really be moved to a stateful queue abstraction
|
||||
if prompt_id in self.background_tasks:
|
||||
return web.Response(status=204)
|
||||
else:
|
||||
# todo: this should check a stateful queue abstraction
|
||||
return web.Response(status=404)
|
||||
else:
|
||||
history_entry = history_items[prompt_id]
|
||||
return web.json_response(history_entry["outputs"])
|
||||
|
||||
@routes.post("/api/v1/prompts")
|
||||
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||
# check if the queue is too long
|
||||
accept = request.headers.get("accept", "application/json")
|
||||
content_type = request.headers.get("content-type", "application/json")
|
||||
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
|
||||
if "+" in content_type:
|
||||
content_type = content_type.split("+")[0]
|
||||
|
||||
wait = not "respond-async" in preferences
|
||||
|
||||
# check if the queue is too long
|
||||
queue_size = self.prompt_queue.size()
|
||||
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
|
||||
if queue_size > queue_too_busy_size:
|
||||
@ -778,17 +802,25 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
result: TaskInvocation
|
||||
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
||||
item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
|
||||
task_id = str(uuid.uuid4())
|
||||
item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed)
|
||||
|
||||
try:
|
||||
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
|
||||
# this enables span propagation seamlessly
|
||||
result = await self.prompt_queue.put_async(item)
|
||||
if result is None:
|
||||
return web.Response(body="the queue is shutting down", status=503)
|
||||
fut = self.prompt_queue.put_async(item)
|
||||
if wait:
|
||||
result = await fut
|
||||
if result is None:
|
||||
return web.Response(body="the queue is shutting down", status=503)
|
||||
else:
|
||||
return await self._schedule_background_task_with_web_response(fut, task_id)
|
||||
else:
|
||||
self.prompt_queue.put(item)
|
||||
await completed
|
||||
if wait:
|
||||
await completed
|
||||
else:
|
||||
return await self._schedule_background_task_with_web_response(completed, task_id)
|
||||
task_invocation_or_dict: TaskInvocation | dict = completed.result()
|
||||
if isinstance(task_invocation_or_dict, dict):
|
||||
result = TaskInvocation(item_id=item.prompt_id, outputs=task_invocation_or_dict, status=ExecutionStatus("success", True, []))
|
||||
@ -867,6 +899,18 @@ class PromptServer(ExecutorToClientProgress):
|
||||
prompt = last_history_item['prompt'][2]
|
||||
return web.json_response(prompt, status=200)
|
||||
|
||||
async def _schedule_background_task_with_web_response(self, fut, task_id):
|
||||
task = asyncio.create_task(fut, name=task_id)
|
||||
self.background_tasks[task_id] = task
|
||||
task.add_done_callback(lambda _: self.background_tasks.pop(task_id))
|
||||
# todo: type this from the OpenAPI spec
|
||||
return web.json_response({
|
||||
"prompt_id": task_id
|
||||
}, status=202, headers={
|
||||
"Location": f"api/v1/prompts/{task_id}",
|
||||
"Retry-After": "60"
|
||||
})
|
||||
|
||||
@property
|
||||
def external_address(self):
|
||||
return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}"
|
||||
@ -875,6 +919,10 @@ class PromptServer(ExecutorToClientProgress):
|
||||
def external_address(self, value):
|
||||
self._external_address = value
|
||||
|
||||
@property
|
||||
def receive_all_progress_notifications(self) -> bool:
|
||||
return True
|
||||
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
@ -989,11 +1037,11 @@ class PromptServer(ExecutorToClientProgress):
|
||||
for addr in addresses:
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
site = web.TCPSite(runner, address, port)
|
||||
site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size())
|
||||
await site.start()
|
||||
|
||||
if not hasattr(self, 'address'):
|
||||
self.address = address #TODO: remove this
|
||||
self.address = address # TODO: remove this
|
||||
self.port = port
|
||||
|
||||
if ':' in address:
|
||||
|
||||
@ -80,7 +80,14 @@ class ExecutorToClientProgress(Protocol):
|
||||
client_id: Optional[str]
|
||||
last_node_id: Optional[str]
|
||||
last_prompt_id: Optional[str]
|
||||
receive_all_progress_notifications: Optional[bool]
|
||||
|
||||
@property
|
||||
def receive_all_progress_notifications(self):
|
||||
"""
|
||||
Set to true if this should receive progress bar updates, in addition to the standard execution lifecycle messages
|
||||
:return:
|
||||
"""
|
||||
return False
|
||||
|
||||
def send_sync(self,
|
||||
event: SendSyncEvent,
|
||||
|
||||
@ -28,7 +28,7 @@ def _get_name(queue_name: str, user_id: str) -> str:
|
||||
|
||||
|
||||
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
|
||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop):
|
||||
self._rpc = rpc
|
||||
self._queue_name = queue_name
|
||||
self._loop = loop
|
||||
@ -37,7 +37,10 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
self.node_id = None
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
self.receive_all_progress_notifications = receive_all_progress_notifications
|
||||
|
||||
@property
|
||||
def receive_all_progress_notifications(self) -> bool:
|
||||
return True
|
||||
|
||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
|
||||
@ -119,6 +119,7 @@ class DistributedPromptWorker:
|
||||
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
||||
raise connection_error
|
||||
self._channel = await self._connection.channel()
|
||||
await self._channel.set_qos(prefetch_count=1)
|
||||
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
||||
|
||||
if self._embedded_comfy_client is None:
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Optional, OrderedDict, List, Dict
|
||||
import collections
|
||||
import typing
|
||||
from itertools import islice
|
||||
|
||||
from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE
|
||||
@ -10,22 +10,22 @@ from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStat
|
||||
|
||||
class History:
|
||||
def __init__(self):
|
||||
self.history: OrderedDict[str, HistoryEntry] = collections.OrderedDict()
|
||||
self.history: typing.OrderedDict[str, HistoryEntry] = collections.OrderedDict()
|
||||
|
||||
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
|
||||
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
|
||||
outputs=outputs,
|
||||
status=ExecutionStatus(*status)._asdict())
|
||||
|
||||
def copy(self, prompt_id: Optional[str | int] = None, max_items: Optional[int] = None,
|
||||
offset: Optional[int] = None) -> Dict[str, HistoryEntry]:
|
||||
def copy(self, prompt_id: typing.Optional[str | int] = None, max_items: typing.Optional[int] = None,
|
||||
offset: typing.Optional[int] = None) -> dict[str, HistoryEntry]:
|
||||
if offset is not None and offset < 0:
|
||||
offset = max(len(self.history) + offset, 0)
|
||||
max_items = max_items or MAXIMUM_HISTORY_SIZE
|
||||
if prompt_id in self.history:
|
||||
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
||||
else:
|
||||
ordered_dict = OrderedDict()
|
||||
ordered_dict = collections.OrderedDict()
|
||||
for k in islice(self.history, offset, max_items):
|
||||
ordered_dict[k] = copy.deepcopy(self.history[k])
|
||||
return ordered_dict
|
||||
|
||||
@ -16,7 +16,6 @@ class ServerStub(ExecutorToClientProgress):
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
self.receive_all_progress_notifications = False
|
||||
|
||||
def send_sync(self,
|
||||
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
||||
@ -25,3 +24,7 @@ class ServerStub(ExecutorToClientProgress):
|
||||
|
||||
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@property
|
||||
def receive_all_progress_notifications(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -173,10 +173,15 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class _ComfyOutOfMemoryException(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||
except:
|
||||
OOM_EXCEPTION = Exception
|
||||
OOM_EXCEPTION = _ComfyOutOfMemoryException
|
||||
|
||||
XFORMERS_VERSION = ""
|
||||
XFORMERS_ENABLED_VAE = True
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
@ -87,11 +86,7 @@ def has_gpu() -> bool:
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
||||
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
||||
"""
|
||||
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
||||
:return:
|
||||
"""
|
||||
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
|
||||
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
||||
@ -99,6 +94,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||
executor_factory = request.param
|
||||
processes_to_close: List[subprocess.Popen] = []
|
||||
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
@ -115,15 +111,18 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
||||
]
|
||||
|
||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
backend_command = [
|
||||
"comfyui-worker",
|
||||
"--port=9002",
|
||||
f"-w={str(tmp_path)}",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
f"--executor-factory={executor_factory}"
|
||||
]
|
||||
|
||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
# Start multiple workers
|
||||
for i in range(num_workers):
|
||||
backend_command = [
|
||||
"comfyui-worker",
|
||||
f"--port={9002 + i}",
|
||||
f"-w={str(tmp_path)}",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
f"--executor-factory={executor_factory}"
|
||||
]
|
||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
|
||||
try:
|
||||
server_address = f"http://{get_lan_ip()}:9001"
|
||||
start_time = time.time()
|
||||
@ -134,10 +133,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
||||
if response.status_code == 200:
|
||||
connected = True
|
||||
break
|
||||
except ConnectionRefusedError:
|
||||
except requests.exceptions.ConnectionError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logging.warning("", exc_info=exc)
|
||||
time.sleep(1)
|
||||
if not connected:
|
||||
raise RuntimeError("could not connect to frontend")
|
||||
|
||||
@ -139,3 +139,101 @@ async def test_basic_queue_worker_with_health_check(executor_factory):
|
||||
|
||||
health_check_ok = await check_health(health_check_url)
|
||||
assert health_check_ok, "Health check server did not start properly"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
|
||||
# Create the client using the server address from the fixture
|
||||
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
||||
|
||||
# Create a test prompt
|
||||
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
|
||||
|
||||
# Queue the prompt
|
||||
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||
|
||||
assert task_id is not None, "Failed to get a valid task ID"
|
||||
|
||||
# Poll for the result
|
||||
max_attempts = 60 # Increase max attempts for integration test
|
||||
poll_interval = 1 # Increase poll interval for integration test
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
assert result is not None, "Received empty result"
|
||||
|
||||
# Find the first output node with images
|
||||
output_node = next((node for node in result.values() if 'images' in node), None)
|
||||
assert output_node is not None, "No output node with images found"
|
||||
|
||||
assert len(output_node['images']) > 0, "No images in output node"
|
||||
assert 'filename' in output_node['images'][0], "No filename in image output"
|
||||
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
|
||||
assert 'type' in output_node['images'][0], "No type in image output"
|
||||
|
||||
# Check if we can access the image
|
||||
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
|
||||
image_response = await client.session.get(image_url)
|
||||
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
|
||||
|
||||
return # Test passed
|
||||
elif response.status == 204:
|
||||
await asyncio.sleep(poll_interval)
|
||||
else:
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
print(f"Error while polling: {e}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
|
||||
|
||||
|
||||
class TestWorker(DistributedPromptWorker):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.processed_workflows: set[str] = set()
|
||||
|
||||
async def on_will_complete_work_item(self, request: dict):
|
||||
workflow_id = request.get('prompt_id', 'unknown')
|
||||
self.processed_workflows.add(workflow_id)
|
||||
await super().on_will_complete_work_item(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_two_workers_distinct_requests():
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
# Start two test workers
|
||||
workers: list[TestWorker] = []
|
||||
for i in range(2):
|
||||
worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
|
||||
await worker.init()
|
||||
workers.append(worker)
|
||||
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
queue = DistributedPromptQueue(is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||
await queue.init()
|
||||
|
||||
# Submit two prompts
|
||||
task1 = asyncio.create_task(queue.put_async(create_test_prompt()))
|
||||
task2 = asyncio.create_task(queue.put_async(create_test_prompt()))
|
||||
|
||||
# Wait for tasks to complete
|
||||
await asyncio.gather(task1, task2)
|
||||
|
||||
# Clean up
|
||||
for worker in workers:
|
||||
await worker.close()
|
||||
await queue.close()
|
||||
|
||||
# Assert that each worker processed exactly one distinct workflow
|
||||
all_workflows = set()
|
||||
for worker in workers:
|
||||
assert len(worker.processed_workflows) == 1, f"Worker processed {len(worker.processed_workflows)} workflows instead of 1"
|
||||
all_workflows.update(worker.processed_workflows)
|
||||
|
||||
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user