mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
API now supports fire-and-forget, checking on queue status; prefetch_count now expressly set to 1 for workers
This commit is contained in:
parent
a664a1fbc9
commit
d25394d386
@ -331,9 +331,25 @@ paths:
|
|||||||
description: >-
|
description: >-
|
||||||
A POST request to /free with: {"unload_models":true} will unload models from vram.
|
A POST request to /free with: {"unload_models":true} will unload models from vram.
|
||||||
A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow.
|
A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow.
|
||||||
|
/api/v1/prompts/{prompt_id}:
|
||||||
|
get:
|
||||||
|
summary: (API) Get prompt status
|
||||||
|
responses:
|
||||||
|
204:
|
||||||
|
description: |
|
||||||
|
The prompt is still in progress
|
||||||
|
200:
|
||||||
|
description: |
|
||||||
|
Prompt outputs
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
$ref: "#/components/schemas/Outputs"
|
||||||
|
404:
|
||||||
|
description: |
|
||||||
|
The prompt was not found
|
||||||
/api/v1/prompts:
|
/api/v1/prompts:
|
||||||
get:
|
get:
|
||||||
summary: (API) Get prompt
|
summary: (API) Get last prompt
|
||||||
description: |
|
description: |
|
||||||
Return the last prompt run anywhere that was used to produce an image
|
Return the last prompt run anywhere that was used to produce an image
|
||||||
|
|
||||||
@ -395,56 +411,6 @@ paths:
|
|||||||
|
|
||||||
For each SaveImage node, there will be two URLs: the internal URL returned by the worker, and
|
For each SaveImage node, there will be two URLs: the internal URL returned by the worker, and
|
||||||
the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration.
|
the URL for the image based on the `--external-address` / `COMFYUI_EXTERNAL_ADDRESS` configuration.
|
||||||
|
|
||||||
Hashing function for web browsers:
|
|
||||||
|
|
||||||
```js
|
|
||||||
async function generateHash(body) {
|
|
||||||
// Stringify and sort keys in the JSON object
|
|
||||||
let str = JSON.stringify(body);
|
|
||||||
|
|
||||||
// Encode the string as a Uint8Array
|
|
||||||
let encoder = new TextEncoder();
|
|
||||||
let data = encoder.encode(str);
|
|
||||||
|
|
||||||
// Create a SHA-256 hash of the data
|
|
||||||
let hash = await window.crypto.subtle.digest('SHA-256', data);
|
|
||||||
|
|
||||||
// Convert the hash (which is an ArrayBuffer) to a hex string
|
|
||||||
let hashArray = Array.from(new Uint8Array(hash));
|
|
||||||
let hashHex = hashArray.map(b => b.toString(16).padStart(2, '0')).join('');
|
|
||||||
|
|
||||||
return hashHex;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Hashing function for nodejs:
|
|
||||||
|
|
||||||
```js
|
|
||||||
const crypto = require('crypto');
|
|
||||||
|
|
||||||
function generateHash(body) {
|
|
||||||
// Stringify and sort keys in the JSON object
|
|
||||||
let str = JSON.stringify(body);
|
|
||||||
|
|
||||||
// Create a SHA-256 hash of the string
|
|
||||||
let hash = crypto.createHash('sha256');
|
|
||||||
hash.update(str);
|
|
||||||
|
|
||||||
// Return the hexadecimal representation of the hash
|
|
||||||
return hash.digest('hex');
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Hashing function for python:
|
|
||||||
```python
|
|
||||||
def digest(data: dict | str) -> str:
|
|
||||||
json_str = data if isinstance(data, str) else json.dumps(data, separators=(',', ':'))
|
|
||||||
json_bytes = json_str.encode('utf-8')
|
|
||||||
hash_object = hashlib.sha256(json_bytes)
|
|
||||||
return hash_object.hexdigest()
|
|
||||||
|
|
||||||
```
|
|
||||||
type: object
|
type: object
|
||||||
required:
|
required:
|
||||||
- urls
|
- urls
|
||||||
@ -475,6 +441,28 @@ paths:
|
|||||||
"^\\d+$":
|
"^\\d+$":
|
||||||
type: string
|
type: string
|
||||||
format: binary
|
format: binary
|
||||||
|
202:
|
||||||
|
description: |
|
||||||
|
The prompt was successfully queued.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
description: Information about the item that was queued
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
prompt_id:
|
||||||
|
type: string
|
||||||
|
description: The ID of the prompt that was queued
|
||||||
|
headers:
|
||||||
|
Location:
|
||||||
|
description: The relative URL to check on the status of the request
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
Retry-After:
|
||||||
|
description: |
|
||||||
|
A hint for the number of seconds to check the provided Location for the status of your request.
|
||||||
|
|
||||||
|
This is the server's estimate for when the request will be completed.
|
||||||
204:
|
204:
|
||||||
description: |
|
description: |
|
||||||
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
|
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
|
||||||
@ -517,11 +505,38 @@ paths:
|
|||||||
- "application/json"
|
- "application/json"
|
||||||
- "image/png"
|
- "image/png"
|
||||||
- "multipart/mixed"
|
- "multipart/mixed"
|
||||||
|
- "application/json+respond-async"
|
||||||
|
- "image/png+respond-async"
|
||||||
|
- "multipart/mixed+respond-async"
|
||||||
required: false
|
required: false
|
||||||
description: |
|
description: |
|
||||||
Specifies the media type the client is willing to receive.
|
Specifies the media type the client is willing to receive.
|
||||||
|
|
||||||
multipart/mixed will soon be supported to return all the images from the workflow.
|
multipart/mixed will soon be supported to return all the images from the workflow.
|
||||||
|
|
||||||
|
If +respond-async is specified after your Accept mimetype, the request will be run async and you will get 202 when the prompt was queued.
|
||||||
|
- in: header
|
||||||
|
name: Prefer
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- "respond-async"
|
||||||
|
- ""
|
||||||
|
required: false
|
||||||
|
allowEmptyValue: true
|
||||||
|
description: |
|
||||||
|
When respond-async is in your Prefer header, the request will be run async and you will get 202 when the prompt was queued.
|
||||||
|
- in: path
|
||||||
|
name: prefer
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- "respond-async"
|
||||||
|
- ""
|
||||||
|
required: false
|
||||||
|
allowEmptyValue: true
|
||||||
|
description: |
|
||||||
|
When respond-async is in the prefer query parameter, the request will be run async and you will get 202 when the prompt was queued.
|
||||||
requestBody:
|
requestBody:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
|
|||||||
@ -7,8 +7,7 @@ from typing import Optional, List
|
|||||||
from urllib.parse import urlparse, urljoin
|
from urllib.parse import urlparse, urljoin
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import WSMessage, ClientResponse
|
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||||
from typing_extensions import Dict
|
|
||||||
|
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.api_client import JSONEncoder
|
from ..api.api_client import JSONEncoder
|
||||||
@ -33,15 +32,46 @@ class AsyncRemoteComfyClient:
|
|||||||
self.websocket_address = websocket_address if websocket_address is not None else urljoin(
|
self.websocket_address = websocket_address if websocket_address is not None else urljoin(
|
||||||
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
f"ws://{server_address_url.hostname}:{server_address_url.port}", f"/ws?clientId={client_id}")
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
try:
|
||||||
|
if asyncio.get_event_loop() is not None:
|
||||||
|
self._ensure_session()
|
||||||
|
except RuntimeError as no_running_event_loop:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None:
|
||||||
|
self._session = aiohttp.ClientSession(timeout=ClientTimeout(total=10 * 60.0, connect=60.0))
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> aiohttp.ClientSession:
|
||||||
|
return self._ensure_session()
|
||||||
|
|
||||||
async def len_queue(self) -> int:
|
async def len_queue(self) -> int:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
|
||||||
async with session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application.json'}) as response:
|
if response.status == 200:
|
||||||
if response.status == 200:
|
exec_info_dict = await response.json()
|
||||||
exec_info_dict = await response.json()
|
return exec_info_dict["exec_info"]["queue_remaining"]
|
||||||
return exec_info_dict["exec_info"]["queue_remaining"]
|
else:
|
||||||
else:
|
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
|
||||||
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
|
|
||||||
|
async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str:
|
||||||
|
"""
|
||||||
|
Calls the API to queue a prompt, and forgets about it
|
||||||
|
:param prompt:
|
||||||
|
:return: the task ID
|
||||||
|
"""
|
||||||
|
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||||
|
response: ClientResponse
|
||||||
|
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||||
|
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response:
|
||||||
|
|
||||||
|
if 200 <= response.status < 400:
|
||||||
|
response_json = await response.json()
|
||||||
|
return response_json["prompt_id"]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
|
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
|
||||||
"""
|
"""
|
||||||
@ -50,15 +80,14 @@ class AsyncRemoteComfyClient:
|
|||||||
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
|
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
|
||||||
"""
|
"""
|
||||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||||
async with aiohttp.ClientSession() as session:
|
response: ClientResponse
|
||||||
response: ClientResponse
|
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||||
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
|
||||||
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
|
|
||||||
|
|
||||||
if response.status == 200:
|
if 200 <= response.status < 400:
|
||||||
return V1QueuePromptResponse(**(await response.json()))
|
return V1QueuePromptResponse(**(await response.json()))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
|
async def queue_prompt_uris(self, prompt: PromptDict) -> List[str]:
|
||||||
"""
|
"""
|
||||||
@ -68,24 +97,24 @@ class AsyncRemoteComfyClient:
|
|||||||
"""
|
"""
|
||||||
return (await self.queue_prompt_api(prompt)).urls
|
return (await self.queue_prompt_api(prompt)).urls
|
||||||
|
|
||||||
async def queue_prompt(self, prompt: PromptDict) -> bytes:
|
async def queue_prompt(self, prompt: PromptDict) -> bytes | None:
|
||||||
"""
|
"""
|
||||||
Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node.
|
Calls the API to queue a prompt. Returns the bytes of the first PNG returned by a SaveImage node.
|
||||||
:param prompt:
|
:param prompt:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||||
async with aiohttp.ClientSession() as session:
|
response: ClientResponse
|
||||||
response: ClientResponse
|
headers = {'Content-Type': 'application/json', 'Accept': 'image/png'}
|
||||||
async with session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||||
headers={'Content-Type': 'application/json', 'Accept': 'image/png'}) as response:
|
headers=headers) as response:
|
||||||
|
|
||||||
if 200 <= response.status < 400:
|
if 200 <= response.status < 400:
|
||||||
return await response.read()
|
return await response.read()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_ui(self, prompt: PromptDict) -> Dict[str, List[Path]]:
|
async def queue_prompt_ui(self, prompt: PromptDict) -> dict[str, List[Path]]:
|
||||||
"""
|
"""
|
||||||
Uses the comfyui UI API calls to retrieve a list of paths of output files
|
Uses the comfyui UI API calls to retrieve a list of paths of output files
|
||||||
:param prompt:
|
:param prompt:
|
||||||
@ -93,46 +122,45 @@ class AsyncRemoteComfyClient:
|
|||||||
"""
|
"""
|
||||||
prompt_request = PromptRequest.validate({"prompt": prompt, "client_id": self.client_id})
|
prompt_request = PromptRequest.validate({"prompt": prompt, "client_id": self.client_id})
|
||||||
prompt_request_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt_request)
|
prompt_request_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt_request)
|
||||||
async with aiohttp.ClientSession() as session:
|
async with self.session.ws_connect(self.websocket_address) as ws:
|
||||||
async with session.ws_connect(self.websocket_address) as ws:
|
async with self.session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json,
|
||||||
async with session.post(urljoin(self.server_address, "/prompt"), data=prompt_request_json,
|
headers={'Content-Type': 'application/json'}) as response:
|
||||||
headers={'Content-Type': 'application/json'}) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
prompt_id = (await response.json())["prompt_id"]
|
|
||||||
else:
|
|
||||||
raise RuntimeError("could not prompt")
|
|
||||||
msg: WSMessage
|
|
||||||
async for msg in ws:
|
|
||||||
# Handle incoming messages
|
|
||||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
||||||
msg_json = msg.json()
|
|
||||||
if msg_json["type"] == "executing":
|
|
||||||
data = msg_json["data"]
|
|
||||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
|
||||||
break
|
|
||||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
||||||
break
|
|
||||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
||||||
break
|
|
||||||
async with session.get(urljoin(self.server_address, "/history")) as response:
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
|
prompt_id = (await response.json())["prompt_id"]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Couldn't get history")
|
raise RuntimeError("could not prompt")
|
||||||
|
msg: WSMessage
|
||||||
|
async for msg in ws:
|
||||||
|
# Handle incoming messages
|
||||||
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
msg_json = msg.json()
|
||||||
|
if msg_json["type"] == "executing":
|
||||||
|
data = msg_json["data"]
|
||||||
|
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||||
|
break
|
||||||
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||||
|
break
|
||||||
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
break
|
||||||
|
async with self.session.get(urljoin(self.server_address, "/history")) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Couldn't get history")
|
||||||
|
|
||||||
# images have filename, subfolder, type keys
|
# images have filename, subfolder, type keys
|
||||||
# todo: use the OpenAPI spec for this when I get around to updating it
|
# todo: use the OpenAPI spec for this when I get around to updating it
|
||||||
outputs_by_node_id = history_json[prompt_id].outputs
|
outputs_by_node_id = history_json[prompt_id].outputs
|
||||||
res: Dict[str, List[Path]] = {}
|
res: dict[str, List[Path]] = {}
|
||||||
for node_id, output in outputs_by_node_id.items():
|
for node_id, output in outputs_by_node_id.items():
|
||||||
if 'images' in output:
|
if 'images' in output:
|
||||||
images = []
|
images = []
|
||||||
image_dicts: List[dict] = output['images']
|
image_dicts: List[dict] = output['images']
|
||||||
for image_file_output_dict in image_dicts:
|
for image_file_output_dict in image_dicts:
|
||||||
image_file_output_dict = defaultdict(None, image_file_output_dict)
|
image_file_output_dict = defaultdict(None, image_file_output_dict)
|
||||||
filename = image_file_output_dict['filename']
|
filename = image_file_output_dict['filename']
|
||||||
subfolder = image_file_output_dict['subfolder']
|
subfolder = image_file_output_dict['subfolder']
|
||||||
type = image_file_output_dict['type']
|
type = image_file_output_dict['type']
|
||||||
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
|
images.append(Path(file_output_path(filename, subfolder=subfolder, type=type)))
|
||||||
res[node_id] = images
|
res[node_id] = images
|
||||||
return res
|
return res
|
||||||
|
|||||||
@ -167,7 +167,7 @@ class EmbeddedComfyClient:
|
|||||||
async def queue_prompt_api(self,
|
async def queue_prompt_api(self,
|
||||||
prompt: PromptDict) -> V1QueuePromptResponse:
|
prompt: PromptDict) -> V1QueuePromptResponse:
|
||||||
outputs = await self.queue_prompt(prompt)
|
outputs = await self.queue_prompt(prompt)
|
||||||
return V1QueuePromptResponse(**outputs)
|
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
||||||
|
|
||||||
@tracer.start_as_current_span("Queue Prompt")
|
@tracer.start_as_current_span("Queue Prompt")
|
||||||
async def queue_prompt(self,
|
async def queue_prompt(self,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
import urllib
|
import urllib
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import Future, AbstractEventLoop
|
from asyncio import Future, AbstractEventLoop, Task
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from posixpath import join as urljoin
|
from posixpath import join as urljoin
|
||||||
@ -190,7 +190,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self.number: int = 0
|
self.number: int = 0
|
||||||
self.port: int = 8188
|
self.port: int = 8188
|
||||||
self._external_address: Optional[str] = None
|
self._external_address: Optional[str] = None
|
||||||
self.receive_all_progress_notifications = True
|
self.background_tasks: dict[str, Task] = dict()
|
||||||
|
|
||||||
middlewares = [cache_control]
|
middlewares = [cache_control]
|
||||||
if args.enable_cors_header:
|
if args.enable_cors_header:
|
||||||
@ -726,11 +726,35 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
return web.json_response(task.result().to_dict())
|
return web.json_response(task.result().to_dict())
|
||||||
|
|
||||||
|
@routes.get("/api/v1/prompts/{prompt_id}")
|
||||||
|
async def get_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||||
|
prompt_id: str = request.match_info.get("prompt_id", "")
|
||||||
|
if prompt_id == "":
|
||||||
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
history_items = self.prompt_queue.get_history(prompt_id)
|
||||||
|
if len(history_items) == 0:
|
||||||
|
# todo: this should really be moved to a stateful queue abstraction
|
||||||
|
if prompt_id in self.background_tasks:
|
||||||
|
return web.Response(status=204)
|
||||||
|
else:
|
||||||
|
# todo: this should check a stateful queue abstraction
|
||||||
|
return web.Response(status=404)
|
||||||
|
else:
|
||||||
|
history_entry = history_items[prompt_id]
|
||||||
|
return web.json_response(history_entry["outputs"])
|
||||||
|
|
||||||
@routes.post("/api/v1/prompts")
|
@routes.post("/api/v1/prompts")
|
||||||
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||||
# check if the queue is too long
|
|
||||||
accept = request.headers.get("accept", "application/json")
|
accept = request.headers.get("accept", "application/json")
|
||||||
content_type = request.headers.get("content-type", "application/json")
|
content_type = request.headers.get("content-type", "application/json")
|
||||||
|
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
|
||||||
|
if "+" in content_type:
|
||||||
|
content_type = content_type.split("+")[0]
|
||||||
|
|
||||||
|
wait = not "respond-async" in preferences
|
||||||
|
|
||||||
|
# check if the queue is too long
|
||||||
queue_size = self.prompt_queue.size()
|
queue_size = self.prompt_queue.size()
|
||||||
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
|
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
|
||||||
if queue_size > queue_too_busy_size:
|
if queue_size > queue_too_busy_size:
|
||||||
@ -778,17 +802,25 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
result: TaskInvocation
|
result: TaskInvocation
|
||||||
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
completed: Future[TaskInvocation | dict] = self.loop.create_future()
|
||||||
item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
|
task_id = str(uuid.uuid4())
|
||||||
|
item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
|
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
|
||||||
# this enables span propagation seamlessly
|
# this enables span propagation seamlessly
|
||||||
result = await self.prompt_queue.put_async(item)
|
fut = self.prompt_queue.put_async(item)
|
||||||
if result is None:
|
if wait:
|
||||||
return web.Response(body="the queue is shutting down", status=503)
|
result = await fut
|
||||||
|
if result is None:
|
||||||
|
return web.Response(body="the queue is shutting down", status=503)
|
||||||
|
else:
|
||||||
|
return await self._schedule_background_task_with_web_response(fut, task_id)
|
||||||
else:
|
else:
|
||||||
self.prompt_queue.put(item)
|
self.prompt_queue.put(item)
|
||||||
await completed
|
if wait:
|
||||||
|
await completed
|
||||||
|
else:
|
||||||
|
return await self._schedule_background_task_with_web_response(completed, task_id)
|
||||||
task_invocation_or_dict: TaskInvocation | dict = completed.result()
|
task_invocation_or_dict: TaskInvocation | dict = completed.result()
|
||||||
if isinstance(task_invocation_or_dict, dict):
|
if isinstance(task_invocation_or_dict, dict):
|
||||||
result = TaskInvocation(item_id=item.prompt_id, outputs=task_invocation_or_dict, status=ExecutionStatus("success", True, []))
|
result = TaskInvocation(item_id=item.prompt_id, outputs=task_invocation_or_dict, status=ExecutionStatus("success", True, []))
|
||||||
@ -867,6 +899,18 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
prompt = last_history_item['prompt'][2]
|
prompt = last_history_item['prompt'][2]
|
||||||
return web.json_response(prompt, status=200)
|
return web.json_response(prompt, status=200)
|
||||||
|
|
||||||
|
async def _schedule_background_task_with_web_response(self, fut, task_id):
|
||||||
|
task = asyncio.create_task(fut, name=task_id)
|
||||||
|
self.background_tasks[task_id] = task
|
||||||
|
task.add_done_callback(lambda _: self.background_tasks.pop(task_id))
|
||||||
|
# todo: type this from the OpenAPI spec
|
||||||
|
return web.json_response({
|
||||||
|
"prompt_id": task_id
|
||||||
|
}, status=202, headers={
|
||||||
|
"Location": f"api/v1/prompts/{task_id}",
|
||||||
|
"Retry-After": "60"
|
||||||
|
})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def external_address(self):
|
def external_address(self):
|
||||||
return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}"
|
return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}"
|
||||||
@ -875,6 +919,10 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
def external_address(self, value):
|
def external_address(self, value):
|
||||||
self._external_address = value
|
self._external_address = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def receive_all_progress_notifications(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
@ -989,11 +1037,11 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
for addr in addresses:
|
for addr in addresses:
|
||||||
address = addr[0]
|
address = addr[0]
|
||||||
port = addr[1]
|
port = addr[1]
|
||||||
site = web.TCPSite(runner, address, port)
|
site = web.TCPSite(runner, address, port, backlog=PromptServer.get_too_busy_queue_size())
|
||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
if not hasattr(self, 'address'):
|
if not hasattr(self, 'address'):
|
||||||
self.address = address #TODO: remove this
|
self.address = address # TODO: remove this
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
if ':' in address:
|
if ':' in address:
|
||||||
|
|||||||
@ -80,7 +80,14 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
client_id: Optional[str]
|
client_id: Optional[str]
|
||||||
last_node_id: Optional[str]
|
last_node_id: Optional[str]
|
||||||
last_prompt_id: Optional[str]
|
last_prompt_id: Optional[str]
|
||||||
receive_all_progress_notifications: Optional[bool]
|
|
||||||
|
@property
|
||||||
|
def receive_all_progress_notifications(self):
|
||||||
|
"""
|
||||||
|
Set to true if this should receive progress bar updates, in addition to the standard execution lifecycle messages
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
def send_sync(self,
|
def send_sync(self,
|
||||||
event: SendSyncEvent,
|
event: SendSyncEvent,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def _get_name(queue_name: str, user_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||||
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop, receive_all_progress_notifications=True):
|
def __init__(self, rpc: RPC, queue_name: str, loop: AbstractEventLoop):
|
||||||
self._rpc = rpc
|
self._rpc = rpc
|
||||||
self._queue_name = queue_name
|
self._queue_name = queue_name
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
@ -37,7 +37,10 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
|||||||
self.node_id = None
|
self.node_id = None
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.last_prompt_id = None
|
self.last_prompt_id = None
|
||||||
self.receive_all_progress_notifications = receive_all_progress_notifications
|
|
||||||
|
@property
|
||||||
|
def receive_all_progress_notifications(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||||
|
|||||||
@ -119,6 +119,7 @@ class DistributedPromptWorker:
|
|||||||
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
logging.error(f"failed to connect to self._connection_uri={self._connection_uri}", connection_error)
|
||||||
raise connection_error
|
raise connection_error
|
||||||
self._channel = await self._connection.channel()
|
self._channel = await self._connection.channel()
|
||||||
|
await self._channel.set_qos(prefetch_count=1)
|
||||||
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
||||||
|
|
||||||
if self._embedded_comfy_client is None:
|
if self._embedded_comfy_client is None:
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Optional, OrderedDict, List, Dict
|
|
||||||
import collections
|
import collections
|
||||||
|
import typing
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE
|
from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStatus, MAXIMUM_HISTORY_SIZE
|
||||||
@ -10,22 +10,22 @@ from ..component_model.queue_types import HistoryEntry, QueueItem, ExecutionStat
|
|||||||
|
|
||||||
class History:
|
class History:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.history: OrderedDict[str, HistoryEntry] = collections.OrderedDict()
|
self.history: typing.OrderedDict[str, HistoryEntry] = collections.OrderedDict()
|
||||||
|
|
||||||
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
|
def put(self, queue_item: QueueItem, outputs: dict, status: ExecutionStatus):
|
||||||
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
|
self.history[queue_item.prompt_id] = HistoryEntry(prompt=queue_item.queue_tuple,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
status=ExecutionStatus(*status)._asdict())
|
status=ExecutionStatus(*status)._asdict())
|
||||||
|
|
||||||
def copy(self, prompt_id: Optional[str | int] = None, max_items: Optional[int] = None,
|
def copy(self, prompt_id: typing.Optional[str | int] = None, max_items: typing.Optional[int] = None,
|
||||||
offset: Optional[int] = None) -> Dict[str, HistoryEntry]:
|
offset: typing.Optional[int] = None) -> dict[str, HistoryEntry]:
|
||||||
if offset is not None and offset < 0:
|
if offset is not None and offset < 0:
|
||||||
offset = max(len(self.history) + offset, 0)
|
offset = max(len(self.history) + offset, 0)
|
||||||
max_items = max_items or MAXIMUM_HISTORY_SIZE
|
max_items = max_items or MAXIMUM_HISTORY_SIZE
|
||||||
if prompt_id in self.history:
|
if prompt_id in self.history:
|
||||||
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
||||||
else:
|
else:
|
||||||
ordered_dict = OrderedDict()
|
ordered_dict = collections.OrderedDict()
|
||||||
for k in islice(self.history, offset, max_items):
|
for k in islice(self.history, offset, max_items):
|
||||||
ordered_dict[k] = copy.deepcopy(self.history[k])
|
ordered_dict[k] = copy.deepcopy(self.history[k])
|
||||||
return ordered_dict
|
return ordered_dict
|
||||||
|
|||||||
@ -16,7 +16,6 @@ class ServerStub(ExecutorToClientProgress):
|
|||||||
self.client_id = str(uuid.uuid4())
|
self.client_id = str(uuid.uuid4())
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
self.last_prompt_id = None
|
self.last_prompt_id = None
|
||||||
self.receive_all_progress_notifications = False
|
|
||||||
|
|
||||||
def send_sync(self,
|
def send_sync(self,
|
||||||
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
event: Literal["status", "executing"] | BinaryEventTypes | str | None,
|
||||||
@ -25,3 +24,7 @@ class ServerStub(ExecutorToClientProgress):
|
|||||||
|
|
||||||
def queue_updated(self, queue_remaining: Optional[int] = None):
|
def queue_updated(self, queue_remaining: Optional[int] = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def receive_all_progress_notifications(self) -> bool:
|
||||||
|
return False
|
||||||
|
|||||||
@ -173,10 +173,15 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _ComfyOutOfMemoryException(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
|
||||||
except:
|
except:
|
||||||
OOM_EXCEPTION = Exception
|
OOM_EXCEPTION = _ComfyOutOfMemoryException
|
||||||
|
|
||||||
XFORMERS_VERSION = ""
|
XFORMERS_VERSION = ""
|
||||||
XFORMERS_ENABLED_VAE = True
|
XFORMERS_ENABLED_VAE = True
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -87,11 +86,7 @@ def has_gpu() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
||||||
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
|
||||||
"""
|
|
||||||
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
|
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
|
||||||
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
||||||
@ -99,6 +94,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
|||||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
executor_factory = request.param
|
executor_factory = request.param
|
||||||
processes_to_close: List[subprocess.Popen] = []
|
processes_to_close: List[subprocess.Popen] = []
|
||||||
|
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
params = rabbitmq.get_connection_params()
|
params = rabbitmq.get_connection_params()
|
||||||
@ -115,15 +111,18 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
|||||||
]
|
]
|
||||||
|
|
||||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
backend_command = [
|
|
||||||
"comfyui-worker",
|
|
||||||
"--port=9002",
|
|
||||||
f"-w={str(tmp_path)}",
|
|
||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
|
||||||
f"--executor-factory={executor_factory}"
|
|
||||||
]
|
|
||||||
|
|
||||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
# Start multiple workers
|
||||||
|
for i in range(num_workers):
|
||||||
|
backend_command = [
|
||||||
|
"comfyui-worker",
|
||||||
|
f"--port={9002 + i}",
|
||||||
|
f"-w={str(tmp_path)}",
|
||||||
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
|
f"--executor-factory={executor_factory}"
|
||||||
|
]
|
||||||
|
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server_address = f"http://{get_lan_ip()}:9001"
|
server_address = f"http://{get_lan_ip()}:9001"
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -134,10 +133,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
|||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
connected = True
|
connected = True
|
||||||
break
|
break
|
||||||
except ConnectionRefusedError:
|
except requests.exceptions.ConnectionError:
|
||||||
pass
|
pass
|
||||||
except Exception as exc:
|
|
||||||
logging.warning("", exc_info=exc)
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
if not connected:
|
if not connected:
|
||||||
raise RuntimeError("could not connect to frontend")
|
raise RuntimeError("could not connect to frontend")
|
||||||
|
|||||||
@ -139,3 +139,101 @@ async def test_basic_queue_worker_with_health_check(executor_factory):
|
|||||||
|
|
||||||
health_check_ok = await check_health(health_check_url)
|
health_check_ok = await check_health(health_check_url)
|
||||||
assert health_check_ok, "Health check server did not start properly"
|
assert health_check_ok, "Health check server did not start properly"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_queue_and_forget_prompt_api_integration(frontend_backend_worker_with_rabbitmq):
|
||||||
|
# Create the client using the server address from the fixture
|
||||||
|
client = AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq)
|
||||||
|
|
||||||
|
# Create a test prompt
|
||||||
|
prompt = sdxl_workflow_with_refiner("test prompt", inference_steps=1, refiner_steps=1)
|
||||||
|
|
||||||
|
# Queue the prompt
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||||
|
|
||||||
|
assert task_id is not None, "Failed to get a valid task ID"
|
||||||
|
|
||||||
|
# Poll for the result
|
||||||
|
max_attempts = 60 # Increase max attempts for integration test
|
||||||
|
poll_interval = 1 # Increase poll interval for integration test
|
||||||
|
for _ in range(max_attempts):
|
||||||
|
try:
|
||||||
|
response = await client.session.get(f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts/{task_id}")
|
||||||
|
if response.status == 200:
|
||||||
|
result = await response.json()
|
||||||
|
assert result is not None, "Received empty result"
|
||||||
|
|
||||||
|
# Find the first output node with images
|
||||||
|
output_node = next((node for node in result.values() if 'images' in node), None)
|
||||||
|
assert output_node is not None, "No output node with images found"
|
||||||
|
|
||||||
|
assert len(output_node['images']) > 0, "No images in output node"
|
||||||
|
assert 'filename' in output_node['images'][0], "No filename in image output"
|
||||||
|
assert 'subfolder' in output_node['images'][0], "No subfolder in image output"
|
||||||
|
assert 'type' in output_node['images'][0], "No type in image output"
|
||||||
|
|
||||||
|
# Check if we can access the image
|
||||||
|
image_url = f"{client.server_address}/view?filename={output_node['images'][0]['filename']}&type={output_node['images'][0]['type']}&subfolder={output_node['images'][0]['subfolder']}"
|
||||||
|
image_response = await client.session.get(image_url)
|
||||||
|
assert image_response.status == 200, f"Failed to retrieve image from {image_url}"
|
||||||
|
|
||||||
|
return # Test passed
|
||||||
|
elif response.status == 204:
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
else:
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while polling: {e}")
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
pytest.fail("Failed to get a 200 response with valid data within the timeout period")
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorker(DistributedPromptWorker):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.processed_workflows: set[str] = set()
|
||||||
|
|
||||||
|
async def on_will_complete_work_item(self, request: dict):
|
||||||
|
workflow_id = request.get('prompt_id', 'unknown')
|
||||||
|
self.processed_workflows.add(workflow_id)
|
||||||
|
await super().on_will_complete_work_item(request)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_two_workers_distinct_requests():
|
||||||
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
|
params = rabbitmq.get_connection_params()
|
||||||
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
|
||||||
|
# Start two test workers
|
||||||
|
workers: list[TestWorker] = []
|
||||||
|
for i in range(2):
|
||||||
|
worker = TestWorker(connection_uri=connection_uri, health_check_port=9090 + i, executor=ProcessPoolExecutor(max_workers=1))
|
||||||
|
await worker.init()
|
||||||
|
workers.append(worker)
|
||||||
|
|
||||||
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
|
queue = DistributedPromptQueue(is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||||
|
await queue.init()
|
||||||
|
|
||||||
|
# Submit two prompts
|
||||||
|
task1 = asyncio.create_task(queue.put_async(create_test_prompt()))
|
||||||
|
task2 = asyncio.create_task(queue.put_async(create_test_prompt()))
|
||||||
|
|
||||||
|
# Wait for tasks to complete
|
||||||
|
await asyncio.gather(task1, task2)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
for worker in workers:
|
||||||
|
await worker.close()
|
||||||
|
await queue.close()
|
||||||
|
|
||||||
|
# Assert that each worker processed exactly one distinct workflow
|
||||||
|
all_workflows = set()
|
||||||
|
for worker in workers:
|
||||||
|
assert len(worker.processed_workflows) == 1, f"Worker processed {len(worker.processed_workflows)} workflows instead of 1"
|
||||||
|
all_workflows.update(worker.processed_workflows)
|
||||||
|
|
||||||
|
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user