Improve OpenAPI contract in distributed context, propagating validation and execution errors correctly.

This commit is contained in:
doctorpangloss 2025-11-06 12:54:35 -08:00
parent be255c2691
commit 243f34f282
10 changed files with 181 additions and 21 deletions

View File

@ -871,6 +871,22 @@ components:
type: array
items:
type: string
node_errors:
type: object
description: "Detailed validation errors per node"
additionalProperties:
type: object
properties:
errors:
type: array
items:
$ref: "#/components/schemas/ValidationErrorDict"
dependent_outputs:
type: array
items:
type: string
class_type:
type: string
required:
- type
- details

View File

@ -85,7 +85,7 @@ class AsyncRemoteComfyClient:
response_json = await response.json()
return response_json["prompt_id"]
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
"""

View File

@ -161,7 +161,7 @@ def sdxl_workflow_with_refiner(prompt: str,
sampler="euler_ancestral",
scheduler="normal",
filename_prefix="sdxl_",
seed=42) -> PromptDict:
seed=42) -> dict:
prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT)
prompt_dict["17"]["inputs"]["text"] = prompt
prompt_dict["20"]["inputs"]["text"] = negative_prompt
@ -188,4 +188,4 @@ def sdxl_workflow_with_refiner(prompt: str,
prompt_dict["14"]["inputs"]["scheduler"] = scheduler
prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix
return Prompt.validate(prompt_dict)
return prompt_dict

View File

@ -1246,16 +1246,36 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
if len(good_outputs) == 0:
errors_list = []
extra_info = {}
for o, _errors in errors:
for error in _errors:
errors_list.append(f"{error['message']}: {error['details']}")
# Aggregate exception_type and traceback from validation errors
if 'extra_info' in error and error['extra_info']:
if 'exception_type' in error['extra_info'] and 'exception_type' not in extra_info:
extra_info['exception_type'] = error['extra_info']['exception_type']
if 'traceback' in error['extra_info'] and 'traceback' not in extra_info:
extra_info['traceback'] = error['extra_info']['traceback']
# Per OpenAPI spec, extra_info must have exception_type and traceback
# For non-exception validation errors, provide synthetic values
if 'exception_type' not in extra_info:
extra_info['exception_type'] = 'ValidationError'
if 'traceback' not in extra_info:
# Capture current stack for validation errors that don't have their own traceback
extra_info['traceback'] = traceback.format_stack()
# Include detailed node_errors for actionable debugging information
if node_errors:
extra_info['node_errors'] = node_errors
errors_list = "\n".join(errors_list)
error = {
"type": "prompt_outputs_failed_validation",
"message": "Prompt outputs failed validation",
"details": errors_list,
"extra_info": {}
"extra_info": extra_info
}
return ValidationTuple(False, error, list(good_outputs), node_errors)
@ -1301,7 +1321,7 @@ class PromptQueue(AbstractPromptQueue):
return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id: str, outputs: HistoryResultDict,
status: Optional[ExecutionStatus]):
status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None):
history_result = outputs
with self.mutex:
queue_item = self.currently_running.pop(item_id)
@ -1311,7 +1331,7 @@ class PromptQueue(AbstractPromptQueue):
status_dict = None
if status is not None:
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details)
outputs_ = history_result["outputs"]
# Remove sensitive data from extra_data before storing in history

View File

@ -79,12 +79,25 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
await e.execute_async(item[2], prompt_id, item[3], item[4])
need_gc = True
# Extract error details from status_messages if there's an error
error_details = None
if not e.success:
for event, data in e.status_messages:
if event == "execution_error":
error_details = data
break
# Convert status_messages tuples to string messages for backward compatibility
messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages]
q.task_done(item_id,
e.history_result,
status=queue_types.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
messages=messages),
error_details=error_details)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
server_instance.client_id)

View File

@ -905,9 +905,13 @@ class PromptServer(ExecutorToClientProgress):
if accept == '*/*':
accept = "application/json"
content_type = request.headers.get("content-type", "application/json")
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + " " + accept
# handle media type parameters like "application/json+respond-async"
if "+" in content_type:
content_type = content_type.split("+")[0]
if "+" in accept:
accept = accept.split("+")[0]
wait = not "respond-async" in preferences
@ -993,7 +997,8 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(body=str(ex), status=500)
if result.status is not None and result.status.status_str == "error":
return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json")
status_dict = result.status.as_dict(error_details=result.error_details)
return web.Response(body=json.dumps(status_dict), status=500, content_type="application/json")
# find images and read them
output_images: List[FileOutput] = []
for node_id, node in result.outputs.items():

View File

@ -199,9 +199,8 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False):
input_config: NotRequired[Dict[str, InputTypeSpec]]
received_value: NotRequired[Any]
linked_node: NotRequired[str]
traceback: NotRequired[list[str]]
exception_message: NotRequired[str]
exception_type: NotRequired[str]
node_errors: NotRequired[Dict[str, 'NodeErrorsDictValue']]
class ValidationErrorDict(TypedDict):

View File

@ -18,6 +18,7 @@ class TaskInvocation(NamedTuple):
item_id: int | str
outputs: OutputsDict
status: Optional[ExecutionStatus]
error_details: Optional['ExecutionErrorMessage'] = None
class ExecutionStatus(NamedTuple):
@ -25,12 +26,15 @@ class ExecutionStatus(NamedTuple):
completed: bool
messages: List[str]
def as_dict(self) -> ExecutionStatusAsDict:
return {
def as_dict(self, error_details: Optional['ExecutionErrorMessage'] = None) -> ExecutionStatusAsDict:
result: ExecutionStatusAsDict = {
"status_str": self.status_str,
"completed": self.completed,
"messages": copy.copy(self.messages),
}
if error_details is not None:
result["error_details"] = error_details
return result
class ExecutionError(RuntimeError):

View File

@ -162,7 +162,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
return item, item[1]
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]):
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None):
# callee: executed on the worker thread
if "outputs" in outputs:
outputs: HistoryResultDict
@ -173,7 +173,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
assert pending.completed is not None
assert not pending.completed.done()
# finish the task. status will transmit the errors in comfy's domain-specific way
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status))
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status, error_details=error_details))
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:

View File

@ -245,11 +245,6 @@ async def test_two_workers_distinct_requests():
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
# ============================================================================
# API Error Reporting Tests
# ============================================================================
@pytest.mark.asyncio
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with blocking request (no async preference)"""
@ -416,7 +411,7 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi
"""Test that validation errors return proper ValidationErrorDict structure"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create an invalid prompt (invalid checkpoint name)
prompt = sdxl_workflow_with_refiner("test", 1, 1)
prompt = sdxl_workflow_with_refiner("test", "", 1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
@ -436,11 +431,37 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi
assert "details" in error_body, "Missing 'details'"
assert "extra_info" in error_body, "Missing 'extra_info'"
assert error_body["type"] == "prompt_outputs_failed_validation", "unexpected type"
# extra_info should have exception_type and traceback
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
# extra_info should have node_errors with detailed validation information
assert "node_errors" in error_body["extra_info"], "Missing 'node_errors' in extra_info"
node_errors = error_body["extra_info"]["node_errors"]
assert isinstance(node_errors, dict), "node_errors should be a dict"
assert len(node_errors) > 0, "node_errors should contain at least one node"
# Verify node_errors structure for node "4" (CheckpointLoaderSimple with invalid ckpt_name)
assert "4" in node_errors, "Node '4' should have validation errors"
node_4_errors = node_errors["4"]
assert "errors" in node_4_errors, "Node '4' should have 'errors' field"
assert "class_type" in node_4_errors, "Node '4' should have 'class_type' field"
assert "dependent_outputs" in node_4_errors, "Node '4' should have 'dependent_outputs' field"
assert node_4_errors["class_type"] == "CheckpointLoaderSimple", "Node '4' class_type should be CheckpointLoaderSimple"
assert len(node_4_errors["errors"]) > 0, "Node '4' should have at least one error"
# Verify the error details include the validation error type and message
first_error = node_4_errors["errors"][0]
assert "type" in first_error, "Error should have 'type' field"
assert "message" in first_error, "Error should have 'message' field"
assert "details" in first_error, "Error should have 'details' field"
assert first_error["type"] == "value_not_in_list", f"Expected 'value_not_in_list' error, got {first_error['type']}"
assert "fake.safetensors" in first_error["details"], "Error details should mention 'fake.safetensors'"
@pytest.mark.asyncio
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
@ -505,3 +526,85 @@ async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_w
# Should not be "prompt", "outputs", "status"
assert key not in ["prompt", "status"], \
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"
@pytest.mark.asyncio
async def test_api_execution_error_blocking_mode(frontend_backend_worker_with_rabbitmq):
"""Test that execution errors (not validation) return proper error structure in blocking mode"""
from comfy_execution.graph_utils import GraphBuilder
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt that will fail during execution (not validation)
# Use Regex with a group name that doesn't exist - validation passes but execution fails
g = GraphBuilder()
regex_match = g.node("Regex", pattern="hello", string="hello world")
# Request a non-existent group name - this will pass validation but fail during execution
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
prompt = g.finalize()
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
# Execution errors return 500
assert response.status == 500, f"Expected 500 for execution error, got {response.status}"
error_body = await response.json()
# Verify ExecutionStatus structure
assert "status_str" in error_body, "Missing 'status_str'"
assert "completed" in error_body, "Missing 'completed'"
assert "messages" in error_body, "Missing 'messages'"
assert error_body["status_str"] == "error", f"Expected 'error', got {error_body['status_str']}"
assert error_body["completed"] == False, "completed should be False for errors"
assert isinstance(error_body["messages"], list), "messages should be a list"
assert len(error_body["messages"]) > 0, "messages should contain error details"
@pytest.mark.asyncio
async def test_api_execution_error_async_mode(frontend_backend_worker_with_rabbitmq):
"""Test that execution errors return proper error structure in respond-async mode"""
from comfy_execution.graph_utils import GraphBuilder
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt that will fail during execution (not validation)
# Use Regex with a group name that doesn't exist - validation passes but execution fails
g = GraphBuilder()
regex_match = g.node("Regex", pattern="hello", string="hello world")
# Request a non-existent group name - this will pass validation but fail during execution
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
prompt = g.finalize()
# Queue with respond-async
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None, "Should get task_id in async mode"
# Poll for completion
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
# In async mode with polling, errors come back as 200 with error in the response body
# because the prompt was accepted (202) and we're just retrieving the completed result
assert status_code in (200, 500), f"Expected 200 or 500, got {status_code}"
if status_code == 500:
# Error returned directly - should be ExecutionStatus
assert "status_str" in result, "Missing 'status_str'"
assert "completed" in result, "Missing 'completed'"
assert "messages" in result, "Missing 'messages'"
assert result["status_str"] == "error"
assert result["completed"] == False
assert len(result["messages"]) > 0
else:
# Error in successful response - result might be ExecutionStatus or empty outputs
# If it's a dict with status info, verify it
if "status_str" in result:
assert result["status_str"] == "error"
assert result["completed"] == False
assert len(result["messages"]) > 0