wip align openapi and api methods for error handling

This commit is contained in:
Benjamin Berman 2025-11-06 11:11:46 -08:00
parent 152524e8b1
commit be255c2691
5 changed files with 366 additions and 24 deletions

View File

@ -354,19 +354,29 @@ paths:
required: true
description: |
The ID of the prompt to query.
responses:
204:
description: |
The prompt is still in progress
200:
description: |
Prompt outputs
content:
application/json:
$ref: "#/components/schemas/Outputs"
404:
description: |
The prompt was not found
responses:
204:
description: |
The prompt is still in progress
200:
description: |
Prompt outputs
content:
application/json:
schema:
$ref: "#/components/schemas/Outputs"
404:
description: |
The prompt was not found
500:
description: |
An execution error occurred while processing the prompt.
content:
application/json:
description:
An execution status directly from the workers
schema:
$ref: "#/components/schemas/ExecutionStatusAsDict"
/api/v1/prompts:
get:
operationId: list_prompts

View File

@ -65,16 +65,21 @@ class AsyncRemoteComfyClient:
else:
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str:
async def queue_and_forget_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = "respond-async", accept_header: str = "application/json") -> str:
"""
Calls the API to queue a prompt, and forgets about it
:param prompt:
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the task ID
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response:
headers=headers) as response:
if 200 <= response.status < 400:
response_json = await response.json()
@ -82,16 +87,21 @@ class AsyncRemoteComfyClient:
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse:
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
"""
Calls the API to queue a prompt.
:param prompt:
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
headers=headers) as response:
if 200 <= response.status < 400:
return V1QueuePromptResponse(**(await response.json()))
@ -160,3 +170,36 @@ class AsyncRemoteComfyClient:
# images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it
return history_json[prompt_id].outputs
async def get_prompt_status(self, prompt_id: str) -> ClientResponse:
"""
Get the status of a prompt by ID using the API endpoint.
:param prompt_id: The prompt ID to query
:return: The ClientResponse object (caller should check status and read body)
"""
return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}"))
async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]:
"""
Poll a prompt until it's done (200), errors (500), or times out.
:param prompt_id: The prompt ID to poll
:param max_attempts: Maximum number of polling attempts
:param poll_interval: Time to wait between polls in seconds
:return: Tuple of (status_code, response_json or None)
"""
for _ in range(max_attempts):
async with await self.get_prompt_status(prompt_id) as response:
if response.status == 200:
return response.status, await response.json()
elif response.status == 500:
return response.status, await response.json()
elif response.status == 404:
return response.status, None
elif response.status == 204:
# Still in progress
await asyncio.sleep(poll_interval)
else:
# Unexpected status
return response.status, None
# Timeout
return 408, None

View File

@ -116,6 +116,24 @@ async def compress_body(request: web.Request, handler):
return response
@web.middleware
async def opentelemetry_middleware(request: web.Request, handler):
"""Middleware to extract and propagate OpenTelemetry context from request headers"""
from opentelemetry import propagate, context
# Extract OpenTelemetry context from headers
carrier = dict(request.headers)
ctx = propagate.extract(carrier)
# Attach context and execute handler
token = context.attach(ctx)
try:
response = await handler(request)
return response
finally:
context.detach(token)
def create_cors_middleware(allowed_origin: str):
@web.middleware
async def cors_middleware(request: web.Request, handler):
@ -127,7 +145,7 @@ def create_cors_middleware(allowed_origin: str):
response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
@ -224,7 +242,7 @@ class PromptServer(ExecutorToClientProgress):
self._external_address: Optional[str] = None
self.background_tasks: dict[str, Task] = dict()
middlewares = [cache_control, deprecation_warning]
middlewares = [opentelemetry_middleware, cache_control, deprecation_warning]
if args.enable_compress_response_body:
middlewares.append(compress_body)
@ -867,9 +885,19 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(status=404)
elif prompt_id in history_items:
history_entry = history_items[prompt_id]
# Check if execution resulted in an error
if "status" in history_entry:
status = history_entry["status"]
if isinstance(status, dict) and status.get("status_str") == "error":
# Return ExecutionStatusAsDict format with status 500, matching POST /api/v1/prompts behavior
return web.Response(
body=json.dumps(status),
status=500,
content_type="application/json"
)
return web.json_response(history_entry["outputs"])
else:
return web.json_response(status=500)
return web.Response(status=404, reason="prompt not found in expected state")
@routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:

View File

@ -3,8 +3,6 @@ import multiprocessing
import os
import pathlib
import subprocess
import sys
import time
import urllib
from contextvars import ContextVar
from multiprocessing import Process
@ -12,9 +10,9 @@ from typing import List, Any, Generator
import pytest
import requests
import sys
import time
from comfy.cli_args import default_configuration
from comfy.execution_context import context_configuration
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
@ -22,6 +20,7 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)

View File

@ -243,3 +243,265 @@ async def test_two_workers_distinct_requests():
all_workflows.update(worker.processed_workflows)
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
# ============================================================================
# API Error Reporting Tests
# ============================================================================
@pytest.mark.asyncio
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with blocking request (no async preference)"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create an invalid prompt that will cause a validation error
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Make the prompt invalid by referencing a non-existent checkpoint
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent_checkpoint.safetensors"
# Post with blocking behavior (no prefer header for async)
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
# Should return 400 for validation error (invalid checkpoint)
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
# Verify ValidationErrorDict structure per OpenAPI spec
assert "type" in error_body, "Missing 'type' field in error response"
assert "message" in error_body, "Missing 'message' field in error response"
assert "details" in error_body, "Missing 'details' field in error response"
assert "extra_info" in error_body, "Missing 'extra_info' field in error response"
@pytest.mark.asyncio
async def test_api_error_reporting_async_prefer_header(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with Prefer: respond-async header"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt structure but with invalid checkpoint
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent.safetensors"
# Post with Prefer: respond-async header
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={
'Content-Type': 'application/json',
'Accept': 'application/json',
'Prefer': 'respond-async'
}
) as response:
# Should return 400 immediately for validation error
assert response.status == 400, f"Expected 400 for validation error, got {response.status}"
error_body = await response.json()
assert "type" in error_body
@pytest.mark.asyncio
async def test_api_error_reporting_async_accept_mimetype(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with +respond-async in Accept mimetype"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt with validation error
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "invalid_model.safetensors"
# Post with +respond-async in Accept header
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={
'Content-Type': 'application/json',
'Accept': 'application/json+respond-async'
}
) as response:
# Should return 400 for validation error (happens before queuing)
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
assert "type" in error_body
@pytest.mark.asyncio
async def test_api_get_prompt_status_success(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 200 with Outputs on success"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Queue async to get prompt_id
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None
# Poll until done
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
# For a valid prompt, should get 200
assert status_code == 200, f"Expected 200 for successful execution, got {status_code}"
assert result is not None
# Verify it returns outputs structure (dict with node IDs)
assert isinstance(result, dict)
assert len(result) > 0, "Expected non-empty outputs"
@pytest.mark.asyncio
async def test_api_get_prompt_status_404(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 404 for non-existent prompt"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Request a non-existent prompt ID
fake_prompt_id = str(uuid.uuid4())
async with await client.get_prompt_status(fake_prompt_id) as response:
assert response.status == 404, f"Expected 404 for non-existent prompt, got {response.status}"
@pytest.mark.asyncio
async def test_api_get_prompt_status_204_in_progress(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 204 while prompt is in progress"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt that takes some time to execute
prompt = sdxl_workflow_with_refiner("test", inference_steps=10, refiner_steps=10)
# Queue async
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
# Immediately check status (should be 204 or 200 if very fast)
async with await client.get_prompt_status(task_id) as response:
# Should be either 204 (in progress) or 200 (completed very fast)
assert response.status in [200, 204], f"Expected 200 or 204, got {response.status}"
if response.status == 204:
# No content for in-progress
content = await response.read()
assert len(content) == 0 or content == b'', "Expected no content for 204 response"
@pytest.mark.asyncio
async def test_api_async_workflow_both_methods(frontend_backend_worker_with_rabbitmq):
"""Test full async workflow: queue with respond-async, then poll for completion"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Method 1: Prefer header
task_id_1 = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id_1 is not None
# Method 2: +respond-async in Accept header
task_id_2 = await client.queue_and_forget_prompt_api(
prompt, prefer_header=None, accept_header="application/json+respond-async"
)
assert task_id_2 is not None
# Poll both until done
status_1, result_1 = await client.poll_prompt_until_done(task_id_1, max_attempts=60, poll_interval=1.0)
status_2, result_2 = await client.poll_prompt_until_done(task_id_2, max_attempts=60, poll_interval=1.0)
# Both should succeed
assert status_1 == 200, f"Task 1 failed with status {status_1}"
assert status_2 == 200, f"Task 2 failed with status {status_2}"
# Both should have outputs
assert result_1 is not None and len(result_1) > 0
assert result_2 is not None and len(result_2) > 0
@pytest.mark.asyncio
async def test_api_validation_error_structure(frontend_backend_worker_with_rabbitmq):
"""Test that validation errors return proper ValidationErrorDict structure"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create an invalid prompt (invalid checkpoint name)
prompt = sdxl_workflow_with_refiner("test", 1, 1)
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
# Verify ValidationErrorDict structure per OpenAPI spec
assert "type" in error_body, "Missing 'type'"
assert "message" in error_body, "Missing 'message'"
assert "details" in error_body, "Missing 'details'"
assert "extra_info" in error_body, "Missing 'extra_info'"
# extra_info should have exception_type and traceback
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
@pytest.mark.asyncio
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
"""Test that successful execution returns proper response structure"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Queue and wait for blocking response
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
assert response.status == 200, f"Expected 200, got {response.status}"
result = await response.json()
# Should have 'outputs' key (and deprecated 'urls' key)
assert "outputs" in result, "Missing 'outputs' in response"
# outputs should be a dict with node IDs as keys
outputs = result["outputs"]
assert isinstance(outputs, dict), "outputs should be a dict"
assert len(outputs) > 0, "outputs should not be empty"
# Each output should follow the Output schema
for node_id, output in outputs.items():
assert isinstance(output, dict), f"Output for node {node_id} should be a dict"
# Should have images or other output types
if "images" in output:
assert isinstance(output["images"], list), f"images for node {node_id} should be a list"
for image in output["images"]:
assert "filename" in image, f"image missing 'filename' in node {node_id}"
assert "subfolder" in image, f"image missing 'subfolder' in node {node_id}"
assert "type" in image, f"image missing 'type' in node {node_id}"
@pytest.mark.asyncio
async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns Outputs directly (not wrapped in history entry)"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create and queue a prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
task_id = await client.queue_and_forget_prompt_api(prompt)
# Poll until done
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
assert status_code == 200, f"Expected 200, got {status_code}"
assert result is not None, "Result should not be None"
# Per OpenAPI spec, GET should return Outputs directly, not wrapped
# result should be a dict with node IDs as keys
assert isinstance(result, dict), "Result should be a dict (Outputs)"
# Should NOT have 'prompt', 'outputs', 'status' keys (those are in history entry)
# Should have node IDs directly
for key in result.keys():
# Node IDs are typically numeric strings like "4", "13", etc.
# Should not be "prompt", "outputs", "status"
assert key not in ["prompt", "status"], \
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"