wip align openapi and api methods for error handling

This commit is contained in:
Benjamin Berman 2025-11-06 11:11:46 -08:00
parent 152524e8b1
commit be255c2691
5 changed files with 366 additions and 24 deletions

View File

@ -354,19 +354,29 @@ paths:
required: true required: true
description: | description: |
The ID of the prompt to query. The ID of the prompt to query.
responses: responses:
204: 204:
description: | description: |
The prompt is still in progress The prompt is still in progress
200: 200:
description: | description: |
Prompt outputs Prompt outputs
content: content:
application/json: application/json:
$ref: "#/components/schemas/Outputs" schema:
404: $ref: "#/components/schemas/Outputs"
description: | 404:
The prompt was not found description: |
The prompt was not found
500:
description: |
An execution error occurred while processing the prompt.
content:
application/json:
description:
An execution status directly from the workers
schema:
$ref: "#/components/schemas/ExecutionStatusAsDict"
/api/v1/prompts: /api/v1/prompts:
get: get:
operationId: list_prompts operationId: list_prompts

View File

@ -65,16 +65,21 @@ class AsyncRemoteComfyClient:
else: else:
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}") raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str: async def queue_and_forget_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = "respond-async", accept_header: str = "application/json") -> str:
""" """
Calls the API to queue a prompt, and forgets about it Calls the API to queue a prompt, and forgets about it
:param prompt: :param prompt:
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the task ID :return: the task ID
""" """
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response: headers=headers) as response:
if 200 <= response.status < 400: if 200 <= response.status < 400:
response_json = await response.json() response_json = await response.json()
@ -82,16 +87,21 @@ class AsyncRemoteComfyClient:
else: else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
""" """
Calls the API to queue a prompt. Calls the API to queue a prompt.
:param prompt: :param prompt:
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
""" """
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: headers=headers) as response:
if 200 <= response.status < 400: if 200 <= response.status < 400:
return V1QueuePromptResponse(**(await response.json())) return V1QueuePromptResponse(**(await response.json()))
@ -160,3 +170,36 @@ class AsyncRemoteComfyClient:
# images have filename, subfolder, type keys # images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it # todo: use the OpenAPI spec for this when I get around to updating it
return history_json[prompt_id].outputs return history_json[prompt_id].outputs
async def get_prompt_status(self, prompt_id: str) -> ClientResponse:
"""
Get the status of a prompt by ID using the API endpoint.
:param prompt_id: The prompt ID to query
:return: The ClientResponse object (caller should check status and read body)
"""
return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}"))
async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]:
"""
Poll a prompt until it's done (200), errors (500), or times out.
:param prompt_id: The prompt ID to poll
:param max_attempts: Maximum number of polling attempts
:param poll_interval: Time to wait between polls in seconds
:return: Tuple of (status_code, response_json or None)
"""
for _ in range(max_attempts):
async with await self.get_prompt_status(prompt_id) as response:
if response.status == 200:
return response.status, await response.json()
elif response.status == 500:
return response.status, await response.json()
elif response.status == 404:
return response.status, None
elif response.status == 204:
# Still in progress
await asyncio.sleep(poll_interval)
else:
# Unexpected status
return response.status, None
# Timeout
return 408, None

View File

@ -116,6 +116,24 @@ async def compress_body(request: web.Request, handler):
return response return response
@web.middleware
async def opentelemetry_middleware(request: web.Request, handler):
"""Middleware to extract and propagate OpenTelemetry context from request headers"""
from opentelemetry import propagate, context
# Extract OpenTelemetry context from headers
carrier = dict(request.headers)
ctx = propagate.extract(carrier)
# Attach context and execute handler
token = context.attach(ctx)
try:
response = await handler(request)
return response
finally:
context.detach(token)
def create_cors_middleware(allowed_origin: str): def create_cors_middleware(allowed_origin: str):
@web.middleware @web.middleware
async def cors_middleware(request: web.Request, handler): async def cors_middleware(request: web.Request, handler):
@ -127,7 +145,7 @@ def create_cors_middleware(allowed_origin: str):
response.headers['Access-Control-Allow-Origin'] = allowed_origin response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate'
response.headers['Access-Control-Allow-Credentials'] = 'true' response.headers['Access-Control-Allow-Credentials'] = 'true'
return response return response
@ -224,7 +242,7 @@ class PromptServer(ExecutorToClientProgress):
self._external_address: Optional[str] = None self._external_address: Optional[str] = None
self.background_tasks: dict[str, Task] = dict() self.background_tasks: dict[str, Task] = dict()
middlewares = [cache_control, deprecation_warning] middlewares = [opentelemetry_middleware, cache_control, deprecation_warning]
if args.enable_compress_response_body: if args.enable_compress_response_body:
middlewares.append(compress_body) middlewares.append(compress_body)
@ -867,9 +885,19 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(status=404) return web.json_response(status=404)
elif prompt_id in history_items: elif prompt_id in history_items:
history_entry = history_items[prompt_id] history_entry = history_items[prompt_id]
# Check if execution resulted in an error
if "status" in history_entry:
status = history_entry["status"]
if isinstance(status, dict) and status.get("status_str") == "error":
# Return ExecutionStatusAsDict format with status 500, matching POST /api/v1/prompts behavior
return web.Response(
body=json.dumps(status),
status=500,
content_type="application/json"
)
return web.json_response(history_entry["outputs"]) return web.json_response(history_entry["outputs"])
else: else:
return web.json_response(status=500) return web.Response(status=404, reason="prompt not found in expected state")
@routes.post("/api/v1/prompts") @routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:

View File

@ -3,8 +3,6 @@ import multiprocessing
import os import os
import pathlib import pathlib
import subprocess import subprocess
import sys
import time
import urllib import urllib
from contextvars import ContextVar from contextvars import ContextVar
from multiprocessing import Process from multiprocessing import Process
@ -12,9 +10,9 @@ from typing import List, Any, Generator
import pytest import pytest
import requests import requests
import sys
import time
from comfy.cli_args import default_configuration
from comfy.execution_context import context_configuration
os.environ['OTEL_METRICS_EXPORTER'] = 'none' os.environ['OTEL_METRICS_EXPORTER'] = 'none'
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
@ -22,6 +20,7 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
# fixes issues with running the testcontainers rabbitmqcontainer on Windows # fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost" os.environ["TC_HOST"] = "localhost"
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration from comfy.cli_args_types import Configuration
logging.getLogger("pika").setLevel(logging.CRITICAL + 1) logging.getLogger("pika").setLevel(logging.CRITICAL + 1)

View File

@ -243,3 +243,265 @@ async def test_two_workers_distinct_requests():
all_workflows.update(worker.processed_workflows) all_workflows.update(worker.processed_workflows)
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}" assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
# ============================================================================
# API Error Reporting Tests
# ============================================================================
@pytest.mark.asyncio
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with blocking request (no async preference)"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create an invalid prompt that will cause a validation error
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Make the prompt invalid by referencing a non-existent checkpoint
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent_checkpoint.safetensors"
# Post with blocking behavior (no prefer header for async)
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
# Should return 400 for validation error (invalid checkpoint)
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
# Verify ValidationErrorDict structure per OpenAPI spec
assert "type" in error_body, "Missing 'type' field in error response"
assert "message" in error_body, "Missing 'message' field in error response"
assert "details" in error_body, "Missing 'details' field in error response"
assert "extra_info" in error_body, "Missing 'extra_info' field in error response"
@pytest.mark.asyncio
async def test_api_error_reporting_async_prefer_header(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with Prefer: respond-async header"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt structure but with invalid checkpoint
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent.safetensors"
# Post with Prefer: respond-async header
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={
'Content-Type': 'application/json',
'Accept': 'application/json',
'Prefer': 'respond-async'
}
) as response:
# Should return 400 immediately for validation error
assert response.status == 400, f"Expected 400 for validation error, got {response.status}"
error_body = await response.json()
assert "type" in error_body
@pytest.mark.asyncio
async def test_api_error_reporting_async_accept_mimetype(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with +respond-async in Accept mimetype"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt with validation error
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "invalid_model.safetensors"
# Post with +respond-async in Accept header
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={
'Content-Type': 'application/json',
'Accept': 'application/json+respond-async'
}
) as response:
# Should return 400 for validation error (happens before queuing)
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
assert "type" in error_body
@pytest.mark.asyncio
async def test_api_get_prompt_status_success(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 200 with Outputs on success"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Queue async to get prompt_id
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None
# Poll until done
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
# For a valid prompt, should get 200
assert status_code == 200, f"Expected 200 for successful execution, got {status_code}"
assert result is not None
# Verify it returns outputs structure (dict with node IDs)
assert isinstance(result, dict)
assert len(result) > 0, "Expected non-empty outputs"
@pytest.mark.asyncio
async def test_api_get_prompt_status_404(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 404 for non-existent prompt"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Request a non-existent prompt ID
fake_prompt_id = str(uuid.uuid4())
async with await client.get_prompt_status(fake_prompt_id) as response:
assert response.status == 404, f"Expected 404 for non-existent prompt, got {response.status}"
@pytest.mark.asyncio
async def test_api_get_prompt_status_204_in_progress(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 204 while prompt is in progress"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt that takes some time to execute
prompt = sdxl_workflow_with_refiner("test", inference_steps=10, refiner_steps=10)
# Queue async
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
# Immediately check status (should be 204 or 200 if very fast)
async with await client.get_prompt_status(task_id) as response:
# Should be either 204 (in progress) or 200 (completed very fast)
assert response.status in [200, 204], f"Expected 200 or 204, got {response.status}"
if response.status == 204:
# No content for in-progress
content = await response.read()
assert len(content) == 0 or content == b'', "Expected no content for 204 response"
@pytest.mark.asyncio
async def test_api_async_workflow_both_methods(frontend_backend_worker_with_rabbitmq):
"""Test full async workflow: queue with respond-async, then poll for completion"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Method 1: Prefer header
task_id_1 = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id_1 is not None
# Method 2: +respond-async in Accept header
task_id_2 = await client.queue_and_forget_prompt_api(
prompt, prefer_header=None, accept_header="application/json+respond-async"
)
assert task_id_2 is not None
# Poll both until done
status_1, result_1 = await client.poll_prompt_until_done(task_id_1, max_attempts=60, poll_interval=1.0)
status_2, result_2 = await client.poll_prompt_until_done(task_id_2, max_attempts=60, poll_interval=1.0)
# Both should succeed
assert status_1 == 200, f"Task 1 failed with status {status_1}"
assert status_2 == 200, f"Task 2 failed with status {status_2}"
# Both should have outputs
assert result_1 is not None and len(result_1) > 0
assert result_2 is not None and len(result_2) > 0
@pytest.mark.asyncio
async def test_api_validation_error_structure(frontend_backend_worker_with_rabbitmq):
"""Test that validation errors return proper ValidationErrorDict structure"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create an invalid prompt (invalid checkpoint name)
prompt = sdxl_workflow_with_refiner("test", 1, 1)
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
# Verify ValidationErrorDict structure per OpenAPI spec
assert "type" in error_body, "Missing 'type'"
assert "message" in error_body, "Missing 'message'"
assert "details" in error_body, "Missing 'details'"
assert "extra_info" in error_body, "Missing 'extra_info'"
# extra_info should have exception_type and traceback
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
@pytest.mark.asyncio
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
"""Test that successful execution returns proper response structure"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Queue and wait for blocking response
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
assert response.status == 200, f"Expected 200, got {response.status}"
result = await response.json()
# Should have 'outputs' key (and deprecated 'urls' key)
assert "outputs" in result, "Missing 'outputs' in response"
# outputs should be a dict with node IDs as keys
outputs = result["outputs"]
assert isinstance(outputs, dict), "outputs should be a dict"
assert len(outputs) > 0, "outputs should not be empty"
# Each output should follow the Output schema
for node_id, output in outputs.items():
assert isinstance(output, dict), f"Output for node {node_id} should be a dict"
# Should have images or other output types
if "images" in output:
assert isinstance(output["images"], list), f"images for node {node_id} should be a list"
for image in output["images"]:
assert "filename" in image, f"image missing 'filename' in node {node_id}"
assert "subfolder" in image, f"image missing 'subfolder' in node {node_id}"
assert "type" in image, f"image missing 'type' in node {node_id}"
@pytest.mark.asyncio
async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns Outputs directly (not wrapped in history entry)"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create and queue a prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
task_id = await client.queue_and_forget_prompt_api(prompt)
# Poll until done
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
assert status_code == 200, f"Expected 200, got {status_code}"
assert result is not None, "Result should not be None"
# Per OpenAPI spec, GET should return Outputs directly, not wrapped
# result should be a dict with node IDs as keys
assert isinstance(result, dict), "Result should be a dict (Outputs)"
# Should NOT have 'prompt', 'outputs', 'status' keys (those are in history entry)
# Should have node IDs directly
for key in result.keys():
# Node IDs are typically numeric strings like "4", "13", etc.
# Should not be "prompt", "outputs", "status"
assert key not in ["prompt", "status"], \
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"