diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index 1e983a6c4..bcfe29d8f 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -354,19 +354,29 @@ paths: required: true description: | The ID of the prompt to query. - responses: - 204: - description: | - The prompt is still in progress - 200: - description: | - Prompt outputs - content: - application/json: - $ref: "#/components/schemas/Outputs" - 404: - description: | - The prompt was not found + responses: + 204: + description: | + The prompt is still in progress + 200: + description: | + Prompt outputs + content: + application/json: + schema: + $ref: "#/components/schemas/Outputs" + 404: + description: | + The prompt was not found + 500: + description: | + An execution error occurred while processing the prompt. + content: + application/json: + description: + An execution status directly from the workers + schema: + $ref: "#/components/schemas/ExecutionStatusAsDict" /api/v1/prompts: get: operationId: list_prompts diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index cdf16e0f1..68ac7d1e3 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -65,16 +65,21 @@ class AsyncRemoteComfyClient: else: raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}") - async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str: + async def queue_and_forget_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = "respond-async", accept_header: str = "application/json") -> str: """ Calls the API to queue a prompt, and forgets about it :param prompt: + :param prefer_header: The Prefer header value (e.g., "respond-async" or None) + :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the task ID """ prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) response: ClientResponse + headers = {'Content-Type': 'application/json', 'Accept': accept_header} + if prefer_header: + headers['Prefer'] = prefer_header async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response: + headers=headers) as response: if 200 <= response.status < 400: response_json = await response.json() @@ -82,16 +87,21 @@ class AsyncRemoteComfyClient: else: raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") - async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: + async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse: """ Calls the API to queue a prompt. :param prompt: + :param prefer_header: The Prefer header value (e.g., "respond-async" or None) + :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) """ prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) response: ClientResponse + headers = {'Content-Type': 'application/json', 'Accept': accept_header} + if prefer_header: + headers['Prefer'] = prefer_header async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: + headers=headers) as response: if 200 <= response.status < 400: return V1QueuePromptResponse(**(await response.json())) @@ -160,3 +170,36 @@ class AsyncRemoteComfyClient: # images have filename, subfolder, type keys # todo: use the OpenAPI spec for this when I get around to updating it return history_json[prompt_id].outputs + + async def get_prompt_status(self, prompt_id: str) -> ClientResponse: + """ + Get the status of a prompt by ID using the API endpoint. + :param prompt_id: The prompt ID to query + :return: The ClientResponse object (caller should check status and read body) + """ + return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}")) + + async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]: + """ + Poll a prompt until it's done (200), errors (500), or times out. + :param prompt_id: The prompt ID to poll + :param max_attempts: Maximum number of polling attempts + :param poll_interval: Time to wait between polls in seconds + :return: Tuple of (status_code, response_json or None) + """ + for _ in range(max_attempts): + async with await self.get_prompt_status(prompt_id) as response: + if response.status == 200: + return response.status, await response.json() + elif response.status == 500: + return response.status, await response.json() + elif response.status == 404: + return response.status, None + elif response.status == 204: + # Still in progress + await asyncio.sleep(poll_interval) + else: + # Unexpected status + return response.status, None + # Timeout + return 408, None diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index df37647d3..b346a711e 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -116,6 +116,24 @@ async def compress_body(request: web.Request, handler): return response +@web.middleware +async def opentelemetry_middleware(request: web.Request, handler): + """Middleware to extract and propagate OpenTelemetry context from request headers""" + from opentelemetry import propagate, context + + # Extract OpenTelemetry context from headers + carrier = dict(request.headers) + ctx = propagate.extract(carrier) + + # Attach context and execute handler + token = context.attach(ctx) + try: + response = await handler(request) + return response + finally: + context.detach(token) + + def create_cors_middleware(allowed_origin: str): @web.middleware async def cors_middleware(request: web.Request, handler): @@ -127,7 +145,7 @@ def create_cors_middleware(allowed_origin: str): response.headers['Access-Control-Allow-Origin'] = allowed_origin response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' - response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate' response.headers['Access-Control-Allow-Credentials'] = 'true' return response @@ -224,7 +242,7 @@ class PromptServer(ExecutorToClientProgress): self._external_address: Optional[str] = None self.background_tasks: dict[str, Task] = dict() - middlewares = [cache_control, deprecation_warning] + middlewares = [opentelemetry_middleware, cache_control, deprecation_warning] if args.enable_compress_response_body: middlewares.append(compress_body) @@ -867,9 +885,19 @@ class PromptServer(ExecutorToClientProgress): return web.json_response(status=404) elif prompt_id in history_items: history_entry = history_items[prompt_id] + # Check if execution resulted in an error + if "status" in history_entry: + status = history_entry["status"] + if isinstance(status, dict) and status.get("status_str") == "error": + # Return ExecutionStatusAsDict format with status 500, matching POST /api/v1/prompts behavior + return web.Response( + body=json.dumps(status), + status=500, + content_type="application/json" + ) return web.json_response(history_entry["outputs"]) else: - return web.json_response(status=500) + return web.Response(status=404, reason="prompt not found in expected state") @routes.post("/api/v1/prompts") async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: diff --git a/tests/conftest.py b/tests/conftest.py index 712cfc0dc..33ae98f9d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,6 @@ import multiprocessing import os import pathlib import subprocess -import sys -import time import urllib from contextvars import ContextVar from multiprocessing import Process @@ -12,9 +10,9 @@ from typing import List, Any, Generator import pytest import requests +import sys +import time -from comfy.cli_args import default_configuration -from comfy.execution_context import context_configuration os.environ['OTEL_METRICS_EXPORTER'] = 'none' os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" @@ -22,6 +20,7 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost" +from comfy.cli_args import default_configuration from comfy.cli_args_types import Configuration logging.getLogger("pika").setLevel(logging.CRITICAL + 1) diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 5aa28c722..098072645 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -243,3 +243,265 @@ async def test_two_workers_distinct_requests(): all_workflows.update(worker.processed_workflows) assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}" + + +# ============================================================================ +# API Error Reporting Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq): + """Test error reporting with blocking request (no async preference)""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create an invalid prompt that will cause a validation error + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + # Make the prompt invalid by referencing a non-existent checkpoint + prompt["4"]["inputs"]["ckpt_name"] = "nonexistent_checkpoint.safetensors" + + # Post with blocking behavior (no prefer header for async) + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + # Should return 400 for validation error (invalid checkpoint) + assert response.status == 400, f"Expected 400, got {response.status}" + error_body = await response.json() + + # Verify ValidationErrorDict structure per OpenAPI spec + assert "type" in error_body, "Missing 'type' field in error response" + assert "message" in error_body, "Missing 'message' field in error response" + assert "details" in error_body, "Missing 'details' field in error response" + assert "extra_info" in error_body, "Missing 'extra_info' field in error response" + + +@pytest.mark.asyncio +async def test_api_error_reporting_async_prefer_header(frontend_backend_worker_with_rabbitmq): + """Test error reporting with Prefer: respond-async header""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt structure but with invalid checkpoint + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + prompt["4"]["inputs"]["ckpt_name"] = "nonexistent.safetensors" + + # Post with Prefer: respond-async header + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'Prefer': 'respond-async' + } + ) as response: + # Should return 400 immediately for validation error + assert response.status == 400, f"Expected 400 for validation error, got {response.status}" + error_body = await response.json() + assert "type" in error_body + + +@pytest.mark.asyncio +async def test_api_error_reporting_async_accept_mimetype(frontend_backend_worker_with_rabbitmq): + """Test error reporting with +respond-async in Accept mimetype""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt with validation error + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + prompt["4"]["inputs"]["ckpt_name"] = "invalid_model.safetensors" + + # Post with +respond-async in Accept header + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json+respond-async' + } + ) as response: + # Should return 400 for validation error (happens before queuing) + assert response.status == 400, f"Expected 400, got {response.status}" + error_body = await response.json() + assert "type" in error_body + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_success(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 200 with Outputs on success""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Queue async to get prompt_id + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None + + # Poll until done + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + # For a valid prompt, should get 200 + assert status_code == 200, f"Expected 200 for successful execution, got {status_code}" + assert result is not None + + # Verify it returns outputs structure (dict with node IDs) + assert isinstance(result, dict) + assert len(result) > 0, "Expected non-empty outputs" + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_404(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 404 for non-existent prompt""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Request a non-existent prompt ID + fake_prompt_id = str(uuid.uuid4()) + + async with await client.get_prompt_status(fake_prompt_id) as response: + assert response.status == 404, f"Expected 404 for non-existent prompt, got {response.status}" + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_204_in_progress(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 204 while prompt is in progress""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt that takes some time to execute + prompt = sdxl_workflow_with_refiner("test", inference_steps=10, refiner_steps=10) + + # Queue async + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + + # Immediately check status (should be 204 or 200 if very fast) + async with await client.get_prompt_status(task_id) as response: + # Should be either 204 (in progress) or 200 (completed very fast) + assert response.status in [200, 204], f"Expected 200 or 204, got {response.status}" + + if response.status == 204: + # No content for in-progress + content = await response.read() + assert len(content) == 0 or content == b'', "Expected no content for 204 response" + + +@pytest.mark.asyncio +async def test_api_async_workflow_both_methods(frontend_backend_worker_with_rabbitmq): + """Test full async workflow: queue with respond-async, then poll for completion""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Method 1: Prefer header + task_id_1 = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id_1 is not None + + # Method 2: +respond-async in Accept header + task_id_2 = await client.queue_and_forget_prompt_api( + prompt, prefer_header=None, accept_header="application/json+respond-async" + ) + assert task_id_2 is not None + + # Poll both until done + status_1, result_1 = await client.poll_prompt_until_done(task_id_1, max_attempts=60, poll_interval=1.0) + status_2, result_2 = await client.poll_prompt_until_done(task_id_2, max_attempts=60, poll_interval=1.0) + + # Both should succeed + assert status_1 == 200, f"Task 1 failed with status {status_1}" + assert status_2 == 200, f"Task 2 failed with status {status_2}" + + # Both should have outputs + assert result_1 is not None and len(result_1) > 0 + assert result_2 is not None and len(result_2) > 0 + + +@pytest.mark.asyncio +async def test_api_validation_error_structure(frontend_backend_worker_with_rabbitmq): + """Test that validation errors return proper ValidationErrorDict structure""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create an invalid prompt (invalid checkpoint name) + prompt = sdxl_workflow_with_refiner("test", 1, 1) + prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors" + + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + assert response.status == 400, f"Expected 400, got {response.status}" + + error_body = await response.json() + + # Verify ValidationErrorDict structure per OpenAPI spec + assert "type" in error_body, "Missing 'type'" + assert "message" in error_body, "Missing 'message'" + assert "details" in error_body, "Missing 'details'" + assert "extra_info" in error_body, "Missing 'extra_info'" + + # extra_info should have exception_type and traceback + assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info" + assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info" + assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list" + + +@pytest.mark.asyncio +async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq): + """Test that successful execution returns proper response structure""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Queue and wait for blocking response + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + assert response.status == 200, f"Expected 200, got {response.status}" + + result = await response.json() + + # Should have 'outputs' key (and deprecated 'urls' key) + assert "outputs" in result, "Missing 'outputs' in response" + + # outputs should be a dict with node IDs as keys + outputs = result["outputs"] + assert isinstance(outputs, dict), "outputs should be a dict" + assert len(outputs) > 0, "outputs should not be empty" + + # Each output should follow the Output schema + for node_id, output in outputs.items(): + assert isinstance(output, dict), f"Output for node {node_id} should be a dict" + # Should have images or other output types + if "images" in output: + assert isinstance(output["images"], list), f"images for node {node_id} should be a list" + for image in output["images"]: + assert "filename" in image, f"image missing 'filename' in node {node_id}" + assert "subfolder" in image, f"image missing 'subfolder' in node {node_id}" + assert "type" in image, f"image missing 'type' in node {node_id}" + + +@pytest.mark.asyncio +async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns Outputs directly (not wrapped in history entry)""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create and queue a prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt) + + # Poll until done + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + assert status_code == 200, f"Expected 200, got {status_code}" + assert result is not None, "Result should not be None" + + # Per OpenAPI spec, GET should return Outputs directly, not wrapped + # result should be a dict with node IDs as keys + assert isinstance(result, dict), "Result should be a dict (Outputs)" + + # Should NOT have 'prompt', 'outputs', 'status' keys (those are in history entry) + # Should have node IDs directly + for key in result.keys(): + # Node IDs are typically numeric strings like "4", "13", etc. + # Should not be "prompt", "outputs", "status" + assert key not in ["prompt", "status"], \ + f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"