mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Improve OpenAPI contract in distributed context, propagating validation and execution errors correctly.
This commit is contained in:
parent
be255c2691
commit
243f34f282
@ -871,6 +871,22 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
node_errors:
|
||||
type: object
|
||||
description: "Detailed validation errors per node"
|
||||
additionalProperties:
|
||||
type: object
|
||||
properties:
|
||||
errors:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/ValidationErrorDict"
|
||||
dependent_outputs:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
class_type:
|
||||
type: string
|
||||
required:
|
||||
- type
|
||||
- details
|
||||
|
||||
@ -85,7 +85,7 @@ class AsyncRemoteComfyClient:
|
||||
response_json = await response.json()
|
||||
return response_json["prompt_id"]
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
|
||||
|
||||
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
||||
"""
|
||||
|
||||
@ -161,7 +161,7 @@ def sdxl_workflow_with_refiner(prompt: str,
|
||||
sampler="euler_ancestral",
|
||||
scheduler="normal",
|
||||
filename_prefix="sdxl_",
|
||||
seed=42) -> PromptDict:
|
||||
seed=42) -> dict:
|
||||
prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT)
|
||||
prompt_dict["17"]["inputs"]["text"] = prompt
|
||||
prompt_dict["20"]["inputs"]["text"] = negative_prompt
|
||||
@ -188,4 +188,4 @@ def sdxl_workflow_with_refiner(prompt: str,
|
||||
prompt_dict["14"]["inputs"]["scheduler"] = scheduler
|
||||
|
||||
prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix
|
||||
return Prompt.validate(prompt_dict)
|
||||
return prompt_dict
|
||||
|
||||
@ -1246,16 +1246,36 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = []
|
||||
extra_info = {}
|
||||
for o, _errors in errors:
|
||||
for error in _errors:
|
||||
errors_list.append(f"{error['message']}: {error['details']}")
|
||||
# Aggregate exception_type and traceback from validation errors
|
||||
if 'extra_info' in error and error['extra_info']:
|
||||
if 'exception_type' in error['extra_info'] and 'exception_type' not in extra_info:
|
||||
extra_info['exception_type'] = error['extra_info']['exception_type']
|
||||
if 'traceback' in error['extra_info'] and 'traceback' not in extra_info:
|
||||
extra_info['traceback'] = error['extra_info']['traceback']
|
||||
|
||||
# Per OpenAPI spec, extra_info must have exception_type and traceback
|
||||
# For non-exception validation errors, provide synthetic values
|
||||
if 'exception_type' not in extra_info:
|
||||
extra_info['exception_type'] = 'ValidationError'
|
||||
if 'traceback' not in extra_info:
|
||||
# Capture current stack for validation errors that don't have their own traceback
|
||||
extra_info['traceback'] = traceback.format_stack()
|
||||
|
||||
# Include detailed node_errors for actionable debugging information
|
||||
if node_errors:
|
||||
extra_info['node_errors'] = node_errors
|
||||
|
||||
errors_list = "\n".join(errors_list)
|
||||
|
||||
error = {
|
||||
"type": "prompt_outputs_failed_validation",
|
||||
"message": "Prompt outputs failed validation",
|
||||
"details": errors_list,
|
||||
"extra_info": {}
|
||||
"extra_info": extra_info
|
||||
}
|
||||
|
||||
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
||||
@ -1301,7 +1321,7 @@ class PromptQueue(AbstractPromptQueue):
|
||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||
|
||||
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
||||
status: Optional[ExecutionStatus]):
|
||||
status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None):
|
||||
history_result = outputs
|
||||
with self.mutex:
|
||||
queue_item = self.currently_running.pop(item_id)
|
||||
@ -1311,7 +1331,7 @@ class PromptQueue(AbstractPromptQueue):
|
||||
|
||||
status_dict = None
|
||||
if status is not None:
|
||||
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
|
||||
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details)
|
||||
|
||||
outputs_ = history_result["outputs"]
|
||||
# Remove sensitive data from extra_data before storing in history
|
||||
|
||||
@ -79,12 +79,25 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
|
||||
|
||||
await e.execute_async(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
|
||||
# Extract error details from status_messages if there's an error
|
||||
error_details = None
|
||||
if not e.success:
|
||||
for event, data in e.status_messages:
|
||||
if event == "execution_error":
|
||||
error_details = data
|
||||
break
|
||||
|
||||
# Convert status_messages tuples to string messages for backward compatibility
|
||||
messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages]
|
||||
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=queue_types.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages))
|
||||
messages=messages),
|
||||
error_details=error_details)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
|
||||
server_instance.client_id)
|
||||
|
||||
@ -905,9 +905,13 @@ class PromptServer(ExecutorToClientProgress):
|
||||
if accept == '*/*':
|
||||
accept = "application/json"
|
||||
content_type = request.headers.get("content-type", "application/json")
|
||||
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
|
||||
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + " " + accept
|
||||
|
||||
# handle media type parameters like "application/json+respond-async"
|
||||
if "+" in content_type:
|
||||
content_type = content_type.split("+")[0]
|
||||
if "+" in accept:
|
||||
accept = accept.split("+")[0]
|
||||
|
||||
wait = not "respond-async" in preferences
|
||||
|
||||
@ -993,7 +997,8 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.Response(body=str(ex), status=500)
|
||||
|
||||
if result.status is not None and result.status.status_str == "error":
|
||||
return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json")
|
||||
status_dict = result.status.as_dict(error_details=result.error_details)
|
||||
return web.Response(body=json.dumps(status_dict), status=500, content_type="application/json")
|
||||
# find images and read them
|
||||
output_images: List[FileOutput] = []
|
||||
for node_id, node in result.outputs.items():
|
||||
|
||||
@ -199,9 +199,8 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False):
|
||||
input_config: NotRequired[Dict[str, InputTypeSpec]]
|
||||
received_value: NotRequired[Any]
|
||||
linked_node: NotRequired[str]
|
||||
traceback: NotRequired[list[str]]
|
||||
exception_message: NotRequired[str]
|
||||
exception_type: NotRequired[str]
|
||||
node_errors: NotRequired[Dict[str, 'NodeErrorsDictValue']]
|
||||
|
||||
|
||||
class ValidationErrorDict(TypedDict):
|
||||
|
||||
@ -18,6 +18,7 @@ class TaskInvocation(NamedTuple):
|
||||
item_id: int | str
|
||||
outputs: OutputsDict
|
||||
status: Optional[ExecutionStatus]
|
||||
error_details: Optional['ExecutionErrorMessage'] = None
|
||||
|
||||
|
||||
class ExecutionStatus(NamedTuple):
|
||||
@ -25,12 +26,15 @@ class ExecutionStatus(NamedTuple):
|
||||
completed: bool
|
||||
messages: List[str]
|
||||
|
||||
def as_dict(self) -> ExecutionStatusAsDict:
|
||||
return {
|
||||
def as_dict(self, error_details: Optional['ExecutionErrorMessage'] = None) -> ExecutionStatusAsDict:
|
||||
result: ExecutionStatusAsDict = {
|
||||
"status_str": self.status_str,
|
||||
"completed": self.completed,
|
||||
"messages": copy.copy(self.messages),
|
||||
}
|
||||
if error_details is not None:
|
||||
result["error_details"] = error_details
|
||||
return result
|
||||
|
||||
|
||||
class ExecutionError(RuntimeError):
|
||||
|
||||
@ -162,7 +162,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
||||
|
||||
return item, item[1]
|
||||
|
||||
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]):
|
||||
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None):
|
||||
# callee: executed on the worker thread
|
||||
if "outputs" in outputs:
|
||||
outputs: HistoryResultDict
|
||||
@ -173,7 +173,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
||||
assert pending.completed is not None
|
||||
assert not pending.completed.done()
|
||||
# finish the task. status will transmit the errors in comfy's domain-specific way
|
||||
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status))
|
||||
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status, error_details=error_details))
|
||||
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
|
||||
|
||||
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:
|
||||
|
||||
@ -245,11 +245,6 @@ async def test_two_workers_distinct_requests():
|
||||
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# API Error Reporting Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
|
||||
"""Test error reporting with blocking request (no async preference)"""
|
||||
@ -416,7 +411,7 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi
|
||||
"""Test that validation errors return proper ValidationErrorDict structure"""
|
||||
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||
# Create an invalid prompt (invalid checkpoint name)
|
||||
prompt = sdxl_workflow_with_refiner("test", 1, 1)
|
||||
prompt = sdxl_workflow_with_refiner("test", "", 1, refiner_steps=1)
|
||||
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
|
||||
|
||||
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||
@ -436,11 +431,37 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi
|
||||
assert "details" in error_body, "Missing 'details'"
|
||||
assert "extra_info" in error_body, "Missing 'extra_info'"
|
||||
|
||||
assert error_body["type"] == "prompt_outputs_failed_validation", "unexpected type"
|
||||
|
||||
# extra_info should have exception_type and traceback
|
||||
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
|
||||
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
|
||||
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
|
||||
|
||||
# extra_info should have node_errors with detailed validation information
|
||||
assert "node_errors" in error_body["extra_info"], "Missing 'node_errors' in extra_info"
|
||||
node_errors = error_body["extra_info"]["node_errors"]
|
||||
assert isinstance(node_errors, dict), "node_errors should be a dict"
|
||||
assert len(node_errors) > 0, "node_errors should contain at least one node"
|
||||
|
||||
# Verify node_errors structure for node "4" (CheckpointLoaderSimple with invalid ckpt_name)
|
||||
assert "4" in node_errors, "Node '4' should have validation errors"
|
||||
node_4_errors = node_errors["4"]
|
||||
assert "errors" in node_4_errors, "Node '4' should have 'errors' field"
|
||||
assert "class_type" in node_4_errors, "Node '4' should have 'class_type' field"
|
||||
assert "dependent_outputs" in node_4_errors, "Node '4' should have 'dependent_outputs' field"
|
||||
|
||||
assert node_4_errors["class_type"] == "CheckpointLoaderSimple", "Node '4' class_type should be CheckpointLoaderSimple"
|
||||
assert len(node_4_errors["errors"]) > 0, "Node '4' should have at least one error"
|
||||
|
||||
# Verify the error details include the validation error type and message
|
||||
first_error = node_4_errors["errors"][0]
|
||||
assert "type" in first_error, "Error should have 'type' field"
|
||||
assert "message" in first_error, "Error should have 'message' field"
|
||||
assert "details" in first_error, "Error should have 'details' field"
|
||||
assert first_error["type"] == "value_not_in_list", f"Expected 'value_not_in_list' error, got {first_error['type']}"
|
||||
assert "fake.safetensors" in first_error["details"], "Error details should mention 'fake.safetensors'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
|
||||
@ -505,3 +526,85 @@ async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_w
|
||||
# Should not be "prompt", "outputs", "status"
|
||||
assert key not in ["prompt", "status"], \
|
||||
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_execution_error_blocking_mode(frontend_backend_worker_with_rabbitmq):
|
||||
"""Test that execution errors (not validation) return proper error structure in blocking mode"""
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
|
||||
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||
# Create a prompt that will fail during execution (not validation)
|
||||
# Use Regex with a group name that doesn't exist - validation passes but execution fails
|
||||
g = GraphBuilder()
|
||||
regex_match = g.node("Regex", pattern="hello", string="hello world")
|
||||
# Request a non-existent group name - this will pass validation but fail during execution
|
||||
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
|
||||
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
|
||||
|
||||
prompt = g.finalize()
|
||||
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||
|
||||
async with client.session.post(
|
||||
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||
data=prompt_json,
|
||||
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
|
||||
) as response:
|
||||
# Execution errors return 500
|
||||
assert response.status == 500, f"Expected 500 for execution error, got {response.status}"
|
||||
|
||||
error_body = await response.json()
|
||||
|
||||
# Verify ExecutionStatus structure
|
||||
assert "status_str" in error_body, "Missing 'status_str'"
|
||||
assert "completed" in error_body, "Missing 'completed'"
|
||||
assert "messages" in error_body, "Missing 'messages'"
|
||||
|
||||
assert error_body["status_str"] == "error", f"Expected 'error', got {error_body['status_str']}"
|
||||
assert error_body["completed"] == False, "completed should be False for errors"
|
||||
assert isinstance(error_body["messages"], list), "messages should be a list"
|
||||
assert len(error_body["messages"]) > 0, "messages should contain error details"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_execution_error_async_mode(frontend_backend_worker_with_rabbitmq):
|
||||
"""Test that execution errors return proper error structure in respond-async mode"""
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
|
||||
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||
# Create a prompt that will fail during execution (not validation)
|
||||
# Use Regex with a group name that doesn't exist - validation passes but execution fails
|
||||
g = GraphBuilder()
|
||||
regex_match = g.node("Regex", pattern="hello", string="hello world")
|
||||
# Request a non-existent group name - this will pass validation but fail during execution
|
||||
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
|
||||
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
|
||||
|
||||
prompt = g.finalize()
|
||||
|
||||
# Queue with respond-async
|
||||
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||
assert task_id is not None, "Should get task_id in async mode"
|
||||
|
||||
# Poll for completion
|
||||
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||
|
||||
# In async mode with polling, errors come back as 200 with error in the response body
|
||||
# because the prompt was accepted (202) and we're just retrieving the completed result
|
||||
assert status_code in (200, 500), f"Expected 200 or 500, got {status_code}"
|
||||
|
||||
if status_code == 500:
|
||||
# Error returned directly - should be ExecutionStatus
|
||||
assert "status_str" in result, "Missing 'status_str'"
|
||||
assert "completed" in result, "Missing 'completed'"
|
||||
assert "messages" in result, "Missing 'messages'"
|
||||
assert result["status_str"] == "error"
|
||||
assert result["completed"] == False
|
||||
assert len(result["messages"]) > 0
|
||||
else:
|
||||
# Error in successful response - result might be ExecutionStatus or empty outputs
|
||||
# If it's a dict with status info, verify it
|
||||
if "status_str" in result:
|
||||
assert result["status_str"] == "error"
|
||||
assert result["completed"] == False
|
||||
assert len(result["messages"]) > 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user