mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Improve OpenAPI contract in distributed context, propagating validation and execution errors correctly.
This commit is contained in:
parent
be255c2691
commit
243f34f282
@ -871,6 +871,22 @@ components:
|
|||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
node_errors:
|
||||||
|
type: object
|
||||||
|
description: "Detailed validation errors per node"
|
||||||
|
additionalProperties:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
errors:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/components/schemas/ValidationErrorDict"
|
||||||
|
dependent_outputs:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
class_type:
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
- details
|
- details
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class AsyncRemoteComfyClient:
|
|||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return response_json["prompt_id"]
|
return response_json["prompt_id"]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
|
||||||
|
|
||||||
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -161,7 +161,7 @@ def sdxl_workflow_with_refiner(prompt: str,
|
|||||||
sampler="euler_ancestral",
|
sampler="euler_ancestral",
|
||||||
scheduler="normal",
|
scheduler="normal",
|
||||||
filename_prefix="sdxl_",
|
filename_prefix="sdxl_",
|
||||||
seed=42) -> PromptDict:
|
seed=42) -> dict:
|
||||||
prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT)
|
prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT)
|
||||||
prompt_dict["17"]["inputs"]["text"] = prompt
|
prompt_dict["17"]["inputs"]["text"] = prompt
|
||||||
prompt_dict["20"]["inputs"]["text"] = negative_prompt
|
prompt_dict["20"]["inputs"]["text"] = negative_prompt
|
||||||
@ -188,4 +188,4 @@ def sdxl_workflow_with_refiner(prompt: str,
|
|||||||
prompt_dict["14"]["inputs"]["scheduler"] = scheduler
|
prompt_dict["14"]["inputs"]["scheduler"] = scheduler
|
||||||
|
|
||||||
prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix
|
prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix
|
||||||
return Prompt.validate(prompt_dict)
|
return prompt_dict
|
||||||
|
|||||||
@ -1246,16 +1246,36 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
|
|
||||||
if len(good_outputs) == 0:
|
if len(good_outputs) == 0:
|
||||||
errors_list = []
|
errors_list = []
|
||||||
|
extra_info = {}
|
||||||
for o, _errors in errors:
|
for o, _errors in errors:
|
||||||
for error in _errors:
|
for error in _errors:
|
||||||
errors_list.append(f"{error['message']}: {error['details']}")
|
errors_list.append(f"{error['message']}: {error['details']}")
|
||||||
|
# Aggregate exception_type and traceback from validation errors
|
||||||
|
if 'extra_info' in error and error['extra_info']:
|
||||||
|
if 'exception_type' in error['extra_info'] and 'exception_type' not in extra_info:
|
||||||
|
extra_info['exception_type'] = error['extra_info']['exception_type']
|
||||||
|
if 'traceback' in error['extra_info'] and 'traceback' not in extra_info:
|
||||||
|
extra_info['traceback'] = error['extra_info']['traceback']
|
||||||
|
|
||||||
|
# Per OpenAPI spec, extra_info must have exception_type and traceback
|
||||||
|
# For non-exception validation errors, provide synthetic values
|
||||||
|
if 'exception_type' not in extra_info:
|
||||||
|
extra_info['exception_type'] = 'ValidationError'
|
||||||
|
if 'traceback' not in extra_info:
|
||||||
|
# Capture current stack for validation errors that don't have their own traceback
|
||||||
|
extra_info['traceback'] = traceback.format_stack()
|
||||||
|
|
||||||
|
# Include detailed node_errors for actionable debugging information
|
||||||
|
if node_errors:
|
||||||
|
extra_info['node_errors'] = node_errors
|
||||||
|
|
||||||
errors_list = "\n".join(errors_list)
|
errors_list = "\n".join(errors_list)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "prompt_outputs_failed_validation",
|
"type": "prompt_outputs_failed_validation",
|
||||||
"message": "Prompt outputs failed validation",
|
"message": "Prompt outputs failed validation",
|
||||||
"details": errors_list,
|
"details": errors_list,
|
||||||
"extra_info": {}
|
"extra_info": extra_info
|
||||||
}
|
}
|
||||||
|
|
||||||
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
return ValidationTuple(False, error, list(good_outputs), node_errors)
|
||||||
@ -1301,7 +1321,7 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||||
|
|
||||||
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
||||||
status: Optional[ExecutionStatus]):
|
status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None):
|
||||||
history_result = outputs
|
history_result = outputs
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
queue_item = self.currently_running.pop(item_id)
|
queue_item = self.currently_running.pop(item_id)
|
||||||
@ -1311,7 +1331,7 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
|
|
||||||
status_dict = None
|
status_dict = None
|
||||||
if status is not None:
|
if status is not None:
|
||||||
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
|
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details)
|
||||||
|
|
||||||
outputs_ = history_result["outputs"]
|
outputs_ = history_result["outputs"]
|
||||||
# Remove sensitive data from extra_data before storing in history
|
# Remove sensitive data from extra_data before storing in history
|
||||||
|
|||||||
@ -79,12 +79,25 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
|
|||||||
|
|
||||||
await e.execute_async(item[2], prompt_id, item[3], item[4])
|
await e.execute_async(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
|
# Extract error details from status_messages if there's an error
|
||||||
|
error_details = None
|
||||||
|
if not e.success:
|
||||||
|
for event, data in e.status_messages:
|
||||||
|
if event == "execution_error":
|
||||||
|
error_details = data
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert status_messages tuples to string messages for backward compatibility
|
||||||
|
messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages]
|
||||||
|
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.history_result,
|
e.history_result,
|
||||||
status=queue_types.ExecutionStatus(
|
status=queue_types.ExecutionStatus(
|
||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages))
|
messages=messages),
|
||||||
|
error_details=error_details)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
|
||||||
server_instance.client_id)
|
server_instance.client_id)
|
||||||
|
|||||||
@ -905,9 +905,13 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
if accept == '*/*':
|
if accept == '*/*':
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
content_type = request.headers.get("content-type", "application/json")
|
content_type = request.headers.get("content-type", "application/json")
|
||||||
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
|
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + " " + accept
|
||||||
|
|
||||||
|
# handle media type parameters like "application/json+respond-async"
|
||||||
if "+" in content_type:
|
if "+" in content_type:
|
||||||
content_type = content_type.split("+")[0]
|
content_type = content_type.split("+")[0]
|
||||||
|
if "+" in accept:
|
||||||
|
accept = accept.split("+")[0]
|
||||||
|
|
||||||
wait = not "respond-async" in preferences
|
wait = not "respond-async" in preferences
|
||||||
|
|
||||||
@ -993,7 +997,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.Response(body=str(ex), status=500)
|
return web.Response(body=str(ex), status=500)
|
||||||
|
|
||||||
if result.status is not None and result.status.status_str == "error":
|
if result.status is not None and result.status.status_str == "error":
|
||||||
return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json")
|
status_dict = result.status.as_dict(error_details=result.error_details)
|
||||||
|
return web.Response(body=json.dumps(status_dict), status=500, content_type="application/json")
|
||||||
# find images and read them
|
# find images and read them
|
||||||
output_images: List[FileOutput] = []
|
output_images: List[FileOutput] = []
|
||||||
for node_id, node in result.outputs.items():
|
for node_id, node in result.outputs.items():
|
||||||
|
|||||||
@ -199,9 +199,8 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False):
|
|||||||
input_config: NotRequired[Dict[str, InputTypeSpec]]
|
input_config: NotRequired[Dict[str, InputTypeSpec]]
|
||||||
received_value: NotRequired[Any]
|
received_value: NotRequired[Any]
|
||||||
linked_node: NotRequired[str]
|
linked_node: NotRequired[str]
|
||||||
traceback: NotRequired[list[str]]
|
|
||||||
exception_message: NotRequired[str]
|
exception_message: NotRequired[str]
|
||||||
exception_type: NotRequired[str]
|
node_errors: NotRequired[Dict[str, 'NodeErrorsDictValue']]
|
||||||
|
|
||||||
|
|
||||||
class ValidationErrorDict(TypedDict):
|
class ValidationErrorDict(TypedDict):
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class TaskInvocation(NamedTuple):
|
|||||||
item_id: int | str
|
item_id: int | str
|
||||||
outputs: OutputsDict
|
outputs: OutputsDict
|
||||||
status: Optional[ExecutionStatus]
|
status: Optional[ExecutionStatus]
|
||||||
|
error_details: Optional['ExecutionErrorMessage'] = None
|
||||||
|
|
||||||
|
|
||||||
class ExecutionStatus(NamedTuple):
|
class ExecutionStatus(NamedTuple):
|
||||||
@ -25,12 +26,15 @@ class ExecutionStatus(NamedTuple):
|
|||||||
completed: bool
|
completed: bool
|
||||||
messages: List[str]
|
messages: List[str]
|
||||||
|
|
||||||
def as_dict(self) -> ExecutionStatusAsDict:
|
def as_dict(self, error_details: Optional['ExecutionErrorMessage'] = None) -> ExecutionStatusAsDict:
|
||||||
return {
|
result: ExecutionStatusAsDict = {
|
||||||
"status_str": self.status_str,
|
"status_str": self.status_str,
|
||||||
"completed": self.completed,
|
"completed": self.completed,
|
||||||
"messages": copy.copy(self.messages),
|
"messages": copy.copy(self.messages),
|
||||||
}
|
}
|
||||||
|
if error_details is not None:
|
||||||
|
result["error_details"] = error_details
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ExecutionError(RuntimeError):
|
class ExecutionError(RuntimeError):
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
|
|
||||||
return item, item[1]
|
return item, item[1]
|
||||||
|
|
||||||
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]):
|
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None):
|
||||||
# callee: executed on the worker thread
|
# callee: executed on the worker thread
|
||||||
if "outputs" in outputs:
|
if "outputs" in outputs:
|
||||||
outputs: HistoryResultDict
|
outputs: HistoryResultDict
|
||||||
@ -173,7 +173,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
|||||||
assert pending.completed is not None
|
assert pending.completed is not None
|
||||||
assert not pending.completed.done()
|
assert not pending.completed.done()
|
||||||
# finish the task. status will transmit the errors in comfy's domain-specific way
|
# finish the task. status will transmit the errors in comfy's domain-specific way
|
||||||
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status))
|
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status, error_details=error_details))
|
||||||
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
|
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
|
||||||
|
|
||||||
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:
|
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:
|
||||||
|
|||||||
@ -245,11 +245,6 @@ async def test_two_workers_distinct_requests():
|
|||||||
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
|
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# API Error Reporting Tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
|
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
|
||||||
"""Test error reporting with blocking request (no async preference)"""
|
"""Test error reporting with blocking request (no async preference)"""
|
||||||
@ -416,7 +411,7 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi
|
|||||||
"""Test that validation errors return proper ValidationErrorDict structure"""
|
"""Test that validation errors return proper ValidationErrorDict structure"""
|
||||||
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
# Create an invalid prompt (invalid checkpoint name)
|
# Create an invalid prompt (invalid checkpoint name)
|
||||||
prompt = sdxl_workflow_with_refiner("test", 1, 1)
|
prompt = sdxl_workflow_with_refiner("test", "", 1, refiner_steps=1)
|
||||||
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
|
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
|
||||||
|
|
||||||
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
@ -436,11 +431,37 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi
|
|||||||
assert "details" in error_body, "Missing 'details'"
|
assert "details" in error_body, "Missing 'details'"
|
||||||
assert "extra_info" in error_body, "Missing 'extra_info'"
|
assert "extra_info" in error_body, "Missing 'extra_info'"
|
||||||
|
|
||||||
|
assert error_body["type"] == "prompt_outputs_failed_validation", "unexpected type"
|
||||||
|
|
||||||
# extra_info should have exception_type and traceback
|
# extra_info should have exception_type and traceback
|
||||||
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
|
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
|
||||||
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
|
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
|
||||||
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
|
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
|
||||||
|
|
||||||
|
# extra_info should have node_errors with detailed validation information
|
||||||
|
assert "node_errors" in error_body["extra_info"], "Missing 'node_errors' in extra_info"
|
||||||
|
node_errors = error_body["extra_info"]["node_errors"]
|
||||||
|
assert isinstance(node_errors, dict), "node_errors should be a dict"
|
||||||
|
assert len(node_errors) > 0, "node_errors should contain at least one node"
|
||||||
|
|
||||||
|
# Verify node_errors structure for node "4" (CheckpointLoaderSimple with invalid ckpt_name)
|
||||||
|
assert "4" in node_errors, "Node '4' should have validation errors"
|
||||||
|
node_4_errors = node_errors["4"]
|
||||||
|
assert "errors" in node_4_errors, "Node '4' should have 'errors' field"
|
||||||
|
assert "class_type" in node_4_errors, "Node '4' should have 'class_type' field"
|
||||||
|
assert "dependent_outputs" in node_4_errors, "Node '4' should have 'dependent_outputs' field"
|
||||||
|
|
||||||
|
assert node_4_errors["class_type"] == "CheckpointLoaderSimple", "Node '4' class_type should be CheckpointLoaderSimple"
|
||||||
|
assert len(node_4_errors["errors"]) > 0, "Node '4' should have at least one error"
|
||||||
|
|
||||||
|
# Verify the error details include the validation error type and message
|
||||||
|
first_error = node_4_errors["errors"][0]
|
||||||
|
assert "type" in first_error, "Error should have 'type' field"
|
||||||
|
assert "message" in first_error, "Error should have 'message' field"
|
||||||
|
assert "details" in first_error, "Error should have 'details' field"
|
||||||
|
assert first_error["type"] == "value_not_in_list", f"Expected 'value_not_in_list' error, got {first_error['type']}"
|
||||||
|
assert "fake.safetensors" in first_error["details"], "Error details should mention 'fake.safetensors'"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
|
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
|
||||||
@ -505,3 +526,85 @@ async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_w
|
|||||||
# Should not be "prompt", "outputs", "status"
|
# Should not be "prompt", "outputs", "status"
|
||||||
assert key not in ["prompt", "status"], \
|
assert key not in ["prompt", "status"], \
|
||||||
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"
|
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_execution_error_blocking_mode(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test that execution errors (not validation) return proper error structure in blocking mode"""
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a prompt that will fail during execution (not validation)
|
||||||
|
# Use Regex with a group name that doesn't exist - validation passes but execution fails
|
||||||
|
g = GraphBuilder()
|
||||||
|
regex_match = g.node("Regex", pattern="hello", string="hello world")
|
||||||
|
# Request a non-existent group name - this will pass validation but fail during execution
|
||||||
|
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
|
||||||
|
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
|
||||||
|
|
||||||
|
prompt = g.finalize()
|
||||||
|
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
|
||||||
|
|
||||||
|
async with client.session.post(
|
||||||
|
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
|
||||||
|
data=prompt_json,
|
||||||
|
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
|
||||||
|
) as response:
|
||||||
|
# Execution errors return 500
|
||||||
|
assert response.status == 500, f"Expected 500 for execution error, got {response.status}"
|
||||||
|
|
||||||
|
error_body = await response.json()
|
||||||
|
|
||||||
|
# Verify ExecutionStatus structure
|
||||||
|
assert "status_str" in error_body, "Missing 'status_str'"
|
||||||
|
assert "completed" in error_body, "Missing 'completed'"
|
||||||
|
assert "messages" in error_body, "Missing 'messages'"
|
||||||
|
|
||||||
|
assert error_body["status_str"] == "error", f"Expected 'error', got {error_body['status_str']}"
|
||||||
|
assert error_body["completed"] == False, "completed should be False for errors"
|
||||||
|
assert isinstance(error_body["messages"], list), "messages should be a list"
|
||||||
|
assert len(error_body["messages"]) > 0, "messages should contain error details"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_execution_error_async_mode(frontend_backend_worker_with_rabbitmq):
|
||||||
|
"""Test that execution errors return proper error structure in respond-async mode"""
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
|
|
||||||
|
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
|
||||||
|
# Create a prompt that will fail during execution (not validation)
|
||||||
|
# Use Regex with a group name that doesn't exist - validation passes but execution fails
|
||||||
|
g = GraphBuilder()
|
||||||
|
regex_match = g.node("Regex", pattern="hello", string="hello world")
|
||||||
|
# Request a non-existent group name - this will pass validation but fail during execution
|
||||||
|
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
|
||||||
|
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
|
||||||
|
|
||||||
|
prompt = g.finalize()
|
||||||
|
|
||||||
|
# Queue with respond-async
|
||||||
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||||
|
assert task_id is not None, "Should get task_id in async mode"
|
||||||
|
|
||||||
|
# Poll for completion
|
||||||
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||||
|
|
||||||
|
# In async mode with polling, errors come back as 200 with error in the response body
|
||||||
|
# because the prompt was accepted (202) and we're just retrieving the completed result
|
||||||
|
assert status_code in (200, 500), f"Expected 200 or 500, got {status_code}"
|
||||||
|
|
||||||
|
if status_code == 500:
|
||||||
|
# Error returned directly - should be ExecutionStatus
|
||||||
|
assert "status_str" in result, "Missing 'status_str'"
|
||||||
|
assert "completed" in result, "Missing 'completed'"
|
||||||
|
assert "messages" in result, "Missing 'messages'"
|
||||||
|
assert result["status_str"] == "error"
|
||||||
|
assert result["completed"] == False
|
||||||
|
assert len(result["messages"]) > 0
|
||||||
|
else:
|
||||||
|
# Error in successful response - result might be ExecutionStatus or empty outputs
|
||||||
|
# If it's a dict with status info, verify it
|
||||||
|
if "status_str" in result:
|
||||||
|
assert result["status_str"] == "error"
|
||||||
|
assert result["completed"] == False
|
||||||
|
assert len(result["messages"]) > 0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user