From be255c269111912e655d125bf90aee6da35292fd Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Thu, 6 Nov 2025 11:11:46 -0800 Subject: [PATCH 1/9] wip align openapi and api methods for error handling --- comfy/api/openapi.yaml | 36 ++- comfy/client/aio_client.py | 51 +++- comfy/cmd/server.py | 34 ++- tests/conftest.py | 7 +- tests/distributed/test_distributed_queue.py | 262 ++++++++++++++++++++ 5 files changed, 366 insertions(+), 24 deletions(-) diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index 1e983a6c4..bcfe29d8f 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -354,19 +354,29 @@ paths: required: true description: | The ID of the prompt to query. - responses: - 204: - description: | - The prompt is still in progress - 200: - description: | - Prompt outputs - content: - application/json: - $ref: "#/components/schemas/Outputs" - 404: - description: | - The prompt was not found + responses: + 204: + description: | + The prompt is still in progress + 200: + description: | + Prompt outputs + content: + application/json: + schema: + $ref: "#/components/schemas/Outputs" + 404: + description: | + The prompt was not found + 500: + description: | + An execution error occurred while processing the prompt. + content: + application/json: + description: + An execution status directly from the workers + schema: + $ref: "#/components/schemas/ExecutionStatusAsDict" /api/v1/prompts: get: operationId: list_prompts diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index cdf16e0f1..68ac7d1e3 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -65,16 +65,21 @@ class AsyncRemoteComfyClient: else: raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}") - async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str: + async def queue_and_forget_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = "respond-async", accept_header: str = "application/json") -> str: """ Calls the API to queue a prompt, and forgets about it :param prompt: + :param prefer_header: The Prefer header value (e.g., "respond-async" or None) + :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the task ID """ prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) response: ClientResponse + headers = {'Content-Type': 'application/json', 'Accept': accept_header} + if prefer_header: + headers['Prefer'] = prefer_header async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response: + headers=headers) as response: if 200 <= response.status < 400: response_json = await response.json() @@ -82,16 +87,21 @@ class AsyncRemoteComfyClient: else: raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") - async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: + async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse: """ Calls the API to queue a prompt. :param prompt: + :param prefer_header: The Prefer header value (e.g., "respond-async" or None) + :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) """ prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) response: ClientResponse + headers = {'Content-Type': 'application/json', 'Accept': accept_header} + if prefer_header: + headers['Prefer'] = prefer_header async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: + headers=headers) as response: if 200 <= response.status < 400: return V1QueuePromptResponse(**(await response.json())) @@ -160,3 +170,36 @@ class AsyncRemoteComfyClient: # images have filename, subfolder, type keys # todo: use the OpenAPI spec for this when I get around to updating it return history_json[prompt_id].outputs + + async def get_prompt_status(self, prompt_id: str) -> ClientResponse: + """ + Get the status of a prompt by ID using the API endpoint. + :param prompt_id: The prompt ID to query + :return: The ClientResponse object (caller should check status and read body) + """ + return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}")) + + async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]: + """ + Poll a prompt until it's done (200), errors (500), or times out. + :param prompt_id: The prompt ID to poll + :param max_attempts: Maximum number of polling attempts + :param poll_interval: Time to wait between polls in seconds + :return: Tuple of (status_code, response_json or None) + """ + for _ in range(max_attempts): + async with await self.get_prompt_status(prompt_id) as response: + if response.status == 200: + return response.status, await response.json() + elif response.status == 500: + return response.status, await response.json() + elif response.status == 404: + return response.status, None + elif response.status == 204: + # Still in progress + await asyncio.sleep(poll_interval) + else: + # Unexpected status + return response.status, None + # Timeout + return 408, None diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index df37647d3..b346a711e 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -116,6 +116,24 @@ async def compress_body(request: web.Request, handler): return response +@web.middleware +async def opentelemetry_middleware(request: web.Request, handler): + """Middleware to extract and propagate OpenTelemetry context from request headers""" + from opentelemetry import propagate, context + + # Extract OpenTelemetry context from headers + carrier = dict(request.headers) + ctx = propagate.extract(carrier) + + # Attach context and execute handler + token = context.attach(ctx) + try: + response = await handler(request) + return response + finally: + context.detach(token) + + def create_cors_middleware(allowed_origin: str): @web.middleware async def cors_middleware(request: web.Request, handler): @@ -127,7 +145,7 @@ def create_cors_middleware(allowed_origin: str): response.headers['Access-Control-Allow-Origin'] = allowed_origin response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' - response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate' response.headers['Access-Control-Allow-Credentials'] = 'true' return response @@ -224,7 +242,7 @@ class PromptServer(ExecutorToClientProgress): self._external_address: Optional[str] = None self.background_tasks: dict[str, Task] = dict() - middlewares = [cache_control, deprecation_warning] + middlewares = [opentelemetry_middleware, cache_control, deprecation_warning] if args.enable_compress_response_body: middlewares.append(compress_body) @@ -867,9 +885,19 @@ class PromptServer(ExecutorToClientProgress): return web.json_response(status=404) elif prompt_id in history_items: history_entry = history_items[prompt_id] + # Check if execution resulted in an error + if "status" in history_entry: + status = history_entry["status"] + if isinstance(status, dict) and status.get("status_str") == "error": + # Return ExecutionStatusAsDict format with status 500, matching POST /api/v1/prompts behavior + return web.Response( + body=json.dumps(status), + status=500, + content_type="application/json" + ) return web.json_response(history_entry["outputs"]) else: - return web.json_response(status=500) + return web.Response(status=404, reason="prompt not found in expected state") @routes.post("/api/v1/prompts") async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: diff --git a/tests/conftest.py b/tests/conftest.py index 712cfc0dc..33ae98f9d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,6 @@ import multiprocessing import os import pathlib import subprocess -import sys -import time import urllib from contextvars import ContextVar from multiprocessing import Process @@ -12,9 +10,9 @@ from typing import List, Any, Generator import pytest import requests +import sys +import time -from comfy.cli_args import default_configuration -from comfy.execution_context import context_configuration os.environ['OTEL_METRICS_EXPORTER'] = 'none' os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" @@ -22,6 +20,7 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost" +from comfy.cli_args import default_configuration from comfy.cli_args_types import Configuration logging.getLogger("pika").setLevel(logging.CRITICAL + 1) diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 5aa28c722..098072645 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -243,3 +243,265 @@ async def test_two_workers_distinct_requests(): all_workflows.update(worker.processed_workflows) assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}" + + +# ============================================================================ +# API Error Reporting Tests +# ============================================================================ + + +@pytest.mark.asyncio +async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq): + """Test error reporting with blocking request (no async preference)""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create an invalid prompt that will cause a validation error + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + # Make the prompt invalid by referencing a non-existent checkpoint + prompt["4"]["inputs"]["ckpt_name"] = "nonexistent_checkpoint.safetensors" + + # Post with blocking behavior (no prefer header for async) + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + # Should return 400 for validation error (invalid checkpoint) + assert response.status == 400, f"Expected 400, got {response.status}" + error_body = await response.json() + + # Verify ValidationErrorDict structure per OpenAPI spec + assert "type" in error_body, "Missing 'type' field in error response" + assert "message" in error_body, "Missing 'message' field in error response" + assert "details" in error_body, "Missing 'details' field in error response" + assert "extra_info" in error_body, "Missing 'extra_info' field in error response" + + +@pytest.mark.asyncio +async def test_api_error_reporting_async_prefer_header(frontend_backend_worker_with_rabbitmq): + """Test error reporting with Prefer: respond-async header""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt structure but with invalid checkpoint + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + prompt["4"]["inputs"]["ckpt_name"] = "nonexistent.safetensors" + + # Post with Prefer: respond-async header + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'Prefer': 'respond-async' + } + ) as response: + # Should return 400 immediately for validation error + assert response.status == 400, f"Expected 400 for validation error, got {response.status}" + error_body = await response.json() + assert "type" in error_body + + +@pytest.mark.asyncio +async def test_api_error_reporting_async_accept_mimetype(frontend_backend_worker_with_rabbitmq): + """Test error reporting with +respond-async in Accept mimetype""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt with validation error + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + prompt["4"]["inputs"]["ckpt_name"] = "invalid_model.safetensors" + + # Post with +respond-async in Accept header + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json+respond-async' + } + ) as response: + # Should return 400 for validation error (happens before queuing) + assert response.status == 400, f"Expected 400, got {response.status}" + error_body = await response.json() + assert "type" in error_body + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_success(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 200 with Outputs on success""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Queue async to get prompt_id + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None + + # Poll until done + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + # For a valid prompt, should get 200 + assert status_code == 200, f"Expected 200 for successful execution, got {status_code}" + assert result is not None + + # Verify it returns outputs structure (dict with node IDs) + assert isinstance(result, dict) + assert len(result) > 0, "Expected non-empty outputs" + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_404(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 404 for non-existent prompt""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Request a non-existent prompt ID + fake_prompt_id = str(uuid.uuid4()) + + async with await client.get_prompt_status(fake_prompt_id) as response: + assert response.status == 404, f"Expected 404 for non-existent prompt, got {response.status}" + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_204_in_progress(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 204 while prompt is in progress""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt that takes some time to execute + prompt = sdxl_workflow_with_refiner("test", inference_steps=10, refiner_steps=10) + + # Queue async + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + + # Immediately check status (should be 204 or 200 if very fast) + async with await client.get_prompt_status(task_id) as response: + # Should be either 204 (in progress) or 200 (completed very fast) + assert response.status in [200, 204], f"Expected 200 or 204, got {response.status}" + + if response.status == 204: + # No content for in-progress + content = await response.read() + assert len(content) == 0 or content == b'', "Expected no content for 204 response" + + +@pytest.mark.asyncio +async def test_api_async_workflow_both_methods(frontend_backend_worker_with_rabbitmq): + """Test full async workflow: queue with respond-async, then poll for completion""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Method 1: Prefer header + task_id_1 = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id_1 is not None + + # Method 2: +respond-async in Accept header + task_id_2 = await client.queue_and_forget_prompt_api( + prompt, prefer_header=None, accept_header="application/json+respond-async" + ) + assert task_id_2 is not None + + # Poll both until done + status_1, result_1 = await client.poll_prompt_until_done(task_id_1, max_attempts=60, poll_interval=1.0) + status_2, result_2 = await client.poll_prompt_until_done(task_id_2, max_attempts=60, poll_interval=1.0) + + # Both should succeed + assert status_1 == 200, f"Task 1 failed with status {status_1}" + assert status_2 == 200, f"Task 2 failed with status {status_2}" + + # Both should have outputs + assert result_1 is not None and len(result_1) > 0 + assert result_2 is not None and len(result_2) > 0 + + +@pytest.mark.asyncio +async def test_api_validation_error_structure(frontend_backend_worker_with_rabbitmq): + """Test that validation errors return proper ValidationErrorDict structure""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create an invalid prompt (invalid checkpoint name) + prompt = sdxl_workflow_with_refiner("test", 1, 1) + prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors" + + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + assert response.status == 400, f"Expected 400, got {response.status}" + + error_body = await response.json() + + # Verify ValidationErrorDict structure per OpenAPI spec + assert "type" in error_body, "Missing 'type'" + assert "message" in error_body, "Missing 'message'" + assert "details" in error_body, "Missing 'details'" + assert "extra_info" in error_body, "Missing 'extra_info'" + + # extra_info should have exception_type and traceback + assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info" + assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info" + assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list" + + +@pytest.mark.asyncio +async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq): + """Test that successful execution returns proper response structure""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Queue and wait for blocking response + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + assert response.status == 200, f"Expected 200, got {response.status}" + + result = await response.json() + + # Should have 'outputs' key (and deprecated 'urls' key) + assert "outputs" in result, "Missing 'outputs' in response" + + # outputs should be a dict with node IDs as keys + outputs = result["outputs"] + assert isinstance(outputs, dict), "outputs should be a dict" + assert len(outputs) > 0, "outputs should not be empty" + + # Each output should follow the Output schema + for node_id, output in outputs.items(): + assert isinstance(output, dict), f"Output for node {node_id} should be a dict" + # Should have images or other output types + if "images" in output: + assert isinstance(output["images"], list), f"images for node {node_id} should be a list" + for image in output["images"]: + assert "filename" in image, f"image missing 'filename' in node {node_id}" + assert "subfolder" in image, f"image missing 'subfolder' in node {node_id}" + assert "type" in image, f"image missing 'type' in node {node_id}" + + +@pytest.mark.asyncio +async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns Outputs directly (not wrapped in history entry)""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create and queue a prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt) + + # Poll until done + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + assert status_code == 200, f"Expected 200, got {status_code}" + assert result is not None, "Result should not be None" + + # Per OpenAPI spec, GET should return Outputs directly, not wrapped + # result should be a dict with node IDs as keys + assert isinstance(result, dict), "Result should be a dict (Outputs)" + + # Should NOT have 'prompt', 'outputs', 'status' keys (those are in history entry) + # Should have node IDs directly + for key in result.keys(): + # Node IDs are typically numeric strings like "4", "13", etc. + # Should not be "prompt", "outputs", "status" + assert key not in ["prompt", "status"], \ + f"GET endpoint should return Outputs directly, not history entry. Found key: {key}" From 243f34f282243bc71f664fe41f8e144852e5365a Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Thu, 6 Nov 2025 12:54:35 -0800 Subject: [PATCH 2/9] Improve OpenAPI contract in distributed context, propagating validation and execution errors correctly. --- comfy/api/openapi.yaml | 16 +++ comfy/client/aio_client.py | 2 +- comfy/client/sdxl_with_refiner_workflow.py | 4 +- comfy/cmd/execution.py | 26 +++- comfy/cmd/main.py | 15 ++- comfy/cmd/server.py | 9 +- comfy/component_model/executor_types.py | 3 +- comfy/component_model/queue_types.py | 8 +- comfy/distributed/distributed_prompt_queue.py | 4 +- tests/distributed/test_distributed_queue.py | 115 +++++++++++++++++- 10 files changed, 181 insertions(+), 21 deletions(-) diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index bcfe29d8f..260c75a38 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -871,6 +871,22 @@ components: type: array items: type: string + node_errors: + type: object + description: "Detailed validation errors per node" + additionalProperties: + type: object + properties: + errors: + type: array + items: + $ref: "#/components/schemas/ValidationErrorDict" + dependent_outputs: + type: array + items: + type: string + class_type: + type: string required: - type - details diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index 68ac7d1e3..f1e35e8ef 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -85,7 +85,7 @@ class AsyncRemoteComfyClient: response_json = await response.json() return response_json["prompt_id"] else: - raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}") async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse: """ diff --git a/comfy/client/sdxl_with_refiner_workflow.py b/comfy/client/sdxl_with_refiner_workflow.py index 8a4923b86..e08c877d4 100644 --- a/comfy/client/sdxl_with_refiner_workflow.py +++ b/comfy/client/sdxl_with_refiner_workflow.py @@ -161,7 +161,7 @@ def sdxl_workflow_with_refiner(prompt: str, sampler="euler_ancestral", scheduler="normal", filename_prefix="sdxl_", - seed=42) -> PromptDict: + seed=42) -> dict: prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT) prompt_dict["17"]["inputs"]["text"] = prompt prompt_dict["20"]["inputs"]["text"] = negative_prompt @@ -188,4 +188,4 @@ def sdxl_workflow_with_refiner(prompt: str, prompt_dict["14"]["inputs"]["scheduler"] = scheduler prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix - return Prompt.validate(prompt_dict) + return prompt_dict diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 8095d53be..5ad5865c2 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -1246,16 +1246,36 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty if len(good_outputs) == 0: errors_list = [] + extra_info = {} for o, _errors in errors: for error in _errors: errors_list.append(f"{error['message']}: {error['details']}") + # Aggregate exception_type and traceback from validation errors + if 'extra_info' in error and error['extra_info']: + if 'exception_type' in error['extra_info'] and 'exception_type' not in extra_info: + extra_info['exception_type'] = error['extra_info']['exception_type'] + if 'traceback' in error['extra_info'] and 'traceback' not in extra_info: + extra_info['traceback'] = error['extra_info']['traceback'] + + # Per OpenAPI spec, extra_info must have exception_type and traceback + # For non-exception validation errors, provide synthetic values + if 'exception_type' not in extra_info: + extra_info['exception_type'] = 'ValidationError' + if 'traceback' not in extra_info: + # Capture current stack for validation errors that don't have their own traceback + extra_info['traceback'] = traceback.format_stack() + + # Include detailed node_errors for actionable debugging information + if node_errors: + extra_info['node_errors'] = node_errors + errors_list = "\n".join(errors_list) error = { "type": "prompt_outputs_failed_validation", "message": "Prompt outputs failed validation", "details": errors_list, - "extra_info": {} + "extra_info": extra_info } return ValidationTuple(False, error, list(good_outputs), node_errors) @@ -1301,7 +1321,7 @@ class PromptQueue(AbstractPromptQueue): return copy.deepcopy(item_with_future.queue_tuple), task_id def task_done(self, item_id: str, outputs: HistoryResultDict, - status: Optional[ExecutionStatus]): + status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None): history_result = outputs with self.mutex: queue_item = self.currently_running.pop(item_id) @@ -1311,7 +1331,7 @@ class PromptQueue(AbstractPromptQueue): status_dict = None if status is not None: - status_dict: Optional[ExecutionStatusAsDict] = status.as_dict() + status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details) outputs_ = history_result["outputs"] # Remove sensitive data from extra_data before storing in history diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 182fa2107..a921a8eed 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -79,12 +79,25 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module. await e.execute_async(item[2], prompt_id, item[3], item[4]) need_gc = True + + # Extract error details from status_messages if there's an error + error_details = None + if not e.success: + for event, data in e.status_messages: + if event == "execution_error": + error_details = data + break + + # Convert status_messages tuples to string messages for backward compatibility + messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages] + q.task_done(item_id, e.history_result, status=queue_types.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, - messages=e.status_messages)) + messages=messages), + error_details=error_details) if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index b346a711e..7a4be89a8 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -905,9 +905,13 @@ class PromptServer(ExecutorToClientProgress): if accept == '*/*': accept = "application/json" content_type = request.headers.get("content-type", "application/json") - preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + " " + accept + + # handle media type parameters like "application/json+respond-async" if "+" in content_type: content_type = content_type.split("+")[0] + if "+" in accept: + accept = accept.split("+")[0] wait = not "respond-async" in preferences @@ -993,7 +997,8 @@ class PromptServer(ExecutorToClientProgress): return web.Response(body=str(ex), status=500) if result.status is not None and result.status.status_str == "error": - return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json") + status_dict = result.status.as_dict(error_details=result.error_details) + return web.Response(body=json.dumps(status_dict), status=500, content_type="application/json") # find images and read them output_images: List[FileOutput] = [] for node_id, node in result.outputs.items(): diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 1b3d5de4c..34b5819e1 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -199,9 +199,8 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False): input_config: NotRequired[Dict[str, InputTypeSpec]] received_value: NotRequired[Any] linked_node: NotRequired[str] - traceback: NotRequired[list[str]] exception_message: NotRequired[str] - exception_type: NotRequired[str] + node_errors: NotRequired[Dict[str, 'NodeErrorsDictValue']] class ValidationErrorDict(TypedDict): diff --git a/comfy/component_model/queue_types.py b/comfy/component_model/queue_types.py index 21be0e5f2..f5189cc64 100644 --- a/comfy/component_model/queue_types.py +++ b/comfy/component_model/queue_types.py @@ -18,6 +18,7 @@ class TaskInvocation(NamedTuple): item_id: int | str outputs: OutputsDict status: Optional[ExecutionStatus] + error_details: Optional['ExecutionErrorMessage'] = None class ExecutionStatus(NamedTuple): @@ -25,12 +26,15 @@ class ExecutionStatus(NamedTuple): completed: bool messages: List[str] - def as_dict(self) -> ExecutionStatusAsDict: - return { + def as_dict(self, error_details: Optional['ExecutionErrorMessage'] = None) -> ExecutionStatusAsDict: + result: ExecutionStatusAsDict = { "status_str": self.status_str, "completed": self.completed, "messages": copy.copy(self.messages), } + if error_details is not None: + result["error_details"] = error_details + return result class ExecutionError(RuntimeError): diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 486657347..894921057 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -162,7 +162,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): return item, item[1] - def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]): + def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None): # callee: executed on the worker thread if "outputs" in outputs: outputs: HistoryResultDict @@ -173,7 +173,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): assert pending.completed is not None assert not pending.completed.done() # finish the task. status will transmit the errors in comfy's domain-specific way - pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status)) + pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status, error_details=error_details)) # todo: the caller is responsible for sending a websocket message right now that the UI expects for updates def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]: diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 098072645..3d8757824 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -245,11 +245,6 @@ async def test_two_workers_distinct_requests(): assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}" -# ============================================================================ -# API Error Reporting Tests -# ============================================================================ - - @pytest.mark.asyncio async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq): """Test error reporting with blocking request (no async preference)""" @@ -416,7 +411,7 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi """Test that validation errors return proper ValidationErrorDict structure""" async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: # Create an invalid prompt (invalid checkpoint name) - prompt = sdxl_workflow_with_refiner("test", 1, 1) + prompt = sdxl_workflow_with_refiner("test", "", 1, refiner_steps=1) prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors" prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) @@ -436,11 +431,37 @@ async def test_api_validation_error_structure(frontend_backend_worker_with_rabbi assert "details" in error_body, "Missing 'details'" assert "extra_info" in error_body, "Missing 'extra_info'" + assert error_body["type"] == "prompt_outputs_failed_validation", "unexpected type" + # extra_info should have exception_type and traceback assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info" assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info" assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list" + # extra_info should have node_errors with detailed validation information + assert "node_errors" in error_body["extra_info"], "Missing 'node_errors' in extra_info" + node_errors = error_body["extra_info"]["node_errors"] + assert isinstance(node_errors, dict), "node_errors should be a dict" + assert len(node_errors) > 0, "node_errors should contain at least one node" + + # Verify node_errors structure for node "4" (CheckpointLoaderSimple with invalid ckpt_name) + assert "4" in node_errors, "Node '4' should have validation errors" + node_4_errors = node_errors["4"] + assert "errors" in node_4_errors, "Node '4' should have 'errors' field" + assert "class_type" in node_4_errors, "Node '4' should have 'class_type' field" + assert "dependent_outputs" in node_4_errors, "Node '4' should have 'dependent_outputs' field" + + assert node_4_errors["class_type"] == "CheckpointLoaderSimple", "Node '4' class_type should be CheckpointLoaderSimple" + assert len(node_4_errors["errors"]) > 0, "Node '4' should have at least one error" + + # Verify the error details include the validation error type and message + first_error = node_4_errors["errors"][0] + assert "type" in first_error, "Error should have 'type' field" + assert "message" in first_error, "Error should have 'message' field" + assert "details" in first_error, "Error should have 'details' field" + assert first_error["type"] == "value_not_in_list", f"Expected 'value_not_in_list' error, got {first_error['type']}" + assert "fake.safetensors" in first_error["details"], "Error details should mention 'fake.safetensors'" + @pytest.mark.asyncio async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq): @@ -505,3 +526,85 @@ async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_w # Should not be "prompt", "outputs", "status" assert key not in ["prompt", "status"], \ f"GET endpoint should return Outputs directly, not history entry. Found key: {key}" + + +@pytest.mark.asyncio +async def test_api_execution_error_blocking_mode(frontend_backend_worker_with_rabbitmq): + """Test that execution errors (not validation) return proper error structure in blocking mode""" + from comfy_execution.graph_utils import GraphBuilder + + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt that will fail during execution (not validation) + # Use Regex with a group name that doesn't exist - validation passes but execution fails + g = GraphBuilder() + regex_match = g.node("Regex", pattern="hello", string="hello world") + # Request a non-existent group name - this will pass validation but fail during execution + match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group") + g.node("SaveString", value=match_group.out(0), filename_prefix="test") + + prompt = g.finalize() + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + # Execution errors return 500 + assert response.status == 500, f"Expected 500 for execution error, got {response.status}" + + error_body = await response.json() + + # Verify ExecutionStatus structure + assert "status_str" in error_body, "Missing 'status_str'" + assert "completed" in error_body, "Missing 'completed'" + assert "messages" in error_body, "Missing 'messages'" + + assert error_body["status_str"] == "error", f"Expected 'error', got {error_body['status_str']}" + assert error_body["completed"] == False, "completed should be False for errors" + assert isinstance(error_body["messages"], list), "messages should be a list" + assert len(error_body["messages"]) > 0, "messages should contain error details" + + +@pytest.mark.asyncio +async def test_api_execution_error_async_mode(frontend_backend_worker_with_rabbitmq): + """Test that execution errors return proper error structure in respond-async mode""" + from comfy_execution.graph_utils import GraphBuilder + + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt that will fail during execution (not validation) + # Use Regex with a group name that doesn't exist - validation passes but execution fails + g = GraphBuilder() + regex_match = g.node("Regex", pattern="hello", string="hello world") + # Request a non-existent group name - this will pass validation but fail during execution + match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group") + g.node("SaveString", value=match_group.out(0), filename_prefix="test") + + prompt = g.finalize() + + # Queue with respond-async + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None, "Should get task_id in async mode" + + # Poll for completion + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + # In async mode with polling, errors come back as 200 with error in the response body + # because the prompt was accepted (202) and we're just retrieving the completed result + assert status_code in (200, 500), f"Expected 200 or 500, got {status_code}" + + if status_code == 500: + # Error returned directly - should be ExecutionStatus + assert "status_str" in result, "Missing 'status_str'" + assert "completed" in result, "Missing 'completed'" + assert "messages" in result, "Missing 'messages'" + assert result["status_str"] == "error" + assert result["completed"] == False + assert len(result["messages"]) > 0 + else: + # Error in successful response - result might be ExecutionStatus or empty outputs + # If it's a dict with status info, verify it + if "status_str" in result: + assert result["status_str"] == "error" + assert result["completed"] == False + assert len(result["messages"]) > 0 From 2f520a4cb4bda5af177a769c497b96e5b3b55e75 Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:27:25 -0800 Subject: [PATCH 3/9] Workflow templates constraint --- .../vendor/aiohttp_server_instrumentation.py | 271 ------------------ pyproject.toml | 11 +- tests/distributed/test_tracing.py | 0 3 files changed, 8 insertions(+), 274 deletions(-) delete mode 100644 comfy/vendor/aiohttp_server_instrumentation.py create mode 100644 tests/distributed/test_tracing.py diff --git a/comfy/vendor/aiohttp_server_instrumentation.py b/comfy/vendor/aiohttp_server_instrumentation.py deleted file mode 100644 index 5e334c49f..000000000 --- a/comfy/vendor/aiohttp_server_instrumentation.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright 2020, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import urllib -from timeit import default_timer -from typing import Dict, List, Tuple, Union - -from aiohttp import web -from multidict import CIMultiDictProxy - -from opentelemetry import metrics, trace -_instruments = ("aiohttp ~= 3.0",) -__version__ = "0.49b0.dev" -from opentelemetry.instrumentation.instrumentor import BaseInstrumentor -from opentelemetry.instrumentation.utils import ( - http_status_to_status_code, - is_http_instrumentation_enabled, -) -from opentelemetry.propagate import extract -from opentelemetry.propagators.textmap import Getter -from opentelemetry.semconv.metrics import MetricInstruments -from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace.status import Status, StatusCode -from opentelemetry.util.http import get_excluded_urls, remove_url_credentials - -_duration_attrs = [ - SpanAttributes.HTTP_METHOD, - SpanAttributes.HTTP_HOST, - SpanAttributes.HTTP_SCHEME, - SpanAttributes.HTTP_STATUS_CODE, - SpanAttributes.HTTP_FLAVOR, - SpanAttributes.HTTP_SERVER_NAME, - SpanAttributes.NET_HOST_NAME, - SpanAttributes.NET_HOST_PORT, - SpanAttributes.HTTP_ROUTE, -] - -_active_requests_count_attrs = [ - SpanAttributes.HTTP_METHOD, - SpanAttributes.HTTP_HOST, - SpanAttributes.HTTP_SCHEME, - SpanAttributes.HTTP_FLAVOR, - SpanAttributes.HTTP_SERVER_NAME, -] - -tracer = trace.get_tracer(__name__) -meter = metrics.get_meter(__name__, __version__) -_excluded_urls = get_excluded_urls("AIOHTTP_SERVER") - - -def _parse_duration_attrs(req_attrs): - duration_attrs = {} - for attr_key in _duration_attrs: - if req_attrs.get(attr_key) is not None: - duration_attrs[attr_key] = req_attrs[attr_key] - return duration_attrs - - -def _parse_active_request_count_attrs(req_attrs): - active_requests_count_attrs = {} - for attr_key in _active_requests_count_attrs: - if req_attrs.get(attr_key) is not None: - active_requests_count_attrs[attr_key] = req_attrs[attr_key] - return active_requests_count_attrs - - -def get_default_span_details(request: web.Request) -> Tuple[str, dict]: - """Default implementation for get_default_span_details - Args: - request: the request object itself. - Returns: - a tuple of the span name, and any attributes to attach to the span. - """ - span_name = request.path.strip() or f"HTTP {request.method}" - return span_name, {} - - -def _get_view_func(request: web.Request) -> str: - """Returns the name of the request handler. - Args: - request: the request object itself. - Returns: - a string containing the name of the handler function - """ - try: - return request.match_info.handler.__name__ - except AttributeError: - return "unknown" - - -def collect_request_attributes(request: web.Request) -> Dict: - """Collects HTTP request attributes from the ASGI scope and returns a - dictionary to be used as span creation attributes.""" - - server_host, port, http_url = ( - request.url.host, - request.url.port, - str(request.url), - ) - query_string = request.query_string - if query_string and http_url: - if isinstance(query_string, bytes): - query_string = query_string.decode("utf8") - http_url += "?" + urllib.parse.unquote(query_string) - - result = { - SpanAttributes.HTTP_SCHEME: request.scheme, - SpanAttributes.HTTP_HOST: server_host, - SpanAttributes.NET_HOST_PORT: port, - SpanAttributes.HTTP_ROUTE: _get_view_func(request), - SpanAttributes.HTTP_FLAVOR: f"{request.version.major}.{request.version.minor}", - SpanAttributes.HTTP_TARGET: request.path, - SpanAttributes.HTTP_URL: remove_url_credentials(http_url), - } - - http_method = request.method - if http_method: - result[SpanAttributes.HTTP_METHOD] = http_method - - http_host_value_list = ( - [request.host] if not isinstance(request.host, list) else request.host - ) - if http_host_value_list: - result[SpanAttributes.HTTP_SERVER_NAME] = ",".join( - http_host_value_list - ) - http_user_agent = request.headers.get("user-agent") - if http_user_agent: - result[SpanAttributes.HTTP_USER_AGENT] = http_user_agent - - # remove None values - result = {k: v for k, v in result.items() if v is not None} - - return result - - -def set_status_code(span, status_code: int) -> None: - """Adds HTTP response attributes to span using the status_code argument.""" - - try: - status_code = int(status_code) - except ValueError: - span.set_status( - Status( - StatusCode.ERROR, - "Non-integer HTTP status: " + repr(status_code), - ) - ) - else: - span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status_code) - span.set_status( - Status(http_status_to_status_code(status_code, server_span=True)) - ) - - -class AiohttpGetter(Getter): - """Extract current trace from headers""" - - def get(self, carrier, key: str) -> Union[List, None]: - """Getter implementation to retrieve an HTTP header value from the ASGI - scope. - - Args: - carrier: ASGI scope object - key: header name in scope - Returns: - A list of all header values matching the key, or None if the key - does not match any header. - """ - headers: CIMultiDictProxy = carrier.headers - if not headers: - return None - return headers.getall(key, None) - - def keys(self, carrier: Dict) -> List: - return list(carrier.keys()) - - -getter = AiohttpGetter() - - -@web.middleware -async def middleware(request, handler): - """Middleware for aiohttp implementing tracing logic""" - if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled( - request.url.path - ): - return await handler(request) - - span_name, additional_attributes = get_default_span_details(request) - - req_attrs = collect_request_attributes(request) - duration_attrs = _parse_duration_attrs(req_attrs) - active_requests_count_attrs = _parse_active_request_count_attrs(req_attrs) - - duration_histogram = meter.create_histogram( - name=MetricInstruments.HTTP_SERVER_DURATION, - unit="ms", - description="Measures the duration of inbound HTTP requests.", - ) - - active_requests_counter = meter.create_up_down_counter( - name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS, - unit="requests", - description="measures the number of concurrent HTTP requests those are currently in flight", - ) - - with tracer.start_as_current_span( - span_name, - context=extract(request, getter=getter), - kind=trace.SpanKind.SERVER, - ) as span: - attributes = collect_request_attributes(request) - attributes.update(additional_attributes) - span.set_attributes(attributes) - start = default_timer() - active_requests_counter.add(1, active_requests_count_attrs) - try: - resp = await handler(request) - set_status_code(span, resp.status) - except web.HTTPException as ex: - set_status_code(span, ex.status_code) - raise - except AttributeError: - # No response was returned or a NoneType response was returned, handle gracefully - set_status_code(span, 500) - raise - finally: - duration = max((default_timer() - start) * 1000, 0) - duration_histogram.record(duration, duration_attrs) - active_requests_counter.add(-1, active_requests_count_attrs) - return resp - - -class _InstrumentedApplication(web.Application): - """Insert tracing middleware""" - - def __init__(self, *args, **kwargs): - middlewares = kwargs.pop("middlewares", []) - middlewares.insert(0, middleware) - kwargs["middlewares"] = middlewares - super().__init__(*args, **kwargs) - - -class AioHttpServerInstrumentor(BaseInstrumentor): - # pylint: disable=protected-access,attribute-defined-outside-init - """An instrumentor for aiohttp.web.Application - - See `BaseInstrumentor` - """ - - def _instrument(self, **kwargs): - self._original_app = web.Application - setattr(web, "Application", _InstrumentedApplication) - - def _uninstrument(self, **kwargs): - setattr(web, "Application", self._original_app) - - def instrumentation_dependencies(self): - return _instruments diff --git a/pyproject.toml b/pyproject.toml index 65f42c34d..f6e6f375c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ dependencies = [ "comfyui-frontend-package>=1.28.7", - "comfyui-workflow-templates>=0.1.95", + "comfyui-workflow-templates>=0.1.95,<0.3.0", "comfyui-embedded-docs>=0.3.0", "torch", "torchvision", @@ -65,13 +65,18 @@ dependencies = [ "natsort", "OpenEXR", "opentelemetry-distro", - "opentelemetry-sdk<1.34.0", - "opentelemetry-exporter-otlp<=1.27.0", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp", "opentelemetry-propagator-jaeger", "opentelemetry-instrumentation", "opentelemetry-util-http", "opentelemetry-instrumentation-aio-pika", "opentelemetry-instrumentation-requests", + "opentelemetry-instrumentation-aiohttp-server", + "opentelemetry-instrumentation-aiohttp-client", + "opentelemetry-instrumentation-asyncio", + "opentelemetry-instrumentation-urllib3", + "opentelemetry-processor-baggage", "opentelemetry-semantic-conventions", "wrapt>=1.16.0", "certifi", diff --git a/tests/distributed/test_tracing.py b/tests/distributed/test_tracing.py new file mode 100644 index 000000000..e69de29bb From 69d8f1b120229740337d0c345b4ef7d3de621a8f Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:27:31 -0800 Subject: [PATCH 4/9] Tracing tests --- comfy/client/aio_client.py | 66 ++- comfy/client/sdxl_with_refiner_workflow.py | 2 - comfy/cmd/main_pre.py | 17 +- comfy/tracing_compatibility.py | 6 +- tests/conftest.py | 12 +- tests/distributed/test_tracing.py | 128 +++++ tests/distributed/test_tracing_integration.py | 497 ++++++++++++++++++ 7 files changed, 688 insertions(+), 40 deletions(-) create mode 100644 tests/distributed/test_tracing_integration.py diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index f1e35e8ef..f79f44ac1 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -1,23 +1,23 @@ -from asyncio import AbstractEventLoop -from collections import defaultdict - -import aiohttp import asyncio import uuid -from aiohttp import WSMessage, ClientResponse, ClientTimeout -from pathlib import Path +from asyncio import AbstractEventLoop from typing import Optional, List from urllib.parse import urlparse, urljoin +import aiohttp +from aiohttp import WSMessage, ClientResponse, ClientTimeout +from opentelemetry import trace + from .client_types import V1QueuePromptResponse from ..api.api_client import JSONEncoder from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt_request import PromptRequest from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict from ..api.schemas import immutabledict -from ..component_model.file_output_path import file_output_path from ..component_model.outputs_types import OutputsDict +tracer = trace.get_tracer(__name__) + class AsyncRemoteComfyClient: """ @@ -57,6 +57,27 @@ class AsyncRemoteComfyClient: def session(self) -> aiohttp.ClientSession: return self._ensure_session() + def _build_headers(self, accept_header: str, prefer_header: Optional[str] = None, content_type: str = "application/json") -> dict: + """Build HTTP headers for requests.""" + headers = {'Content-Type': content_type, 'Accept': accept_header} + if prefer_header: + headers['Prefer'] = prefer_header + return headers + + @tracer.start_as_current_span("Post Prompt") + async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse: + """ + Common method to POST a prompt to a given endpoint. + :param prompt: The prompt to send + :param endpoint: The API endpoint (e.g., "/api/v1/prompts") + :param accept_header: The Accept header value + :param prefer_header: Optional Prefer header value + :return: The response object + """ + prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) + headers = self._build_headers(accept_header, prefer_header) + return await self.session.post(urljoin(self.server_address, endpoint), data=prompt_json, headers=headers) + async def len_queue(self) -> int: async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response: if response.status == 200: @@ -73,14 +94,7 @@ class AsyncRemoteComfyClient: :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the task ID """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - headers = {'Content-Type': 'application/json', 'Accept': accept_header} - if prefer_header: - headers['Prefer'] = prefer_header - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers=headers) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response: if 200 <= response.status < 400: response_json = await response.json() return response_json["prompt_id"] @@ -95,14 +109,7 @@ class AsyncRemoteComfyClient: :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - headers = {'Content-Type': 'application/json', 'Accept': accept_header} - if prefer_header: - headers['Prefer'] = prefer_header - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers=headers) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response: if 200 <= response.status < 400: return V1QueuePromptResponse(**(await response.json())) else: @@ -122,17 +129,13 @@ class AsyncRemoteComfyClient: :param prompt: :return: """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - headers = {'Content-Type': 'application/json', 'Accept': 'image/png'} - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers=headers) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", "image/png") as response: if 200 <= response.status < 400: return await response.read() else: raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + @tracer.start_as_current_span("Post Prompt (UI)") async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict: """ Uses the comfyui UI API calls to retrieve the outputs dictionary @@ -179,6 +182,7 @@ class AsyncRemoteComfyClient: """ return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}")) + @tracer.start_as_current_span("Poll Prompt Until Done") async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]: """ Poll a prompt until it's done (200), errors (500), or times out. @@ -187,6 +191,10 @@ class AsyncRemoteComfyClient: :param poll_interval: Time to wait between polls in seconds :return: Tuple of (status_code, response_json or None) """ + span = trace.get_current_span() + span.set_attribute("prompt_id", prompt_id) + span.set_attribute("max_attempts", max_attempts) + for _ in range(max_attempts): async with await self.get_prompt_status(prompt_id) as response: if response.status == 200: diff --git a/comfy/client/sdxl_with_refiner_workflow.py b/comfy/client/sdxl_with_refiner_workflow.py index e08c877d4..d7dab309a 100644 --- a/comfy/client/sdxl_with_refiner_workflow.py +++ b/comfy/client/sdxl_with_refiner_workflow.py @@ -1,8 +1,6 @@ import copy from typing import TypeAlias, Union -from ..api.components.schema.prompt import PromptDict, Prompt - JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None] _BASE_PROMPT: JSON = { "4": { diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 5a8d0f7f8..249482188 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -15,6 +15,7 @@ import shutil import warnings import fsspec +from opentelemetry.instrumentation.urllib3 import URLLib3Instrumentor from .. import options from ..app import logger @@ -125,9 +126,11 @@ def _create_tracer(): from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter + from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS + from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor + from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor from ..tracing_compatibility import ProgressSpanSampler from ..tracing_compatibility import patch_spanbuilder_set_channel - from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor resource = Resource.create({ service_attributes.SERVICE_NAME: args.otel_service_name, @@ -141,18 +144,24 @@ def _create_tracer(): has_endpoint = args.otel_exporter_otlp_endpoint is not None if has_endpoint: - otlp_exporter = OTLPSpanExporter() + exporter = OTLPSpanExporter() else: - otlp_exporter = SpanExporter() + exporter = SpanExporter() - processor = BatchSpanProcessor(otlp_exporter) + processor = BatchSpanProcessor(exporter) provider.add_span_processor(processor) # enable instrumentation patch_spanbuilder_set_channel() + AioPikaInstrumentor().instrument() AioHttpServerInstrumentor().instrument() + AioHttpClientInstrumentor().instrument() RequestsInstrumentor().instrument() + URLLib3Instrumentor().instrument() + + + provider.add_span_processor(BaggageSpanProcessor(ALLOW_ALL_BAGGAGE_KEYS)) # makes this behave better as a library return trace.get_tracer(args.otel_service_name, tracer_provider=provider) diff --git a/comfy/tracing_compatibility.py b/comfy/tracing_compatibility.py index 6c0ef16f5..a77f1e79b 100644 --- a/comfy/tracing_compatibility.py +++ b/comfy/tracing_compatibility.py @@ -3,7 +3,7 @@ from typing import Optional, Sequence from aio_pika.abc import AbstractChannel from opentelemetry.context import Context from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision -from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.semconv.attributes.network_attributes import NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT from opentelemetry.trace import SpanKind, Link, TraceState from opentelemetry.util.types import Attributes @@ -22,8 +22,8 @@ def patch_spanbuilder_set_channel() -> None: port = url.port or 5672 self._attributes.update( { - SpanAttributes.NET_PEER_NAME: url.host, - SpanAttributes.NET_PEER_PORT: port, + NETWORK_PEER_ADDRESS: url.host, + NETWORK_PEER_PORT: port, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 33ae98f9d..931ec8deb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,6 +90,14 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers params = rabbitmq.get_connection_params() connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + # Check if OTEL endpoint is configured for integration testing + otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + + env = os.environ.copy() + if otel_endpoint: + env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otel_endpoint + logging.info(f"Configuring services to export traces to: {otel_endpoint}") + frontend_command = [ "comfyui", "--listen=127.0.0.1", @@ -100,7 +108,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers f"--distributed-queue-connection-uri={connection_uri}", ] - processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) + processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr, env=env)) # Start multiple workers for i in range(num_workers): @@ -111,7 +119,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers f"--distributed-queue-connection-uri={connection_uri}", f"--executor-factory={executor_factory}" ] - processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) + processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr, env=env)) try: server_address = f"http://127.0.0.1:19001" diff --git a/tests/distributed/test_tracing.py b/tests/distributed/test_tracing.py index e69de29bb..8c156a9bf 100644 --- a/tests/distributed/test_tracing.py +++ b/tests/distributed/test_tracing.py @@ -0,0 +1,128 @@ +import asyncio +import logging + +logging.basicConfig(level=logging.ERROR) + +import uuid + +import pytest +from testcontainers.rabbitmq import RabbitMqContainer +from opentelemetry import trace, propagate, context +from opentelemetry.trace import SpanKind +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner +from comfy.component_model.make_mutable import make_mutable +from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus +from comfy.distributed.server_stub import ServerStub + + +async def create_test_prompt() -> QueueItem: + from comfy.cmd.execution import validate_prompt + + prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) + item_id = str(uuid.uuid4()) + + validation_tuple = await validate_prompt(item_id, prompt) + queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) + return QueueItem(queue_tuple, None) + + +@pytest.mark.asyncio +async def test_rabbitmq_message_properties_contain_trace_context(): + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + import aio_pika + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = trace.get_tracer(__name__, tracer_provider=provider) + + with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER): + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: + queue_item = await create_test_prompt() + + put_task = asyncio.create_task(frontend.put_async(queue_item)) + + incoming, incoming_prompt_id = await worker.get_async(timeout=5.0) + assert incoming is not None, "Worker should receive message" + + worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, [])) + + result = await put_task + assert result is not None, "Frontend should get result" + + # Now inspect the RabbitMQ queue directly to see message structure + connection = await aio_pika.connect_robust(connection_uri) + channel = await connection.channel() + + # Declare a test queue to inspect message format + test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True) + + # Publish a test message with trace context + carrier = {} + propagate.inject(carrier) + + test_message = aio_pika.Message( + body=b"test", + headers=carrier + ) + + await channel.default_exchange.publish( + test_message, + routing_key=test_queue.name + ) + + # Get and inspect the message + received = await test_queue.get(timeout=2, fail=False) + if received: + headers = received.headers or {} + + # Document what trace headers should be present + # OpenTelemetry uses 'traceparent' header for W3C Trace Context + has_traceparent = "traceparent" in headers + + assert has_traceparent + + await received.ack() + + await connection.close() + + +@pytest.mark.asyncio +async def test_distributed_queue_uses_async_interface(): + """ + Test that demonstrates the correct way to use DistributedPromptQueue in async context. + The synchronous get() method cannot be used in async tests due to event loop assertions. + """ + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: + queue_item = await create_test_prompt() + + # Start consuming in background + result_future = asyncio.create_task(frontend.put_async(queue_item)) + + # Worker gets item asynchronously (not using blocking get()) + incoming, incoming_prompt_id = await worker.get_async(timeout=5.0) + assert incoming is not None, "Should receive a queue item" + + # Complete the work + worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, [])) + + # Wait for frontend to complete + result = await result_future + assert result is not None, "Should get result from worker" + assert result.status.status_str == "success" diff --git a/tests/distributed/test_tracing_integration.py b/tests/distributed/test_tracing_integration.py new file mode 100644 index 000000000..412b14c02 --- /dev/null +++ b/tests/distributed/test_tracing_integration.py @@ -0,0 +1,497 @@ +""" +Integration tests for distributed tracing across RabbitMQ and services. + +These tests validate that trace context propagates correctly from frontend +to backend workers through RabbitMQ, and that Jaeger can reconstruct the +full distributed trace. +""" +import asyncio +import logging +import os +import time +import uuid + +import pytest +import requests +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.semconv.attributes import service_attributes +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs + +from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class JaegerContainer(DockerContainer): + """Testcontainer for Jaeger all-in-one with OTLP support.""" + + def __init__(self, image: str = "jaegertracing/all-in-one:latest"): + super().__init__(image) + self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP + self.with_env("COLLECTOR_OTLP_ENABLED", "true") + + def get_query_url(self) -> str: + """Get Jaeger Query API URL.""" + host = self.get_container_host_ip() + port = self.get_exposed_port(16686) + return f"http://{host}:{port}" + + def get_otlp_endpoint(self) -> str: + """Get OTLP HTTP endpoint for sending traces.""" + host = self.get_container_host_ip() + port = self.get_exposed_port(4318) + return f"http://{host}:{port}" + + def start(self): + super().start() + wait_for_logs(self, ".*Starting GRPC server.*", timeout=30) + return self + + +@pytest.fixture(scope="module") +def jaeger_container(): + """ + Provide a Jaeger container for collecting traces. + + This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the + Jaeger container, and cleans it up when the container stops. + """ + container = JaegerContainer() + container.start() + + # Wait for Jaeger to be fully ready + query_url = container.get_query_url() + otlp_endpoint = container.get_otlp_endpoint() + + for _ in range(30): + try: + response = requests.get(f"{query_url}/api/services") + if response.status_code == 200: + logger.info(f"Jaeger ready at {query_url}") + logger.info(f"OTLP endpoint: {otlp_endpoint}") + break + except Exception: + pass + time.sleep(1) + + # Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test + old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint + logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}") + + try: + yield container + finally: + # Restore original OTEL_EXPORTER_OTLP_ENDPOINT + if old_endpoint is not None: + os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint + logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}") + else: + os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None) + logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT") + + container.stop() + + +def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None, + lookback: str = "1h", limit: int = 100) -> dict: + """ + Query Jaeger for traces. + + Args: + jaeger_url: Base URL of Jaeger query service + service: Service name to query + operation: Optional operation name filter + lookback: Lookback period (e.g., "1h", "30m") + limit: Maximum number of traces to return + + Returns: + JSON response from Jaeger API + """ + params = { + "service": service, + "lookback": lookback, + "limit": limit + } + if operation: + params["operation"] = operation + + response = requests.get(f"{jaeger_url}/api/traces", params=params) + response.raise_for_status() + return response.json() + + +def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict: + """Find a specific trace by operation name.""" + for trace in traces_response.get("data", []): + for span in trace.get("spans", []): + if span.get("operationName") == operation_name: + return trace + return None + + +def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool: + """ + Verify that a trace spans multiple services and maintains parent-child relationships. + + Args: + trace: Jaeger trace object + expected_services: List of service names expected in the trace + + Returns: + True if trace shows proper distributed tracing across services + """ + if not trace: + return False + + spans = trace.get("spans", []) + if not spans: + return False + + # Check that all expected services are present + trace_services = set() + for span in spans: + process_id = span.get("processID") + if process_id: + process = trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName") + if service_name: + trace_services.add(service_name) + + logger.info(f"Trace contains services: {trace_services}") + logger.info(f"Expected services: {set(expected_services)}") + + # Verify all expected services are present + for service in expected_services: + if service not in trace_services: + logger.warning(f"Expected service '{service}' not found in trace") + return False + + # Verify all spans share the same trace ID + trace_ids = set(span.get("traceID") for span in spans) + if len(trace_ids) != 1: + logger.warning(f"Multiple trace IDs found: {trace_ids}") + return False + + # Verify parent-child relationships exist + span_ids = {span.get("spanID") for span in spans} + has_parent_refs = False + + for span in spans: + references = span.get("references", []) + for ref in references: + if ref.get("refType") == "CHILD_OF": + parent_span_id = ref.get("spanID") + if parent_span_id in span_ids: + has_parent_refs = True + logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}") + + if not has_parent_refs: + logger.warning("No parent-child relationships found in trace") + return False + + return True + + +# order matters, execute jaeger_container first +@pytest.mark.asyncio +async def test_tracing_integration(jaeger_container, frontend_backend_worker_with_rabbitmq): + """ + Integration test for distributed tracing across services. + + This test: + 1. Starts ComfyUI frontend and worker with RabbitMQ + 2. Configures OTLP export to Jaeger testcontainer + 3. Submits a workflow through the frontend + 4. Queries Jaeger to verify trace propagation + 5. Validates that the trace spans multiple services with proper relationships + + Note: The frontend_backend_worker_with_rabbitmq fixture is parameterized, + so this test will run with both ThreadPoolExecutor and ProcessPoolExecutor. + """ + server_address = frontend_backend_worker_with_rabbitmq + jaeger_url = jaeger_container.get_query_url() + otlp_endpoint = jaeger_container.get_otlp_endpoint() + + logger.info(f"Frontend server: {server_address}") + logger.info(f"Jaeger UI: {jaeger_url}") + logger.info(f"OTLP endpoint: {otlp_endpoint}") + + # Set up tracing for the async HTTP client + resource = Resource.create({ + service_attributes.SERVICE_NAME: "comfyui-client", + }) + provider = TracerProvider(resource=resource) + exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces") + processor = BatchSpanProcessor(exporter) + provider.add_span_processor(processor) + from opentelemetry import trace + + trace.set_tracer_provider(provider) + + # Instrument aiohttp client + AioHttpClientInstrumentor().instrument() + + # we have to call this very late, so that the instrumentation isn't initialized too early + from comfy.client.aio_client import AsyncRemoteComfyClient + + # Note: In a real integration test, you'd need to configure the ComfyUI + # services to export traces to this Jaeger instance. For now, this test + # documents the expected behavior. + + # Create a unique prompt to identify our trace + test_id = str(uuid.uuid4())[:8] + prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1) + + # Get the tracer for the client + client_tracer = trace.get_tracer("test_tracing_integration") + + # Submit the workflow - wrap in a span to capture the trace ID + with client_tracer.start_as_current_span("submit_workflow") as workflow_span: + trace_id = format(workflow_span.get_span_context().trace_id, '032x') + logger.info(f"Started trace with trace_id: {trace_id}") + + async with AsyncRemoteComfyClient(server_address=server_address) as client: + logger.info(f"Submitting workflow with test_id: {test_id}") + + # Queue the prompt with async response + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None, "Failed to get task ID" + + logger.info(f"Queued task: {task_id}") + + # Poll for completion + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + assert status_code == 200, f"Task failed with status {status_code}" + logger.info("Task completed successfully") + + # Give Jaeger time to receive and process spans + await asyncio.sleep(5) + + # Query Jaeger for traces + # Note: The actual service names depend on how your services are configured + # Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker" + + expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names + + logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...") + + # First, try to find our specific trace by trace_id from the client service + our_trace = None + for service in expected_services: + try: + traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m") + if traces_response.get("data"): + logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'") + for trace in traces_response["data"]: + if trace.get("traceID") == trace_id: + our_trace = trace + logger.info(f"Found our trace in service '{service}'") + break + if our_trace: + break + except Exception as e: + logger.warning(f"Could not query traces for service '{service}': {e}") + + # Assert we can find the trace we just created + assert our_trace is not None, ( + f"Could not find trace with trace_id {trace_id} in Jaeger. " + f"This indicates that spans from comfyui-client are not being exported correctly." + ) + + logger.info(f"Successfully found trace with trace_id {trace_id}") + + # Extract services from the trace + trace_services = set() + for span in our_trace.get("spans", []): + process_id = span.get("processID") + if process_id: + process = our_trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName") + if service_name: + trace_services.add(service_name) + + logger.info(f"Services found in trace: {trace_services}") + + # Assert that comfyui-client service is present (since we instrumented it) + assert "comfyui-client" in trace_services, ( + f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. " + f"This indicates the client instrumentation is not working." + ) + + # Validate trace structure + logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans") + + # Log all spans for debugging + for span in our_trace.get("spans", []): + process_id = span.get("processID") + process = our_trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName", "unknown") + operation = span.get("operationName", "unknown") + logger.info(f" Span: {service_name}.{operation}") + + # Verify trace continuity - only if both services are present + assert "comfyui" in trace_services + is_continuous = verify_trace_continuity(our_trace, expected_services) + + # This assertion documents what SHOULD happen when distributed tracing works + assert is_continuous, ( + "Trace does not show proper distributed tracing. " + "Expected to see spans from multiple services with parent-child relationships. " + "This indicates that trace context is not being propagated correctly through RabbitMQ." + ) + +@pytest.mark.asyncio +async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq): + """ + Test that HTTP requests include traceparent headers. + + This validates that the HTTP layer is properly instrumented for tracing. + """ + server_address = frontend_backend_worker_with_rabbitmq + + # Make a simple HTTP request and check for trace headers + # Note: We're checking the server's response headers to see if it's trace-aware + response = requests.get(f"{server_address}/system_stats") + + logger.info(f"Response headers: {dict(response.headers)}") + + # The server should be instrumented and may include trace context in responses + # or at minimum, should accept traceparent headers in requests + + # Test sending a traceparent header + test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + response_with_trace = requests.get( + f"{server_address}/system_stats", + headers={"traceparent": test_traceparent} + ) + + # Should not error when traceparent is provided + assert response_with_trace.status_code == 200, "Server should accept traceparent header" + + logger.info("✓ Server accepts traceparent headers in HTTP requests") + + +@pytest.mark.asyncio +async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container): + """ + Test that multiple independent requests create separate traces. + + This validates that trace context is properly scoped per request. + """ + server_address = frontend_backend_worker_with_rabbitmq + + # Submit multiple workflows + task_ids = [] + + from comfy.client.aio_client import AsyncRemoteComfyClient + async with AsyncRemoteComfyClient(server_address=server_address) as client: + for i in range(3): + prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + task_ids.append(task_id) + logger.info(f"Queued task {i}: {task_id}") + + # Wait for all to complete + for i, task_id in enumerate(task_ids): + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + assert status_code == 200, f"Task {i} failed" + logger.info(f"Task {i} completed") + + # Give Jaeger time to receive spans + await asyncio.sleep(5) + + # Query Jaeger and verify we have multiple distinct traces + jaeger_url = jaeger_container.get_query_url() + + try: + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10) + traces = traces_response.get("data", []) + + if len(traces) >= 2: + # Get trace IDs + trace_ids = [trace.get("traceID") for trace in traces] + unique_trace_ids = set(trace_ids) + + logger.info(f"Found {len(unique_trace_ids)} unique traces") + + # Verify we have multiple distinct traces + assert len(unique_trace_ids) >= 2, ( + f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. " + "Each request should create its own trace." + ) + + logger.info("✓ Multiple requests created distinct traces") + else: + pytest.skip("Not enough traces to validate") + except Exception as e: + pytest.skip(f"Could not query Jaeger: {e}") + + +@pytest.mark.asyncio +async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container): + """ + Test that traces include RabbitMQ publish/consume operations. + + This is critical for distributed tracing - the RabbitMQ operations + are what link the frontend and backend spans together. + """ + server_address = frontend_backend_worker_with_rabbitmq + jaeger_url = jaeger_container.get_query_url() + + # Submit a workflow + from comfy.client.aio_client import AsyncRemoteComfyClient + async with AsyncRemoteComfyClient(server_address=server_address) as client: + prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt) + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60) + assert status_code == 200 + + await asyncio.sleep(5) + + try: + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m") + traces = traces_response.get("data", []) + + if traces: + # Look for RabbitMQ-related operations in any trace + rabbitmq_operations = [ + "publish", "consume", "amq_queue_publish", "amq_queue_consume", + "amq.basic.publish", "amq.basic.consume", "send", "receive" + ] + + found_rabbitmq_ops = [] + for trace in traces: + for span in trace.get("spans", []): + op_name = span.get("operationName", "").lower() + for rmq_op in rabbitmq_operations: + if rmq_op in op_name: + found_rabbitmq_ops.append(op_name) + + if found_rabbitmq_ops: + logger.info(f"✓ Found RabbitMQ operations in traces: {set(found_rabbitmq_ops)}") + else: + logger.warning( + "No RabbitMQ operations found in traces. " + "This suggests that either:\n" + "1. AioPikaInstrumentor is not creating spans, or\n" + "2. The spans are being filtered out by the collector, or\n" + "3. The spans exist but use different operation names" + ) + + # Log all operation names to help debug + all_ops = set() + for trace in traces[:3]: # First 3 traces + for span in trace.get("spans", []): + all_ops.add(span.get("operationName")) + logger.info(f"Sample operation names: {all_ops}") + else: + pytest.skip("No traces found") + except Exception as e: + pytest.skip(f"Could not query Jaeger: {e}") From 8700c4fadf6fb10c7ba6f6766dbe1573359a40b5 Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Fri, 7 Nov 2025 16:50:55 -0800 Subject: [PATCH 5/9] wip eval nodes, test tracing with full integration test, fix dockerfile barfing on flash_attn 2.8.3 --- Dockerfile | 4 +- comfy/cli_args.py | 1 + comfy/cli_args_types.py | 3 + comfy_extras/eval_web/__init__.py | 0 comfy_extras/eval_web/ace_utils.js | 769 ++++++++++++++++++ comfy_extras/eval_web/ky_eval_python.js | 377 +++++++++ comfy_extras/nodes/nodes_eval.py | 110 +++ tests/conftest.py | 2 +- tests/distributed/test_tracing_integration.py | 200 +++-- 9 files changed, 1399 insertions(+), 67 deletions(-) create mode 100644 comfy_extras/eval_web/__init__.py create mode 100644 comfy_extras/eval_web/ace_utils.js create mode 100644 comfy_extras/eval_web/ky_eval_python.js create mode 100644 comfy_extras/nodes/nodes_eval.py diff --git a/Dockerfile b/Dockerfile index bb0c04d1b..942cacb57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,7 @@ RUN pip install uv && uv --version && \ # install sageattention ADD pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl /workspace/pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl -RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry flash-attn "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl" +RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry "flash-attn<=2.8.0" "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl" # this exotic command will determine the correct torchaudio to install for the image RUN <<-EOF python -c 'import torch, re, subprocess @@ -66,7 +66,7 @@ WORKDIR /workspace # addresses https://github.com/pytorch/pytorch/issues/104801 # and issues reported by importing nodes_canny # smoke test -RUN python -c "import torch; import xformers; import sageattention; import cv2" && comfyui --quick-test-for-ci --cpu --cwd /workspace +RUN python -c "import torch; import xformers; import sageattention; import cv2; import diffusers.hooks" && comfyui --quick-test-for-ci --cpu --cwd /workspace EXPOSE 8188 CMD ["python", "-m", "comfy.cmd.main", "--listen", "--use-sage-attention", "--reserve-vram=0", "--logging-level=INFO", "--enable-cors"] diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fe964ba45..63157dd24 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -156,6 +156,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.") parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") + parser.add_argument("--enable-eval", action="store_true", help="Enable nodes that can evaluate Python code in workflows.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--create-directories", action="store_true", diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 953747f95..903f46f1a 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -169,6 +169,7 @@ class Configuration(dict): whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled. default_device (Optional[int]): Set the id of the default device, all other devices will stay visible. block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental). + enable_eval (Optional[bool]): Enable nodes that can evaluate Python code in workflows. """ def __init__(self, **kwargs): @@ -288,6 +289,7 @@ class Configuration(dict): self.database_url: str = db_config() self.default_device: Optional[int] = None self.block_runtime_package_installation = None + self.enable_eval: Optional[bool] = False for key, value in kwargs.items(): self[key] = value @@ -420,6 +422,7 @@ class FlattenAndAppendAction(argparse.Action): Custom action to handle comma-separated values and multiple invocations of the same argument, flattening them into a single list. """ + def __call__(self, parser, namespace, values, option_string=None): items = getattr(namespace, self.dest, None) if items is None: diff --git a/comfy_extras/eval_web/__init__.py b/comfy_extras/eval_web/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/eval_web/ace_utils.js b/comfy_extras/eval_web/ace_utils.js new file mode 100644 index 000000000..78c00d809 --- /dev/null +++ b/comfy_extras/eval_web/ace_utils.js @@ -0,0 +1,769 @@ +/** + * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode + * + * MIT License + * + * Copyright (c) 2024 Kevin Yuan + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// Make modal window +function makeModal({ title = "Message", text = "No text", type = "info", parent = null, stylePos = "fixed", classes = [] } = {}) { + const overlay = document.createElement("div"); + Object.assign(overlay.style, { + display: "none", + position: stylePos, + background: "rgba(0 0 0 / 0.8)", + opacity: 0, + top: "0", + left: "0", + right: "0", + bottom: "0", + zIndex: "500", + transition: "all .8s", + cursor: "pointer", + }); + + const boxModal = document.createElement("div"); + Object.assign(boxModal.style, { + transition: "all 0.5s", + opacity: 0, + display: "none", + position: stylePos, + top: "50%", + left: "50%", + transform: "translate(-50%,-50%)", + background: "#525252", + minWidth: "300px", + fontFamily: "sans-serif", + zIndex: "501", + border: "1px solid rgb(255 255 255 / 45%)", + }); + + boxModal.className = "alekpet_modal_window"; + boxModal.classList.add(...classes); + + const boxModalBody = document.createElement("div"); + Object.assign(boxModalBody.style, { + display: "flex", + flexDirection: "column", + textAlign: "center", + }); + + boxModalBody.className = "alekpet_modal_body"; + + const boxModalHtml = ` +
+
${title}
+
+
+
${text}
`; + boxModalBody.innerHTML = boxModalHtml; + + const alekpet_modal_header = boxModalBody.querySelector(".alekpet_modal_header"); + Object.assign(alekpet_modal_header.style, { + display: "flex", + alignItems: "center", + }); + + const close = boxModalBody.querySelector(".alekpet_modal_close"); + Object.assign(close.style, { + cursor: "pointer", + }); + + let parentElement = document.body; + if (parent && parent.nodeType === 1) { + parentElement = parent; + } + + boxModal.append(boxModalBody); + parentElement.append(overlay, boxModal); + + const removeEvent = new Event("removeElements"); + const remove = () => { + animateTransitionProps(boxModal, { opacity: 0 }).then(() => + animateTransitionProps(overlay, { opacity: 0 }).then(() => { + parentElement.removeChild(boxModal); + parentElement.removeChild(overlay); + }), + ); + }; + + boxModal.addEventListener("removeElements", remove); + overlay.addEventListener("removeElements", remove); + + animateTransitionProps(overlay) + .then(() => { + overlay.addEventListener("click", () => { + overlay.dispatchEvent(removeEvent); + }); + animateTransitionProps(boxModal); + }) + .then(() => boxModal.querySelector(".alekpet_modal_close").addEventListener("click", () => boxModal.dispatchEvent(removeEvent))); +} + +function findWidget(node, value, attr = "name", func = "find") { + return node?.widgets ? node.widgets[func]((w) => (Array.isArray(value) ? value.includes(w[attr]) : w[attr] === value)) : null; +} + +function animateTransitionProps(el, props = { opacity: 1 }, preStyles = { display: "block" }) { + Object.assign(el.style, preStyles); + + el.style.transition = !el.style.transition || !window.getComputedStyle(el).getPropertyValue("transition") ? "all .8s" : el.style.transition; + + return new Promise((res) => { + setTimeout(() => { + Object.assign(el.style, props); + + const transstart = () => (el.isAnimating = true); + const transchancel = () => (el.isAnimating = false); + el.addEventListener("transitionstart", transstart); + el.addEventListener("transitioncancel", transchancel); + + el.addEventListener("transitionend", function transend() { + el.isAnimating = false; + el.removeEventListener("transitionend", transend); + el.removeEventListener("transitionend", transchancel); + el.removeEventListener("transitionend", transstart); + res(el); + }); + }, 100); + }); +} + +function animateClick(target, params = {}) { + const { opacityVal = 0.9, callback = () => {} } = params; + if (target?.isAnimating) return; + + const hide = +target.style.opacity === 0; + return animateTransitionProps(target, { + opacity: hide ? opacityVal : 0, + }).then((el) => { + const isHide = hide || el.style.display === "none"; + showHide({ elements: [target], hide: !hide }); + callback(); + return isHide; + }); +} + +function showHide({ elements = [], hide = null, displayProp = "block" } = {}) { + Array.from(elements).forEach((el) => { + if (hide !== null) { + el.style.display = !hide ? displayProp : "none"; + } else { + el.style.display = !el.style.display || el.style.display === "none" ? displayProp : "none"; + } + }); +} + +function isEmptyObject(obj) { + if (!obj) return true; + return Object.keys(obj).length === 0 && obj.constructor === Object; +} + +function makeElement(tag, attrs = {}) { + if (!tag) tag = "div"; + const element = document.createElement(tag); + Object.keys(attrs).forEach((key) => { + const currValue = attrs[key]; + if (key === "class") { + if (Array.isArray(currValue)) { + element.classList.add(...currValue); + } else if (currValue instanceof String || typeof currValue === "string") { + element.className = currValue; + } + } else if (key === "dataset") { + try { + if (Array.isArray(currValue)) { + currValue.forEach((datasetArr) => { + const [prop, propval] = Object.entries(datasetArr)[0]; + element.dataset[prop] = propval; + }); + } else { + Object.entries(currValue).forEach((datasetArr) => { + const [prop, propval] = datasetArr; + element.dataset[prop] = propval; + }); + } + } catch (err) { + console.log(err); + } + } else if (key === "style") { + if (typeof currValue === "object" && !Array.isArray(currValue) && Object.keys(currValue).length) { + Object.assign(element[key], currValue); + } else if (typeof currValue === "object" && Array.isArray(currValue) && currValue.length) { + element[key] = [...currValue]; + } else if (currValue instanceof String || typeof currValue === "string") { + element[key] = currValue; + } + } else if (["for"].includes(key)) { + element.setAttribute(key, currValue); + } else if (key === "children") { + element.append(...(currValue instanceof Array ? currValue : [currValue])); + } else if (key === "parent") { + currValue.append(element); + } else { + element[key] = currValue; + } + }); + return element; +} + +function isValidStyle(opt, strColor) { + let op = new Option().style; + if (!op.hasOwnProperty(opt)) return { result: false, color: "", color_hex: "" }; + + op[opt] = strColor; + + return { + result: op[opt] !== "", + color_rgb: op[opt], + color_hex: rgbToHex(op[opt]), + }; +} + +function rgbToHex(rgb) { + const regEx = new RegExp(/\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)/); + if (regEx.test(rgb)) { + let [, r, g, b] = regEx.exec(rgb); + r = parseInt(r).toString(16); + g = parseInt(g).toString(16); + b = parseInt(b).toString(16); + + r = r.length === 1 ? r + "0" : r; + g = g.length === 1 ? g + "0" : g; + b = b.length === 1 ? b + "0" : b; + + return `#${r}${g}${b}`; + } +} + +async function getDataJSON(url) { + try { + const response = await fetch(url); + const jsonData = await response.json(); + return jsonData; + } catch (err) { + return new Error(err); + } +} + +function deepMerge(target, source) { + if (source?.nodeType) return; + for (let key in source) { + if (source[key] instanceof Object && key in target) { + Object.assign(source[key], deepMerge(target[key], source[key])); + } + } + + Object.assign(target || {}, source); + return target; +} + +const THEME_MODAL_WINDOW_BASE = { + stylesTitle: { + background: "auto", + padding: "5px", + borderRadius: "6px", + marginBottom: "5px", + alignSelf: "stretch", + }, + stylesWrapper: { + display: "none", + opacity: 0, + minWidth: "220px", + position: "absolute", + left: "50%", + top: "50%", + transform: "translate(-50%, -50%)", + transition: "all .8s", + fontFamily: "monospace", + zIndex: 99999, + }, + stylesBox: { + display: "flex", + flexDirection: "column", + background: "#0e0e0e", + padding: "6px", + justifyContent: "center", + alignItems: "center", + gap: "3px", + textAlign: "center", + borderRadius: "6px", + color: "white", + border: "2px solid silver", + boxShadow: "2px 2px 4px silver", + maxWidth: "300px", + }, + stylesClose: { + position: "absolute", + top: "-10px", + right: "-10px", + background: "silver", + borderRadius: "50%", + width: "20px", + height: "20px", + cursor: "pointer", + display: "flex", + justifyContent: "center", + alignItems: "center", + fontSize: "0.8rem", + }, +}; + +const THEMES_MODAL_WINDOW = { + error: { + stylesTitle: { + ...THEME_MODAL_WINDOW_BASE.stylesTitle, + background: "#8f210f", + }, + stylesBox: { + ...THEME_MODAL_WINDOW_BASE.stylesBox, + background: "#3b2222", + boxShadow: "3px 3px 6px #141414", + border: "1px solid #f91b1b", + }, + stylesWrapper: { ...THEME_MODAL_WINDOW_BASE.stylesWrapper }, + stylesClose: { + ...THEME_MODAL_WINDOW_BASE.stylesClose, + background: "#3b2222", + }, + }, + warning: { + stylesTitle: { + ...THEME_MODAL_WINDOW_BASE.stylesTitle, + background: "#e99818", + }, + stylesBox: { + ...THEME_MODAL_WINDOW_BASE.stylesBox, + background: "#594e32", + boxShadow: "3px 3px 6px #141414", + border: "1px solid #e99818", + }, + stylesWrapper: { ...THEME_MODAL_WINDOW_BASE.stylesWrapper }, + stylesClose: { + ...THEME_MODAL_WINDOW_BASE.stylesClose, + background: "#594e32", + }, + }, + normal: { + stylesTitle: { + ...THEME_MODAL_WINDOW_BASE.stylesTitle, + background: "#108f0f", + }, + stylesBox: { + ...THEME_MODAL_WINDOW_BASE.stylesBox, + background: "#223b2a", + boxShadow: "3px 3px 6px #141414", + border: "1px solid #108f0f", + }, + stylesWrapper: { ...THEME_MODAL_WINDOW_BASE.stylesWrapper }, + stylesClose: { + ...THEME_MODAL_WINDOW_BASE.stylesClose, + background: "#223b2a", + }, + }, +}; + +const defaultOptions = { + auto: { + autohide: false, + autoshow: false, + autoremove: false, + propStyles: { opacity: 0 }, + propPreStyles: {}, + timewait: 2000, + }, + overlay: { + overlay_enabled: false, + overlayClasses: [], + overlayStyles: {}, + }, + close: { closeRemove: false, showClose: true }, + parent: null, +}; + +function createWindowModal({ textTitle = "Message", textBody = "Hello world!", textFooter = null, classesWrapper = [], stylesWrapper = {}, classesBox = [], stylesBox = {}, classesTitle = [], stylesTitle = {}, classesBody = [], stylesBody = {}, classesClose = [], stylesClose = {}, classesFooter = [], stylesFooter = {}, options = defaultOptions } = {}) { + // Check all options exist + const _options = deepMerge(JSON.parse(JSON.stringify(defaultOptions)), options); + + const { + parent, + overlay: { overlay_enabled, overlayClasses, overlayStyles }, + close: { closeRemove, showClose }, + auto: { autohide, autoshow, autoremove, timewait, propStyles, propPreStyles }, + } = _options; + + // Function past text(html) + function addText(text, parent) { + if (!parent) return; + + switch (typeof text) { + case "string": + if (/^\<.*\/?\>$/.test(text)) { + parent.innerHTML = text; + } else { + parent.textContent = text; + } + break; + case "object": + default: + if (Array.isArray(text)) { + text.forEach((element) => (element.nodeType === 1 || element.nodeType === 3) && parent.append(element)); + } else if (text.nodeType === 1 || text.nodeType === 3) parent.append(text); + } + } + + // Overlay + let overlayElement = null; + if (overlay_enabled) { + overlayElement = makeElement("div", { + class: [...overlayClasses], + style: { + display: "none", + position: "fixed", + background: "rgba(0 0 0 / 0.8)", + opacity: 0, + top: 0, + left: 0, + right: 0, + bottom: 0, + zIndex: 99999, + transition: "all .8s", + cursor: "pointer", + ...overlayStyles, + }, + }); + } + + // Wrapper + const wrapper_settings = makeElement("div", { + class: ["alekpet__wrapper__window", ...classesWrapper], + }); + + Object.assign(wrapper_settings.style, { + ...THEME_MODAL_WINDOW_BASE.stylesWrapper, + ...stylesWrapper, + }); + + // Box + const box__settings = makeElement("div", { + class: ["alekpet__window__box", ...classesBox], + }); + Object.assign(box__settings.style, { + ...THEME_MODAL_WINDOW_BASE.stylesBox, + ...stylesBox, + }); + + // Title + let box_settings_title = ""; + if (textTitle) { + box_settings_title = makeElement("div", { + class: ["alekpet__window__title", ...classesTitle], + }); + + Object.assign(box_settings_title.style, { + ...THEME_MODAL_WINDOW_BASE.stylesTitle, + ...stylesTitle, + }); + + // Add text (html) to title + addText(textTitle, box_settings_title); + } + // Body + let box_settings_body = ""; + if (textBody) { + box_settings_body = makeElement("div", { + class: ["alekpet__window__body", ...classesBody], + }); + + Object.assign(box_settings_body.style, { + display: "flex", + flexDirection: "column", + alignItems: "flex-end", + gap: "5px", + textWrap: "wrap", + ...stylesBody, + }); + + // Add text (html) to body + addText(textBody, box_settings_body); + } + + // Close button + const close__box__button = makeElement("div", { + class: ["close__box__button", ...classesClose], + textContent: "✖", + }); + + Object.assign(close__box__button.style, { + ...THEME_MODAL_WINDOW_BASE.stylesClose, + ...stylesClose, + }); + + if (!showClose) close__box__button.style.display = "none"; + + const closeEvent = new Event("closeModal"); + const closeModalWindow = function () { + overlay_enabled + ? animateTransitionProps(overlayElement, { + opacity: 0, + }) + .then(() => + animateTransitionProps(wrapper_settings, { + opacity: 0, + }), + ) + .then(() => { + if (closeRemove) { + parent.removeChild(wrapper_settings); + parent.removeChild(overlayElement); + } else { + showHide({ elements: [wrapper_settings, overlayElement] }); + } + }) + : animateTransitionProps(wrapper_settings, { + opacity: 0, + }).then(() => { + showHide({ elements: [wrapper_settings] }); + }); + }; + + close__box__button.addEventListener("closeModal", closeModalWindow); + + close__box__button.addEventListener("click", () => close__box__button.dispatchEvent(closeEvent)); + + close__box__button.onmouseenter = () => { + close__box__button.style.opacity = 0.8; + }; + + close__box__button.onmouseleave = () => { + close__box__button.style.opacity = 1; + }; + + box__settings.append(box_settings_title, box_settings_body); + + // Footer + if (textFooter) { + const box_settings_footer = makeElement("div", { + class: [...classesFooter], + }); + Object.assign(box_settings_footer.style, { + ...stylesFooter, + }); + + // Add text (html) to body + addText(textFooter, box_settings_footer); + + box__settings.append(box_settings_footer); + } + + wrapper_settings.append(close__box__button, box__settings); + + if (parent && parent.nodeType === 1) { + if (overlay_enabled) parent.append(overlayElement); + parent.append(wrapper_settings); + + if (autoshow) { + overlay_enabled + ? animateClick(overlayElement).then(() => + animateClick(wrapper_settings).then( + () => + autohide && + setTimeout( + () => + animateTransitionProps(wrapper_settings, { ...propStyles }, { ...propPreStyles }) + .then(() => animateTransitionProps(overlayElement, { ...propStyles }, { ...propPreStyles })) + .then(() => { + if (autoremove) { + parent.removeChild(wrapper_settings); + parent.removeChild(overlayElement); + } + }), + timewait, + ), + ), + ) + : animateClick(wrapper_settings).then(() => autohide && setTimeout(() => animateTransitionProps(wrapper_settings, { ...propStyles }, { ...propPreStyles }).then(() => autoremove && parent.removeChild(wrapper_settings)), timewait)); + } + } + + return wrapper_settings; +} + +// Prompt +async function comfyuiDesktopPrompt(title, message, defaultValue) { + try { + return await app.extensionManager.dialog.prompt({ + title, + message, + defaultValue, + }); + } catch (err) { + return prompt(title, message); + } +} + +// Alert +function comfyuiDesktopAlert(message) { + try { + app.extensionManager.toast.addAlert(message); + } catch (err) { + alert(message); + } +} + +// Confirm +function confirmModal({ title, message }) { + return new Promise((res) => { + const overlay = makeElement("div", { + class: ["alekpet_confOverlay"], + style: { + background: "rgba(0, 0, 0, 0.7)", + position: "fixed", + top: 0, + left: 0, + right: 0, + bottom: 0, + zIndex: 9999, + userSelect: "none", + }, + }); + + const modal = makeElement("div", { + class: ["alekpet_confModal"], + style: { + ...THEME_MODAL_WINDOW_BASE.stylesBox, + position: "fixed", + top: "50%", + left: "50%", + fontFamily: "monospace", + background: "rgb(92 186 255 / 20%)", + transform: "translate(-50%, -50%)", + borderColor: "rgba(92, 186, 255, 0.63)", + boxShadow: "rgba(92, 186, 255, 0.63) 2px 2px 4px", + }, + }); + + const titleEl = makeElement("div", { + class: ["alekpet_confTitle"], + style: { + ...THEME_MODAL_WINDOW_BASE.stylesTitle, + background: "rgba(92, 186, 255, 0.63)", + }, + textContent: title, + }); + + const messageEl = makeElement("div", { + class: ["alekpet_confMessage"], + style: { + display: "flex", + flexDirection: "column", + alignItems: "flex-end", + gap: "5px", + textWrap: "wrap", + }, + textContent: message, + }); + + const action_box = makeElement("div", { + class: ["alekpet_confActions"], + style: { + display: "flex", + gap: "5px", + width: "100%", + padding: "4px", + justifyContent: "flex-end", + }, + }); + + const remove = () => { + modal.remove(); + overlay.remove(); + }; + + const ok = makeElement("div", { + class: ["alekpet_confButtons", "alekpet_confButtonOk"], + style: { + background: "linear-gradient(45deg, green, limegreen) rgb(21, 100, 6)", + }, + textContent: "Ok", + onclick: (e) => { + res(true); + remove(); + }, + }); + + const Cancel = makeElement("div", { + class: ["alekpet_confButtons", "alekpet_confButtonCancel"], + style: { + background: "linear-gradient(45deg, #b64396, #a52a8b) rgb(135 3 161)", + }, + textContent: "Cancel", + onclick: (e) => { + res(false); + remove(); + }, + }); + + action_box.append(ok, Cancel); + modal.append(titleEl, messageEl, action_box); + overlay.append(modal); + document.body.append(overlay); + }); +} + +async function comfyuiDesktopConfirm(message) { + try { + const result = await confirmModal({ + title: "Confirm", + message: message, + }); + + // Wait update comfyui frontend! Confirm Cancel not return value! Fixed in ComfyUI_frontend ver. v1.10.8 + // https://github.com/Comfy-Org/ComfyUI_frontend/issues/2649 + // const result = await app.extensionManager.dialog.confirm({ + // title: "Confirm", + // message: message, + // }); + return result; + } catch (err) { + return confirm(message); + } +} + +export { + makeModal, + createWindowModal, + animateTransitionProps, + animateClick, + showHide, + makeElement, + getDataJSON, + isEmptyObject, + isValidStyle, + rgbToHex, + findWidget, + THEMES_MODAL_WINDOW, + // + comfyuiDesktopConfirm, + comfyuiDesktopPrompt, + comfyuiDesktopAlert, +}; diff --git a/comfy_extras/eval_web/ky_eval_python.js b/comfy_extras/eval_web/ky_eval_python.js new file mode 100644 index 000000000..3d65aa5c0 --- /dev/null +++ b/comfy_extras/eval_web/ky_eval_python.js @@ -0,0 +1,377 @@ +/** + * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode + * + * MIT License + * + * Copyright (c) 2024 Kevin Yuan + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +import { app } from "../../scripts/app.js"; + +import * as ace from "https://cdn.jsdelivr.net/npm/ace-code@1.43.4/+esm"; +import { makeElement, findWidget } from "./ace_utils.js"; + +// Constants +const varTypes = ["int", "boolean", "string", "float", "json", "list", "dict"]; +const typeMap = { + int: "int", + boolean: "bool", + string: "str", + float: "float", + json: "json", + list: "list", + dict: "dict", +}; + +ace.config.setModuleLoader('ace/mode/python', () => + import('https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src/mode-python.js') +); + +ace.config.setModuleLoader('ace/theme/monokai', () => + import('https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src/theme-monokai.js') +); + +function getPostition(node, ctx, w_width, y, n_height) { + const margin = 5; + + const rect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(rect.width / ctx.canvas.width, rect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); + + return { + transformOrigin: "0 0", + transform: scale, + left: `${transform.a + transform.e + rect.left}px`, + top: `${transform.d + transform.f + rect.top}px`, + maxWidth: `${w_width - margin * 2}px`, + maxHeight: `${n_height - margin * 2 - y - 15}px`, + width: `${w_width - margin * 2}px`, + height: "90%", + position: "absolute", + scrollbarColor: "var(--descrip-text) var(--bg-color)", + scrollbarWidth: "thin", + zIndex: app.graph._nodes.indexOf(node), + }; +} + +// Create editor code +function codeEditor(node, inputName, inputData) { + const widget = { + type: "pycode", + name: inputName, + options: { hideOnZoom: true }, + value: + inputData[1]?.default || + `def my(a, b=1): + return a * b
+ +r0 = str(my(23, 9))`, + draw(ctx, node, widget_width, y, widget_height) { + const hidden = node.flags?.collapsed || (!!widget.options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.type === "converted-widget" || widget.type === "hidden"; + + widget.codeElement.hidden = hidden; + + if (hidden) { + widget.options.onHide?.(widget); + return; + } + + Object.assign(this.codeElement.style, getPostition(node, ctx, widget_width, y, node.size[1])); + }, + computeSize(...args) { + return [500, 250]; + }, + }; + + widget.codeElement = makeElement("pre", { + innerHTML: widget.value, + }); + + widget.editor = ace.edit(widget.codeElement); + widget.editor.setTheme("ace/theme/monokai"); + widget.editor.session.setMode("ace/mode/python"); + widget.editor.setOptions({ + enableAutoIndent: true, + enableLiveAutocompletion: true, + enableBasicAutocompletion: true, + fontFamily: "monospace", + }); + widget.codeElement.hidden = true; + + document.body.appendChild(widget.codeElement); + + const collapse = node.collapse; + node.collapse = function () { + collapse.apply(this, arguments); + if (this.flags?.collapsed) { + widget.codeElement.hidden = true; + } else { + if (this.flags?.collapsed === false) { + widget.codeElement.hidden = false; + } + } + }; + + return widget; +} + +// Save data to workflow forced! +function saveValue() { + app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState(); +} + +// Register extensions +app.registerExtension({ + name: "KYNode.KY_Eval_Python", + getCustomWidgets(app) { + return { + PYCODE: (node, inputName, inputData, app) => { + const widget = codeEditor(node, inputName, inputData); + + widget.editor.getSession().on("change", function (e) { + widget.value = widget.editor.getValue(); + saveValue(); + }); + + const varTypeList = node.addWidget( + "combo", + "select_type", + "string", + (v) => { + // widget.editor.setTheme(`ace/theme/${varTypeList.value}`); + }, + { + values: varTypes, + serialize: false, + }, + ); + + // 6. 使用 addDOMWidget 将容器添加到节点上 + // - 第一个参数是 widget 的名称,在节点内部需要是唯一的。 + // - 第二个参数是 widget 的类型,对于自定义 DOM 元素,通常是 "div"。 + // - 第三个参数是您创建的 DOM 元素。 + // - 第四个参数是一个选项对象,可以用来配置 widget。 + // node.addDOMWidget("rowOfButtons", "div", container, { + // }); + node.addWidget("button", "Add Input variable", "add_input_variable", async () => { + // Input name variable and check + let nameInput = node?.inputs?.length ? `p${node.inputs.length - 1}` : "p0"; + + const currentWidth = node.size[0]; + let tp = varTypeList.value; + nameInput = nameInput + "_" + typeMap[tp]; + node.addInput(nameInput, "*"); + node.setSize([currentWidth, node.size[1]]); + let cv = widget.editor.getValue(); + if (tp === "json") { + cv = cv + "\n" + nameInput + " = json.loads(" + nameInput + ")"; + } else if (tp === "list") { + cv = cv + "\n" + nameInput + " = []"; + } else if (tp === "dict") { + cv = cv + "\n" + nameInput + " = {}"; + } else { + cv = cv + "\n" + nameInput + " = " + typeMap[tp] + "(" + nameInput + ")"; + } + widget.editor.setValue(cv); + saveValue(); + }); + + node.addWidget("button", "Add Output variable", "add_output_variable", async () => { + const currentWidth = node.size[0]; + // Output name variable + let nameOutput = node?.outputs?.length ? `r${node.outputs.length}` : "r0"; + let tp = varTypeList.value; + nameOutput = nameOutput + "_" + typeMap[tp]; + node.addOutput(nameOutput, tp); + node.setSize([currentWidth, node.size[1]]); + let cv = widget.editor.getValue(); + if (tp === "json") { + cv = cv + "\n" + nameOutput + " = json.dumps(" + nameOutput + ")"; + } else if (tp === "list") { + cv = cv + "\n" + nameOutput + " = []"; + } else if (tp === "dict") { + cv = cv + "\n" + nameOutput + " = {}"; + } else { + cv = cv + "\n" + nameOutput + " = " + typeMap[tp] + "(" + nameOutput + ")"; + } + widget.editor.setValue(cv); + saveValue(); + }); + + node.onRemoved = function () { + for (const w of node?.widgets) { + if (w?.codeElement) w.codeElement.remove(); + } + }; + + node.addCustomWidget(widget); + + return widget; + }, + }; + }, + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + // --- IDENode + if (nodeData.name === "KY_Eval_Python") { + // Node Created + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = async function () { + const ret = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; + + const node_title = await this.getTitle(); + const nodeName = `${nodeData.name}_${this.id}`; + + this.name = nodeName; + + // Create default inputs, when first create node + if (this?.inputs?.length < 2) { + ["p0_str"].forEach((inputName) => { + const currentWidth = this.size[0]; + this.addInput(inputName, "*"); + this.setSize([currentWidth, this.size[1]]); + }); + } + + const widgetEditor = findWidget(this, "pycode", "type"); + + this.setSize([530, this.size[1]]); + + return ret; + }; + + const onDrawForeground = nodeType.prototype.onDrawForeground; + nodeType.prototype.onDrawForeground = function (ctx) { + const r = onDrawForeground?.apply?.(this, arguments); + + // if (this.flags?.collapsed) return r; + + if (this?.outputs?.length) { + for (let o = 0; o < this.outputs.length; o++) { + const { name, type } = this.outputs[o]; + const colorType = LGraphCanvas.link_type_colors[type.toUpperCase()]; + const nameSize = ctx.measureText(name); + const typeSize = ctx.measureText(`[${type === "*" ? "any" : type.toLowerCase()}]`); + + ctx.fillStyle = colorType === "" ? "#AAA" : colorType; + ctx.font = "12px Arial, sans-serif"; + ctx.textAlign = "right"; + ctx.fillText(`[${type === "*" ? "any" : type.toLowerCase()}]`, this.size[0] - nameSize.width - typeSize.width, o * 20 + 19); + } + } + + if (this?.inputs?.length) { + const not_showing = ["select_type", "pycode"]; + for (let i = 1; i < this.inputs.length; i++) { + const { name, type } = this.inputs[i]; + if (not_showing.includes(name)) continue; + const colorType = LGraphCanvas.link_type_colors[type.toUpperCase()]; + const nameSize = ctx.measureText(name); + + ctx.fillStyle = !colorType || colorType === "" ? "#AAA" : colorType; + ctx.font = "12px Arial, sans-serif"; + ctx.textAlign = "left"; + ctx.fillText(`[${type === "*" ? "any" : type.toLowerCase()}]`, nameSize.width + 25, i * 20); + } + } + return r; + }; + + // Node Configure + const onConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (node) { + onConfigure?.apply(this, arguments); + if (node?.widgets_values?.length) { + const widget_code_id = findWidget(this, "pycode", "type", "findIndex"); + const widget_theme_id = findWidget(this, "varTypeList", "name", "findIndex"); + const widget_language_id = findWidget(this, "language", "name", "findIndex"); + + const editor = this.widgets[widget_code_id]?.editor; + + if (editor) { + // editor.setTheme( + // `ace/theme/${this.widgets_values[widget_theme_id]}` + // ); + // editor.session.setMode( + // `ace/mode/${this.widgets_values[widget_language_id]}` + // ); + editor.setValue(this.widgets_values[widget_code_id]); + editor.clearSelection(); + } + } + }; + + // ExtraMenuOptions + const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.getExtraMenuOptions = function (_, options) { + getExtraMenuOptions?.apply(this, arguments); + + const past_index = options.length - 1; + const past = options[past_index]; + + if (!!past) { + // Inputs remove + for (const input_idx in this.inputs) { + const input = this.inputs[input_idx]; + + if (["language", "select_type"].includes(input.name)) continue; + + options.splice(past_index + 1, 0, { + content: `Remove Input ${input.name}`, + callback: (e) => { + const currentWidth = this.size[0]; + if (input.link) { + app.graph.removeLink(input.link); + } + this.removeInput(input_idx); + this.setSize([80, this.size[1]]); + saveValue(); + }, + }); + } + + // Output remove + for (const output_idx in this.outputs) { + const output = this.outputs[output_idx]; + + if (output.name === "r0") continue; + + options.splice(past_index + 1, 0, { + content: `Remove Output ${output.name}`, + callback: (e) => { + const currentWidth = this.size[0]; + if (output.link) { + app.graph.removeLink(output.link); + } + this.removeOutput(output_idx); + this.setSize([currentWidth, this.size[1]]); + saveValue(); + }, + }); + } + } + }; + // end - ExtraMenuOptions + } + }, +}); diff --git a/comfy_extras/nodes/nodes_eval.py b/comfy_extras/nodes/nodes_eval.py new file mode 100644 index 000000000..a09739c21 --- /dev/null +++ b/comfy_extras/nodes/nodes_eval.py @@ -0,0 +1,110 @@ +import re +import traceback +import types + +from comfy.execution_context import current_execution_context +from comfy.node_helpers import export_package_as_web_directory, export_custom_nodes +from comfy.nodes.package_typing import CustomNode + +remove_type_name = re.compile(r"(\{.*\})", re.I | re.M) + + +# Hack: string type that is always equal in not equal comparisons, thanks pythongosssss +class AnyType(str): + def __ne__(self, __value: object) -> bool: + return False + + +PY_CODE = AnyType("*") +IDEs_DICT = {} + + +# - Thank you very much for the class -> Trung0246 - +# - https://github.com/Trung0246/ComfyUI-0246/blob/main/utils.py#L51 +class TautologyStr(str): + def __ne__(self, other): + return False + + +class ByPassTypeTuple(tuple): + def __getitem__(self, index): + if index > 0: + index = 0 + item = super().__getitem__(index) + if isinstance(item, str): + return TautologyStr(item) + return item + + +# --------------------------- + + +class KY_Eval_Python(CustomNode): + @classmethod + def INPUT_TYPES(s): + + return { + "required": { + "pycode": ( + "PYCODE", + { + "default": """import re, json, os, traceback +from time import strftime + +def runCode(): + nowDataTime = strftime("%Y-%m-%d %H:%M:%S") + return f"Hello ComfyUI with us today {nowDataTime}!" +r0_str = runCode() + unique_id +""" + }, + ), + }, + "hidden": {"unique_id": "UNIQUE_ID", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = ByPassTypeTuple((PY_CODE,)) + RETURN_NAMES = ("r0_str",) + FUNCTION = "exec_py" + DESCRIPTION = "IDE Node is an node that allows you to run code written in Python or Javascript directly in the node." + CATEGORY = "KYNode/Code" + + def exec_py(self, pycode, unique_id, extra_pnginfo, **kwargs): + ctx = current_execution_context() + if ctx.configuration.enable_eval is not True: + raise ValueError("Python eval is disabled") + + if unique_id not in IDEs_DICT: + IDEs_DICT[unique_id] = self + + outputs = {unique_id: unique_id} + if extra_pnginfo and 'workflow' in extra_pnginfo and extra_pnginfo['workflow']: + for node in extra_pnginfo['workflow']['nodes']: + if node['id'] == int(unique_id): + outputs_valid = [ouput for ouput in node.get('outputs', []) if ouput.get('name', '') != '' and ouput.get('type', '') != ''] + outputs = {ouput['name']: None for ouput in outputs_valid} + self.RETURN_TYPES = ByPassTypeTuple(out["type"] for out in outputs_valid) + self.RETURN_NAMES = tuple(name for name in outputs.keys()) + my_namespace = types.SimpleNamespace() + # 从 prompt 对象中提取 prompt_id + # if extra_data and 'extra_data' in extra_data and 'prompt_id' in extra_data['extra_data']: + # prompt_id = prompt['extra_data']['prompt_id'] + # outputs['p0_str'] = p0_str + + my_namespace.__dict__.update(outputs) + my_namespace.__dict__.update({prop: kwargs[prop] for prop in kwargs}) + # my_namespace.__dict__.setdefault("r0_str", "The r0 variable is not assigned") + + try: + exec(pycode, my_namespace.__dict__) + except Exception as e: + err = traceback.format_exc() + mc = re.search(r'line (\d+), in ([\w\W]+)$', err, re.MULTILINE) + msg = mc[1] + ':' + mc[2] + my_namespace.r0 = f"Error Line{msg}" + + new_dict = {key: my_namespace.__dict__[key] for key in my_namespace.__dict__ if key not in ['__builtins__', *kwargs.keys()] and not callable(my_namespace.__dict__[key])} + return (*new_dict.values(),) + + +export_custom_nodes() +export_package_as_web_directory("comfy_extras.eval_web") diff --git a/tests/conftest.py b/tests/conftest.py index 931ec8deb..1c5b3df20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers frontend_command = [ "comfyui", - "--listen=127.0.0.1", + "--listen=0.0.0.0", "--port=19001", "--cpu", "--distributed-queue-frontend", diff --git a/tests/distributed/test_tracing_integration.py b/tests/distributed/test_tracing_integration.py index 412b14c02..36a9fabd6 100644 --- a/tests/distributed/test_tracing_integration.py +++ b/tests/distributed/test_tracing_integration.py @@ -8,6 +8,7 @@ full distributed trace. import asyncio import logging import os +import tempfile import time import uuid @@ -21,6 +22,7 @@ from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.attributes import service_attributes from testcontainers.core.container import DockerContainer from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.nginx import NginxContainer from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @@ -54,6 +56,102 @@ class JaegerContainer(DockerContainer): return self +@pytest.fixture(scope="function") +def nginx_proxy(frontend_backend_worker_with_rabbitmq): + """ + Provide an nginx proxy in front of the ComfyUI frontend. + This tests if nginx is blocking W3C trace context propagation. + """ + import socket + import subprocess + + # Extract host and port from frontend address + frontend_url = frontend_backend_worker_with_rabbitmq + # frontend_url is like "http://127.0.0.1:19001" + import re + match = re.match(r'http://([^:]+):(\d+)', frontend_url) + if not match: + raise ValueError(f"Could not parse frontend URL: {frontend_url}") + + frontend_host = match.group(1) + frontend_port = match.group(2) + nginx_port = 8085 + + # Get the Docker bridge gateway IP (this is how containers reach the host on Linux) + # Try to get the default Docker bridge gateway + try: + result = subprocess.run( + ["docker", "network", "inspect", "bridge", "-f", "{{range .IPAM.Config}}{{.Gateway}}{{end}}"], + capture_output=True, + text=True, + check=True + ) + docker_gateway = result.stdout.strip() + logger.info(f"Using Docker gateway IP: {docker_gateway}") + except Exception as e: + # Fallback: try common gateway IPs + docker_gateway = "172.17.0.1" # Default Docker bridge gateway on Linux + logger.warning(f"Could not detect Docker gateway, using default: {docker_gateway}") + + # Create nginx config that proxies to the frontend and passes trace headers + nginx_conf = f""" +events {{ + worker_connections 1024; +}} + +http {{ + upstream backend {{ + server {docker_gateway}:{frontend_port}; + }} + + server {{ + listen {nginx_port}; + + location / {{ + proxy_pass http://backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + }} + }} +}} +""" + + # Write config to a temporary file + with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as f: + f.write(nginx_conf) + nginx_conf_path = f.name + + try: + # Start nginx container with the config + nginx = NginxContainer(port=nginx_port) + nginx.with_volume_mapping(nginx_conf_path, "/etc/nginx/nginx.conf") + nginx.start() + + # Get the nginx URL + host = nginx.get_container_host_ip() + port = nginx.get_exposed_port(nginx_port) + nginx_url = f"http://{host}:{port}" + + logger.info(f"Nginx proxy started at {nginx_url} -> {frontend_url}") + + # Wait for nginx to be ready + for _ in range(30): + try: + response = requests.get(nginx_url, timeout=1) + if response.status_code: + break + except Exception: + pass + time.sleep(0.5) + + yield nginx_url + finally: + nginx.stop() + os.unlink(nginx_conf_path) + + @pytest.fixture(scope="module") def jaeger_container(): """ @@ -201,21 +299,21 @@ def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool: # order matters, execute jaeger_container first @pytest.mark.asyncio -async def test_tracing_integration(jaeger_container, frontend_backend_worker_with_rabbitmq): +async def test_tracing_integration(jaeger_container, nginx_proxy): """ - Integration test for distributed tracing across services. + Integration test for distributed tracing across services with nginx proxy. This test: 1. Starts ComfyUI frontend and worker with RabbitMQ - 2. Configures OTLP export to Jaeger testcontainer - 3. Submits a workflow through the frontend - 4. Queries Jaeger to verify trace propagation - 5. Validates that the trace spans multiple services with proper relationships + 2. Starts nginx proxy in front of the frontend to test trace context propagation through nginx + 3. Configures OTLP export to Jaeger testcontainer + 4. Submits a workflow through the nginx proxy + 5. Queries Jaeger to verify trace propagation + 6. Validates that the trace spans multiple services with proper relationships - Note: The frontend_backend_worker_with_rabbitmq fixture is parameterized, - so this test will run with both ThreadPoolExecutor and ProcessPoolExecutor. + This specifically tests if nginx is blocking W3C trace context (traceparent/tracestate headers). """ - server_address = frontend_backend_worker_with_rabbitmq + server_address = nginx_proxy jaeger_url = jaeger_container.get_query_url() otlp_endpoint = jaeger_container.get_otlp_endpoint() @@ -410,31 +508,27 @@ async def test_multiple_requests_different_traces(frontend_backend_worker_with_r # Query Jaeger and verify we have multiple distinct traces jaeger_url = jaeger_container.get_query_url() - try: - traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10) - traces = traces_response.get("data", []) + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10) + traces = traces_response.get("data", []) - if len(traces) >= 2: - # Get trace IDs - trace_ids = [trace.get("traceID") for trace in traces] - unique_trace_ids = set(trace_ids) + assert len(traces) >= 2 + # Get trace IDs + trace_ids = [trace.get("traceID") for trace in traces] + unique_trace_ids = set(trace_ids) - logger.info(f"Found {len(unique_trace_ids)} unique traces") + logger.info(f"Found {len(unique_trace_ids)} unique traces") - # Verify we have multiple distinct traces - assert len(unique_trace_ids) >= 2, ( - f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. " - "Each request should create its own trace." - ) + # Verify we have multiple distinct traces + assert len(unique_trace_ids) >= 2, ( + f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. " + "Each request should create its own trace." + ) - logger.info("✓ Multiple requests created distinct traces") - else: - pytest.skip("Not enough traces to validate") - except Exception as e: - pytest.skip(f"Could not query Jaeger: {e}") + logger.info("✓ Multiple requests created distinct traces") @pytest.mark.asyncio +@pytest.mark.skip(reason="rabbitmq has to be configured for observability?") async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container): """ Test that traces include RabbitMQ publish/consume operations. @@ -455,43 +549,21 @@ async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_r await asyncio.sleep(5) - try: - traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m") - traces = traces_response.get("data", []) + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m") + traces = traces_response.get("data", []) - if traces: - # Look for RabbitMQ-related operations in any trace - rabbitmq_operations = [ - "publish", "consume", "amq_queue_publish", "amq_queue_consume", - "amq.basic.publish", "amq.basic.consume", "send", "receive" - ] + # Look for RabbitMQ-related operations in any trace + rabbitmq_operations = [ + "publish", "consume", "amq_queue_publish", "amq_queue_consume", + "amq.basic.publish", "amq.basic.consume", "send", "receive" + ] - found_rabbitmq_ops = [] - for trace in traces: - for span in trace.get("spans", []): - op_name = span.get("operationName", "").lower() - for rmq_op in rabbitmq_operations: - if rmq_op in op_name: - found_rabbitmq_ops.append(op_name) + found_rabbitmq_ops = [] + for trace in traces: + for span in trace.get("spans", []): + op_name = span.get("operationName", "").lower() + for rmq_op in rabbitmq_operations: + if rmq_op in op_name: + found_rabbitmq_ops.append(op_name) - if found_rabbitmq_ops: - logger.info(f"✓ Found RabbitMQ operations in traces: {set(found_rabbitmq_ops)}") - else: - logger.warning( - "No RabbitMQ operations found in traces. " - "This suggests that either:\n" - "1. AioPikaInstrumentor is not creating spans, or\n" - "2. The spans are being filtered out by the collector, or\n" - "3. The spans exist but use different operation names" - ) - - # Log all operation names to help debug - all_ops = set() - for trace in traces[:3]: # First 3 traces - for span in trace.get("spans", []): - all_ops.add(span.get("operationName")) - logger.info(f"Sample operation names: {all_ops}") - else: - pytest.skip("No traces found") - except Exception as e: - pytest.skip(f"Could not query Jaeger: {e}") + assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces" \ No newline at end of file From 0d9232f02c19be38a2ae02f48fa114cfba936fdc Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Mon, 10 Nov 2025 09:47:27 -0800 Subject: [PATCH 6/9] wip python eval nodes --- comfy/cmd/main_pre.py | 1 + comfy_extras/eval_web/eval_python.js | 178 ++++++ comfy_extras/eval_web/ky_eval_python.js | 377 ------------- comfy_extras/nodes/nodes_eval.py | 187 ++++--- tests/unit/test_eval_nodes.py | 693 ++++++++++++++++++++++++ 5 files changed, 971 insertions(+), 465 deletions(-) create mode 100644 comfy_extras/eval_web/eval_python.js delete mode 100644 comfy_extras/eval_web/ky_eval_python.js create mode 100644 tests/unit/test_eval_nodes.py diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 249482188..7c06e4cee 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -42,6 +42,7 @@ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_ warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.") warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") warnings.filterwarnings('ignore', category=FutureWarning, message=r'`torch\.cuda\.amp\.custom_fwd.*') +warnings.filterwarnings("ignore", category=UserWarning, message="Please use the new API settings to control TF32 behavior.*") warnings.filterwarnings("ignore", message="Importing from timm.models.registry is deprecated, please import via timm.models", category=FutureWarning) warnings.filterwarnings("ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning) warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning) diff --git a/comfy_extras/eval_web/eval_python.js b/comfy_extras/eval_web/eval_python.js new file mode 100644 index 000000000..d344d9e50 --- /dev/null +++ b/comfy_extras/eval_web/eval_python.js @@ -0,0 +1,178 @@ +/** + * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode + * + * MIT License + * + * Copyright (c) 2024 Kevin Yuan + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +import { app } from "../../scripts/app.js"; +import { makeElement, findWidget } from "./ace_utils.js"; + +// Load Ace editor using script tag for Safari compatibility +// The noconflict build includes AMD loader that works in all browsers +let ace; +const aceLoadPromise = new Promise((resolve) => { + if (window.ace) { + ace = window.ace; + resolve(); + } else { + const script = document.createElement('script'); + script.src = "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict/ace.js"; + script.onload = () => { + ace = window.ace; + ace.config.set("basePath", "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict"); + resolve(); + }; + document.head.appendChild(script); + } +}); + +await aceLoadPromise; + + +function getPosition(node, ctx, w_width, y, n_height) { + const margin = 5; + + const rect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(rect.width / ctx.canvas.width, rect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); + + return { + transformOrigin: "0 0", + transform: scale, + left: `${transform.a + transform.e + rect.left}px`, + top: `${transform.d + transform.f + rect.top}px`, + maxWidth: `${w_width - margin * 2}px`, + maxHeight: `${n_height - margin * 2 - y - 15}px`, + width: `${w_width - margin * 2}px`, + height: "90%", + position: "absolute", + scrollbarColor: "var(--descrip-text) var(--bg-color)", + scrollbarWidth: "thin", + zIndex: app.graph._nodes.indexOf(node), + }; +} + +// Create code editor widget +function codeEditor(node, inputName, inputData) { + const widget = { + type: "pycode", + name: inputName, + options: { hideOnZoom: true }, + value: inputData[1]?.default || "", + draw(ctx, node, widgetWidth, y) { + const hidden = node.flags?.collapsed || (!!this.options.hideOnZoom && app.canvas.ds.scale < 0.5) || this.type === "converted-widget" || this.type === "hidden"; + + this.codeElement.hidden = hidden; + + if (hidden) { + this.options.onHide?.(this); + return; + } + + Object.assign(this.codeElement.style, getPosition(node, ctx, widgetWidth, y, node.size[1])); + }, + computeSize() { + return [500, 250]; + }, + }; + + widget.codeElement = makeElement("pre", { + innerHTML: widget.value, + }); + + widget.editor = ace.edit(widget.codeElement); + widget.editor.setTheme("ace/theme/monokai"); + widget.editor.session.setMode("ace/mode/python"); + widget.editor.setOptions({ + enableAutoIndent: true, + enableLiveAutocompletion: true, + enableBasicAutocompletion: true, + fontFamily: "monospace", + }); + widget.codeElement.hidden = true; + + document.body.appendChild(widget.codeElement); + + const originalCollapse = node.collapse; + node.collapse = function () { + originalCollapse.apply(this, arguments); + widget.codeElement.hidden = !!this.flags?.collapsed; + }; + + return widget; +} + +// Trigger workflow change tracking +function markWorkflowChanged() { + app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState(); +} + +// Register extensions +app.registerExtension({ + name: "Comfy.EvalPython", + getCustomWidgets(app) { + return { + PYCODE: (node, inputName, inputData) => { + const widget = codeEditor(node, inputName, inputData); + + widget.editor.getSession().on("change", () => { + widget.value = widget.editor.getValue(); + markWorkflowChanged(); + }); + + node.onRemoved = function () { + for (const w of this.widgets) { + if (w?.codeElement) { + w.codeElement.remove(); + } + } + }; + + node.addCustomWidget(widget); + + return widget; + }, + }; + }, + + async beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name === "EvalPython") { + const originalOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + originalOnConfigure?.apply(this, arguments); + + if (info?.widgets_values?.length) { + const widgetCodeIndex = findWidget(this, "pycode", "type", "findIndex"); + const editor = this.widgets[widgetCodeIndex]?.editor; + + if (editor) { + editor.setValue(info.widgets_values[widgetCodeIndex]); + editor.clearSelection(); + } + } + }; + } + }, +}); diff --git a/comfy_extras/eval_web/ky_eval_python.js b/comfy_extras/eval_web/ky_eval_python.js deleted file mode 100644 index 3d65aa5c0..000000000 --- a/comfy_extras/eval_web/ky_eval_python.js +++ /dev/null @@ -1,377 +0,0 @@ -/** - * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode - * - * MIT License - * - * Copyright (c) 2024 Kevin Yuan - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -import { app } from "../../scripts/app.js"; - -import * as ace from "https://cdn.jsdelivr.net/npm/ace-code@1.43.4/+esm"; -import { makeElement, findWidget } from "./ace_utils.js"; - -// Constants -const varTypes = ["int", "boolean", "string", "float", "json", "list", "dict"]; -const typeMap = { - int: "int", - boolean: "bool", - string: "str", - float: "float", - json: "json", - list: "list", - dict: "dict", -}; - -ace.config.setModuleLoader('ace/mode/python', () => - import('https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src/mode-python.js') -); - -ace.config.setModuleLoader('ace/theme/monokai', () => - import('https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src/theme-monokai.js') -); - -function getPostition(node, ctx, w_width, y, n_height) { - const margin = 5; - - const rect = ctx.canvas.getBoundingClientRect(); - const transform = new DOMMatrix() - .scaleSelf(rect.width / ctx.canvas.width, rect.height / ctx.canvas.height) - .multiplySelf(ctx.getTransform()) - .translateSelf(margin, margin + y); - const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); - - return { - transformOrigin: "0 0", - transform: scale, - left: `${transform.a + transform.e + rect.left}px`, - top: `${transform.d + transform.f + rect.top}px`, - maxWidth: `${w_width - margin * 2}px`, - maxHeight: `${n_height - margin * 2 - y - 15}px`, - width: `${w_width - margin * 2}px`, - height: "90%", - position: "absolute", - scrollbarColor: "var(--descrip-text) var(--bg-color)", - scrollbarWidth: "thin", - zIndex: app.graph._nodes.indexOf(node), - }; -} - -// Create editor code -function codeEditor(node, inputName, inputData) { - const widget = { - type: "pycode", - name: inputName, - options: { hideOnZoom: true }, - value: - inputData[1]?.default || - `def my(a, b=1): - return a * b
- -r0 = str(my(23, 9))`, - draw(ctx, node, widget_width, y, widget_height) { - const hidden = node.flags?.collapsed || (!!widget.options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.type === "converted-widget" || widget.type === "hidden"; - - widget.codeElement.hidden = hidden; - - if (hidden) { - widget.options.onHide?.(widget); - return; - } - - Object.assign(this.codeElement.style, getPostition(node, ctx, widget_width, y, node.size[1])); - }, - computeSize(...args) { - return [500, 250]; - }, - }; - - widget.codeElement = makeElement("pre", { - innerHTML: widget.value, - }); - - widget.editor = ace.edit(widget.codeElement); - widget.editor.setTheme("ace/theme/monokai"); - widget.editor.session.setMode("ace/mode/python"); - widget.editor.setOptions({ - enableAutoIndent: true, - enableLiveAutocompletion: true, - enableBasicAutocompletion: true, - fontFamily: "monospace", - }); - widget.codeElement.hidden = true; - - document.body.appendChild(widget.codeElement); - - const collapse = node.collapse; - node.collapse = function () { - collapse.apply(this, arguments); - if (this.flags?.collapsed) { - widget.codeElement.hidden = true; - } else { - if (this.flags?.collapsed === false) { - widget.codeElement.hidden = false; - } - } - }; - - return widget; -} - -// Save data to workflow forced! -function saveValue() { - app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState(); -} - -// Register extensions -app.registerExtension({ - name: "KYNode.KY_Eval_Python", - getCustomWidgets(app) { - return { - PYCODE: (node, inputName, inputData, app) => { - const widget = codeEditor(node, inputName, inputData); - - widget.editor.getSession().on("change", function (e) { - widget.value = widget.editor.getValue(); - saveValue(); - }); - - const varTypeList = node.addWidget( - "combo", - "select_type", - "string", - (v) => { - // widget.editor.setTheme(`ace/theme/${varTypeList.value}`); - }, - { - values: varTypes, - serialize: false, - }, - ); - - // 6. 使用 addDOMWidget 将容器添加到节点上 - // - 第一个参数是 widget 的名称,在节点内部需要是唯一的。 - // - 第二个参数是 widget 的类型,对于自定义 DOM 元素,通常是 "div"。 - // - 第三个参数是您创建的 DOM 元素。 - // - 第四个参数是一个选项对象,可以用来配置 widget。 - // node.addDOMWidget("rowOfButtons", "div", container, { - // }); - node.addWidget("button", "Add Input variable", "add_input_variable", async () => { - // Input name variable and check - let nameInput = node?.inputs?.length ? `p${node.inputs.length - 1}` : "p0"; - - const currentWidth = node.size[0]; - let tp = varTypeList.value; - nameInput = nameInput + "_" + typeMap[tp]; - node.addInput(nameInput, "*"); - node.setSize([currentWidth, node.size[1]]); - let cv = widget.editor.getValue(); - if (tp === "json") { - cv = cv + "\n" + nameInput + " = json.loads(" + nameInput + ")"; - } else if (tp === "list") { - cv = cv + "\n" + nameInput + " = []"; - } else if (tp === "dict") { - cv = cv + "\n" + nameInput + " = {}"; - } else { - cv = cv + "\n" + nameInput + " = " + typeMap[tp] + "(" + nameInput + ")"; - } - widget.editor.setValue(cv); - saveValue(); - }); - - node.addWidget("button", "Add Output variable", "add_output_variable", async () => { - const currentWidth = node.size[0]; - // Output name variable - let nameOutput = node?.outputs?.length ? `r${node.outputs.length}` : "r0"; - let tp = varTypeList.value; - nameOutput = nameOutput + "_" + typeMap[tp]; - node.addOutput(nameOutput, tp); - node.setSize([currentWidth, node.size[1]]); - let cv = widget.editor.getValue(); - if (tp === "json") { - cv = cv + "\n" + nameOutput + " = json.dumps(" + nameOutput + ")"; - } else if (tp === "list") { - cv = cv + "\n" + nameOutput + " = []"; - } else if (tp === "dict") { - cv = cv + "\n" + nameOutput + " = {}"; - } else { - cv = cv + "\n" + nameOutput + " = " + typeMap[tp] + "(" + nameOutput + ")"; - } - widget.editor.setValue(cv); - saveValue(); - }); - - node.onRemoved = function () { - for (const w of node?.widgets) { - if (w?.codeElement) w.codeElement.remove(); - } - }; - - node.addCustomWidget(widget); - - return widget; - }, - }; - }, - - async beforeRegisterNodeDef(nodeType, nodeData, app) { - // --- IDENode - if (nodeData.name === "KY_Eval_Python") { - // Node Created - const onNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = async function () { - const ret = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; - - const node_title = await this.getTitle(); - const nodeName = `${nodeData.name}_${this.id}`; - - this.name = nodeName; - - // Create default inputs, when first create node - if (this?.inputs?.length < 2) { - ["p0_str"].forEach((inputName) => { - const currentWidth = this.size[0]; - this.addInput(inputName, "*"); - this.setSize([currentWidth, this.size[1]]); - }); - } - - const widgetEditor = findWidget(this, "pycode", "type"); - - this.setSize([530, this.size[1]]); - - return ret; - }; - - const onDrawForeground = nodeType.prototype.onDrawForeground; - nodeType.prototype.onDrawForeground = function (ctx) { - const r = onDrawForeground?.apply?.(this, arguments); - - // if (this.flags?.collapsed) return r; - - if (this?.outputs?.length) { - for (let o = 0; o < this.outputs.length; o++) { - const { name, type } = this.outputs[o]; - const colorType = LGraphCanvas.link_type_colors[type.toUpperCase()]; - const nameSize = ctx.measureText(name); - const typeSize = ctx.measureText(`[${type === "*" ? "any" : type.toLowerCase()}]`); - - ctx.fillStyle = colorType === "" ? "#AAA" : colorType; - ctx.font = "12px Arial, sans-serif"; - ctx.textAlign = "right"; - ctx.fillText(`[${type === "*" ? "any" : type.toLowerCase()}]`, this.size[0] - nameSize.width - typeSize.width, o * 20 + 19); - } - } - - if (this?.inputs?.length) { - const not_showing = ["select_type", "pycode"]; - for (let i = 1; i < this.inputs.length; i++) { - const { name, type } = this.inputs[i]; - if (not_showing.includes(name)) continue; - const colorType = LGraphCanvas.link_type_colors[type.toUpperCase()]; - const nameSize = ctx.measureText(name); - - ctx.fillStyle = !colorType || colorType === "" ? "#AAA" : colorType; - ctx.font = "12px Arial, sans-serif"; - ctx.textAlign = "left"; - ctx.fillText(`[${type === "*" ? "any" : type.toLowerCase()}]`, nameSize.width + 25, i * 20); - } - } - return r; - }; - - // Node Configure - const onConfigure = nodeType.prototype.onConfigure; - nodeType.prototype.onConfigure = function (node) { - onConfigure?.apply(this, arguments); - if (node?.widgets_values?.length) { - const widget_code_id = findWidget(this, "pycode", "type", "findIndex"); - const widget_theme_id = findWidget(this, "varTypeList", "name", "findIndex"); - const widget_language_id = findWidget(this, "language", "name", "findIndex"); - - const editor = this.widgets[widget_code_id]?.editor; - - if (editor) { - // editor.setTheme( - // `ace/theme/${this.widgets_values[widget_theme_id]}` - // ); - // editor.session.setMode( - // `ace/mode/${this.widgets_values[widget_language_id]}` - // ); - editor.setValue(this.widgets_values[widget_code_id]); - editor.clearSelection(); - } - } - }; - - // ExtraMenuOptions - const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; - nodeType.prototype.getExtraMenuOptions = function (_, options) { - getExtraMenuOptions?.apply(this, arguments); - - const past_index = options.length - 1; - const past = options[past_index]; - - if (!!past) { - // Inputs remove - for (const input_idx in this.inputs) { - const input = this.inputs[input_idx]; - - if (["language", "select_type"].includes(input.name)) continue; - - options.splice(past_index + 1, 0, { - content: `Remove Input ${input.name}`, - callback: (e) => { - const currentWidth = this.size[0]; - if (input.link) { - app.graph.removeLink(input.link); - } - this.removeInput(input_idx); - this.setSize([80, this.size[1]]); - saveValue(); - }, - }); - } - - // Output remove - for (const output_idx in this.outputs) { - const output = this.outputs[output_idx]; - - if (output.name === "r0") continue; - - options.splice(past_index + 1, 0, { - content: `Remove Output ${output.name}`, - callback: (e) => { - const currentWidth = this.size[0]; - if (output.link) { - app.graph.removeLink(output.link); - } - this.removeOutput(output_idx); - this.setSize([currentWidth, this.size[1]]); - saveValue(); - }, - }); - } - } - }; - // end - ExtraMenuOptions - } - }, -}); diff --git a/comfy_extras/nodes/nodes_eval.py b/comfy_extras/nodes/nodes_eval.py index a09739c21..ff04522eb 100644 --- a/comfy_extras/nodes/nodes_eval.py +++ b/comfy_extras/nodes/nodes_eval.py @@ -1,109 +1,120 @@ -import re -import traceback -import types +import logging +from comfy.comfy_types import IO from comfy.execution_context import current_execution_context from comfy.node_helpers import export_package_as_web_directory, export_custom_nodes from comfy.nodes.package_typing import CustomNode -remove_type_name = re.compile(r"(\{.*\})", re.I | re.M) +logger = logging.getLogger(__name__) -# Hack: string type that is always equal in not equal comparisons, thanks pythongosssss -class AnyType(str): - def __ne__(self, __value: object) -> bool: - return False +def eval_python(inputs=5, outputs=5, name=None, input_is_list=None, output_is_list=None): + """ + Factory function to create EvalPython node classes with configurable input/output counts. + Args: + inputs: Number of input value slots (default: 5) + outputs: Number of output item slots (default: 5) + name: Class name (default: f"EvalPython_{inputs}_{outputs}") + input_is_list: Optional list of bools indicating which inputs accept lists (default: None, meaning all scalar) + output_is_list: Optional tuple of bools indicating which outputs return lists (default: None, meaning all scalar) -PY_CODE = AnyType("*") -IDEs_DICT = {} + Returns: + A CustomNode subclass configured with the specified inputs/outputs + """ + if name is None: + name = f"EvalPython_{inputs}_{outputs}" - -# - Thank you very much for the class -> Trung0246 - -# - https://github.com/Trung0246/ComfyUI-0246/blob/main/utils.py#L51 -class TautologyStr(str): - def __ne__(self, other): - return False - - -class ByPassTypeTuple(tuple): - def __getitem__(self, index): - if index > 0: - index = 0 - item = super().__getitem__(index) - if isinstance(item, str): - return TautologyStr(item) - return item - - -# --------------------------- - - -class KY_Eval_Python(CustomNode): - @classmethod - def INPUT_TYPES(s): - - return { - "required": { - "pycode": ( - "PYCODE", - { - "default": """import re, json, os, traceback -from time import strftime - -def runCode(): - nowDataTime = strftime("%Y-%m-%d %H:%M:%S") - return f"Hello ComfyUI with us today {nowDataTime}!" -r0_str = runCode() + unique_id + default_code = f""" +print("Hello World!") +return {", ".join([f"value{i}" for i in range(inputs)])} """ - }, - ), - }, - "hidden": {"unique_id": "UNIQUE_ID", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = ByPassTypeTuple((PY_CODE,)) - RETURN_NAMES = ("r0_str",) - FUNCTION = "exec_py" - DESCRIPTION = "IDE Node is an node that allows you to run code written in Python or Javascript directly in the node." - CATEGORY = "KYNode/Code" + class EvalPythonNode(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "pycode": ( + "PYCODE", + { + "default": default_code + }, + ), + }, + "optional": {f"value{i}": (IO.ANY, {}) for i in range(inputs)}, + } - def exec_py(self, pycode, unique_id, extra_pnginfo, **kwargs): - ctx = current_execution_context() - if ctx.configuration.enable_eval is not True: - raise ValueError("Python eval is disabled") + RETURN_TYPES = tuple(IO.ANY for _ in range(outputs)) + RETURN_NAMES = tuple(f"item{i}" for i in range(outputs)) + OUTPUT_IS_LIST = output_is_list + INPUT_IS_LIST = input_is_list is not None + FUNCTION = "exec_py" + DESCRIPTION = "" + CATEGORY = "eval" - if unique_id not in IDEs_DICT: - IDEs_DICT[unique_id] = self + def exec_py(self, pycode, **kwargs): + ctx = current_execution_context() - outputs = {unique_id: unique_id} - if extra_pnginfo and 'workflow' in extra_pnginfo and extra_pnginfo['workflow']: - for node in extra_pnginfo['workflow']['nodes']: - if node['id'] == int(unique_id): - outputs_valid = [ouput for ouput in node.get('outputs', []) if ouput.get('name', '') != '' and ouput.get('type', '') != ''] - outputs = {ouput['name']: None for ouput in outputs_valid} - self.RETURN_TYPES = ByPassTypeTuple(out["type"] for out in outputs_valid) - self.RETURN_NAMES = tuple(name for name in outputs.keys()) - my_namespace = types.SimpleNamespace() - # 从 prompt 对象中提取 prompt_id - # if extra_data and 'extra_data' in extra_data and 'prompt_id' in extra_data['extra_data']: - # prompt_id = prompt['extra_data']['prompt_id'] - # outputs['p0_str'] = p0_str + # Ensure all value inputs have a default of None + kwargs = { + **{f"value{i}": None for i in range(inputs)}, + **kwargs, + } - my_namespace.__dict__.update(outputs) - my_namespace.__dict__.update({prop: kwargs[prop] for prop in kwargs}) - # my_namespace.__dict__.setdefault("r0_str", "The r0 variable is not assigned") + def print(*args): + ctx.server.send_progress_text(" ".join(map(str, args)), ctx.node_id) - try: - exec(pycode, my_namespace.__dict__) - except Exception as e: - err = traceback.format_exc() - mc = re.search(r'line (\d+), in ([\w\W]+)$', err, re.MULTILINE) - msg = mc[1] + ':' + mc[2] - my_namespace.r0 = f"Error Line{msg}" + if not ctx.configuration.enable_eval: + raise ValueError("Python eval is disabled") - new_dict = {key: my_namespace.__dict__[key] for key in my_namespace.__dict__ if key not in ['__builtins__', *kwargs.keys()] and not callable(my_namespace.__dict__[key])} - return (*new_dict.values(),) + # Extract value arguments in order + value_args = [kwargs.pop(f"value{i}") for i in range(inputs)] + arg_names = ", ".join(f"value{i}=None" for i in range(inputs)) + + # Wrap pycode in a function to support return statements + wrapped_code = f"def _eval_func({arg_names}):\n" + for line in pycode.splitlines(): + wrapped_code += " " + line + "\n" + + globals_for_eval = { + **kwargs, + "logger": logger, + "print": print, + } + + # Execute wrapped function definition + exec(wrapped_code, globals_for_eval) + + # Call the function with value arguments + results = globals_for_eval["_eval_func"](*value_args) + + # Normalize results to match output count + if not isinstance(results, tuple): + results = (results,) + + if len(results) < outputs: + results += (None,) * (outputs - len(results)) + elif len(results) > outputs: + results = results[:outputs] + + return results + + # Set the class name for better debugging/introspection + EvalPythonNode.__name__ = name + EvalPythonNode.__qualname__ = name + + return EvalPythonNode + + +# Create the default EvalPython node with 5 inputs and 5 outputs +EvalPython_5_5 = eval_python(inputs=5, outputs=5, name="EvalPython_5_5") +EvalPython = EvalPython_5_5 # Backward compatibility alias + +# Create list variants +EvalPython_List_1 = eval_python(inputs=1, outputs=1, name="EvalPython_List_1", input_is_list=True, output_is_list=None) +EvalPython_1_List = eval_python(inputs=1, outputs=1, name="EvalPython_1_List", input_is_list=None, output_is_list=(True,)) +EvalPython_List_List = eval_python(inputs=1, outputs=1, name="EvalPython_List_List", input_is_list=True, output_is_list=(True,)) export_custom_nodes() diff --git a/tests/unit/test_eval_nodes.py b/tests/unit/test_eval_nodes.py new file mode 100644 index 000000000..f2cb0c763 --- /dev/null +++ b/tests/unit/test_eval_nodes.py @@ -0,0 +1,693 @@ +import pytest +from unittest.mock import Mock, patch + +from comfy.cli_args import default_configuration +from comfy.execution_context import context_configuration +from comfy_extras.nodes.nodes_eval import ( + EvalPython, + EvalPython_5_5, + eval_python, + EvalPython_List_1, + EvalPython_1_List, + EvalPython_List_List, +) + + +@pytest.fixture +def eval_context(): + """Fixture that sets up execution context with eval enabled""" + config = default_configuration() + config.enable_eval = True + with context_configuration(config): + yield + + +def test_eval_python_basic_return(eval_context): + """Test basic return statement with single value""" + node = EvalPython_5_5() + result = node.exec_py(pycode="return 42", value0=0, value1=1, value2=2, value3=3, value4=4) + assert result == (42, None, None, None, None) + + +def test_eval_python_multiple_returns(eval_context): + """Test return statement with tuple of values""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 1, 2, 3", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (1, 2, 3, None, None) + + +def test_eval_python_all_five_returns(eval_context): + """Test return statement with all five values""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 'a', 'b', 'c', 'd', 'e'", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ('a', 'b', 'c', 'd', 'e') + + +def test_eval_python_excess_returns(eval_context): + """Test that excess return values are truncated to 5""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 1, 2, 3, 4, 5, 6, 7", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_use_value_args(eval_context): + """Test that value arguments are accessible in pycode""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return value0 + value1 + value2", + value0=10, value1=20, value2=30, value3=0, value4=0 + ) + assert result == (60, None, None, None, None) + + +def test_eval_python_all_value_args(eval_context): + """Test all value arguments are accessible""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return value0, value1, value2, value3, value4", + value0=1, value1=2, value2=3, value3=4, value4=5 + ) + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_computation(eval_context): + """Test computation with value arguments""" + node = EvalPython_5_5() + code = """ +x = value0 * 2 +y = value1 * 3 +z = x + y +return z +""" + result = node.exec_py( + pycode=code, + value0=5, value1=10, value2=0, value3=0, value4=0 + ) + assert result == (40, None, None, None, None) + + +def test_eval_python_multiline(eval_context): + """Test multiline code with conditionals""" + node = EvalPython_5_5() + code = """ +if value0 > 10: + result = "large" +else: + result = "small" +return result, value0 +""" + result = node.exec_py( + pycode=code, + value0=15, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("large", 15, None, None, None) + + +def test_eval_python_list_comprehension(eval_context): + """Test list comprehension and iteration""" + node = EvalPython_5_5() + code = """ +numbers = [value0, value1, value2] +doubled = [x * 2 for x in numbers] +return sum(doubled) +""" + result = node.exec_py( + pycode=code, + value0=1, value1=2, value2=3, value3=0, value4=0 + ) + assert result == (12, None, None, None, None) + + +def test_eval_python_string_operations(eval_context): + """Test string operations""" + node = EvalPython_5_5() + code = """ +s1 = str(value0) +s2 = str(value1) +return s1 + s2, len(s1 + s2) +""" + result = node.exec_py( + pycode=code, + value0=123, value1=456, value2=0, value3=0, value4=0 + ) + assert result == ("123456", 6, None, None, None) + + +def test_eval_python_type_mixing(eval_context): + """Test mixing different types""" + node = EvalPython_5_5() + code = """ +return value0, str(value1), float(value2), bool(value3) +""" + result = node.exec_py( + pycode=code, + value0=42, value1=100, value2=3, value3=1, value4=0 + ) + assert result == (42, "100", 3.0, True, None) + + +def test_eval_python_logger_available(eval_context): + """Test that logger is available in eval context""" + node = EvalPython_5_5() + code = """ +logger.info("test log") +return "success" +""" + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("success", None, None, None, None) + + +def test_eval_python_print_available(eval_context): + """Test that print function is available""" + node = EvalPython_5_5() + code = """ +print("Hello World!") +return "printed" +""" + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("printed", None, None, None, None) + +def test_eval_python_print_is_called(eval_context): + """Test that print function is called and receives correct arguments""" + node = EvalPython_5_5() + + # Track print calls + print_calls = [] + + code = """ +print("Hello", "World") +print("Line 2") +return "done" +""" + + # Mock exec to capture the globals dict and verify print is there + original_exec = exec + captured_globals = {} + + def mock_exec(code_str, globals_dict, *args, **kwargs): + # Capture the globals dict + captured_globals.update(globals_dict) + + # Wrap the print function to track calls + original_print = globals_dict.get('print') + if original_print: + def tracked_print(*args): + print_calls.append(args) + return original_print(*args) + globals_dict['print'] = tracked_print + + # Run the original exec + return original_exec(code_str, globals_dict, *args, **kwargs) + + with patch('builtins.exec', side_effect=mock_exec): + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + + # Verify the result + assert result == ("done", None, None, None, None) + + # Verify print was in the globals + assert 'print' in captured_globals + + # Verify print was called twice with correct arguments + assert len(print_calls) == 2 + assert print_calls[0] == ("Hello", "World") + assert print_calls[1] == ("Line 2",) + + +def test_eval_python_print_sends_to_server(eval_context): + """Test that print sends messages to PromptServer via context""" + from comfy.execution_context import current_execution_context + + node = EvalPython_5_5() + ctx = current_execution_context() + + # Mock the server's send_progress_text method + original_send = ctx.server.send_progress_text if hasattr(ctx.server, 'send_progress_text') else None + mock_send = Mock() + ctx.server.send_progress_text = mock_send + + code = """ +print("Hello", "World") +print("Value:", value0) +return "done" +""" + + try: + result = node.exec_py( + pycode=code, + value0=42, value1=0, value2=0, value3=0, value4=0 + ) + + # Verify the result + assert result == ("done", None, None, None, None) + + # Verify print messages were sent to server + assert mock_send.call_count == 2 + + # Verify the messages sent + calls = mock_send.call_args_list + assert calls[0][0][0] == "Hello World" + assert calls[0][0][1] == ctx.node_id + assert calls[1][0][0] == "Value: 42" + assert calls[1][0][1] == ctx.node_id + finally: + # Restore original + if original_send: + ctx.server.send_progress_text = original_send + + +def test_eval_python_config_disabled_raises(): + """Test that enable_eval=False raises an error""" + node = EvalPython_5_5() + config = default_configuration() + config.enable_eval = False + with context_configuration(config): + with pytest.raises(ValueError, match="Python eval is disabled"): + node.exec_py( + pycode="return 42", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + + +def test_eval_python_config_enabled_works(eval_context): + """Test that enable_eval=True allows execution""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 42", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (42, None, None, None, None) + + +def test_eval_python_default_code(eval_context): + """Test the default code example works""" + node = EvalPython_5_5() + # Get the default code from INPUT_TYPES + default_code = EvalPython_5_5.INPUT_TYPES()["required"]["pycode"][1]["default"] + + result = node.exec_py( + pycode=default_code, + value0=1, value1=2, value2=3, value3=4, value4=5 + ) + # Default code prints and returns the values + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_function_definition(eval_context): + """Test defining and using functions""" + node = EvalPython_5_5() + code = """ +def multiply(a, b): + return a * b + +result = multiply(value0, value1) +return result +""" + result = node.exec_py( + pycode=code, + value0=7, value1=8, value2=0, value3=0, value4=0 + ) + assert result == (56, None, None, None, None) + + +def test_eval_python_nested_functions(eval_context): + """Test nested function definitions""" + node = EvalPython_5_5() + code = """ +def outer(x): + def inner(y): + return y * 2 + return inner(x) + 10 + +result = outer(value0) +return result +""" + result = node.exec_py( + pycode=code, + value0=5, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (20, None, None, None, None) + + +def test_eval_python_dict_operations(eval_context): + """Test dictionary creation and operations""" + node = EvalPython_5_5() + code = """ +data = { + 'a': value0, + 'b': value1, + 'c': value2 +} +return sum(data.values()), len(data) +""" + result = node.exec_py( + pycode=code, + value0=10, value1=20, value2=30, value3=0, value4=0 + ) + assert result == (60, 3, None, None, None) + + +def test_eval_python_list_operations(eval_context): + """Test list creation and operations""" + node = EvalPython_5_5() + code = """ +items = [value0, value1, value2, value3, value4] +filtered = [x for x in items if x > 5] +return len(filtered), sum(filtered) +""" + result = node.exec_py( + pycode=code, + value0=1, value1=10, value2=3, value3=15, value4=2 + ) + assert result == (2, 25, None, None, None) + + +def test_eval_python_early_return(eval_context): + """Test early return in conditional""" + node = EvalPython_5_5() + code = """ +if value0 > 100: + return "large" +return "small" +""" + result = node.exec_py( + pycode=code, + value0=150, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("large", None, None, None, None) + + +def test_eval_python_loop_with_return(eval_context): + """Test loop with return statement""" + node = EvalPython_5_5() + code = """ +total = 0 +for i in range(value0): + total += i +return total +""" + result = node.exec_py( + pycode=code, + value0=10, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (45, None, None, None, None) + + +def test_eval_python_exception_handling(eval_context): + """Test try/except blocks""" + node = EvalPython_5_5() + code = """ +try: + result = value0 / value1 +except ZeroDivisionError: + result = float('inf') +return result +""" + result = node.exec_py( + pycode=code, + value0=10, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (float('inf'), None, None, None, None) + + +def test_eval_python_none_values(eval_context): + """Test handling None values in inputs""" + node = EvalPython_5_5() + code = """ +return value0, value1 is None, value2 is None +""" + result = node.exec_py( + pycode=code, + value0=42, value1=None, value2=None, value3=0, value4=0 + ) + assert result == (42, True, True, None, None) + + +def test_eval_python_input_types(): + """Test that INPUT_TYPES returns correct structure""" + input_types = EvalPython_5_5.INPUT_TYPES() + assert "required" in input_types + assert "optional" in input_types + assert "pycode" in input_types["required"] + assert input_types["required"]["pycode"][0] == "PYCODE" + + # Check optional inputs + for i in range(5): + assert f"value{i}" in input_types["optional"] + + +def test_eval_python_metadata(): + """Test node metadata""" + assert EvalPython_5_5.FUNCTION == "exec_py" + assert EvalPython_5_5.CATEGORY == "eval" + assert len(EvalPython_5_5.RETURN_TYPES) == 5 + assert len(EvalPython_5_5.RETURN_NAMES) == 5 + assert all(name.startswith("item") for name in EvalPython_5_5.RETURN_NAMES) + + +def test_eval_python_factory_custom_inputs_outputs(eval_context): + """Test creating nodes with custom input/output counts""" + # Create a node with 3 inputs and 2 outputs + CustomNode = eval_python(inputs=3, outputs=2) + + node = CustomNode() + + # Verify INPUT_TYPES has correct number of inputs + input_types = CustomNode.INPUT_TYPES() + assert len(input_types["optional"]) == 3 + assert "value0" in input_types["optional"] + assert "value1" in input_types["optional"] + assert "value2" in input_types["optional"] + assert "value3" not in input_types["optional"] + + # Verify RETURN_TYPES has correct number of outputs + assert len(CustomNode.RETURN_TYPES) == 2 + assert len(CustomNode.RETURN_NAMES) == 2 + + # Test execution + result = node.exec_py( + pycode="return value0 + value1 + value2, value0 * 2", + value0=1, value1=2, value2=3 + ) + assert result == (6, 2) + + +def test_eval_python_factory_custom_name(eval_context): + """Test creating nodes with custom names""" + CustomNode = eval_python(inputs=2, outputs=2, name="MyCustomEval") + + assert CustomNode.__name__ == "MyCustomEval" + assert CustomNode.__qualname__ == "MyCustomEval" + + +def test_eval_python_factory_default_name(eval_context): + """Test that default name follows pattern""" + CustomNode = eval_python(inputs=3, outputs=4) + + assert CustomNode.__name__ == "EvalPython_3_4" + assert CustomNode.__qualname__ == "EvalPython_3_4" + + +def test_eval_python_factory_single_output(eval_context): + """Test node with single output""" + SingleOutputNode = eval_python(inputs=2, outputs=1) + + node = SingleOutputNode() + result = node.exec_py( + pycode="return value0 + value1", + value0=10, value1=20 + ) + assert result == (30,) + + +def test_eval_python_factory_many_outputs(eval_context): + """Test node with many outputs""" + ManyOutputNode = eval_python(inputs=1, outputs=10) + + node = ManyOutputNode() + result = node.exec_py( + pycode="return tuple(range(10))", + value0=0 + ) + assert result == (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + + +def test_eval_python_factory_fewer_returns_than_outputs(eval_context): + """Test that fewer returns are padded with None""" + Node = eval_python(inputs=2, outputs=5) + + node = Node() + result = node.exec_py( + pycode="return value0, value1", + value0=1, value1=2 + ) + assert result == (1, 2, None, None, None) + + +def test_eval_python_factory_more_returns_than_outputs(eval_context): + """Test that excess returns are truncated""" + Node = eval_python(inputs=2, outputs=3) + + node = Node() + result = node.exec_py( + pycode="return 1, 2, 3, 4, 5", + value0=0, value1=0 + ) + assert result == (1, 2, 3) + + +def test_eval_python_list_1_input_is_list(eval_context): + """Test EvalPython_List_1 with list input""" + node = EvalPython_List_1() + + # Verify INPUT_IS_LIST is set + assert EvalPython_List_1.INPUT_IS_LIST is True + assert EvalPython_List_1.OUTPUT_IS_LIST is None + + # Test that value0 receives a list + result = node.exec_py( + pycode="return sum(value0)", + value0=[1, 2, 3, 4, 5] + ) + assert result == (15,) + + +def test_eval_python_list_1_iterate_list(eval_context): + """Test EvalPython_List_1 iterating over list input""" + node = EvalPython_List_1() + + result = node.exec_py( + pycode="return [x * 2 for x in value0]", + value0=[1, 2, 3] + ) + assert result == ([2, 4, 6],) + + +def test_eval_python_1_list_output_is_list(eval_context): + """Test EvalPython_1_List with list output""" + node = EvalPython_1_List() + + # Verify OUTPUT_IS_LIST is set + assert EvalPython_1_List.INPUT_IS_LIST is False + assert EvalPython_1_List.OUTPUT_IS_LIST == (True,) + + # Test that returns a list + result = node.exec_py( + pycode="return list(range(value0))", + value0=5 + ) + assert result == ([0, 1, 2, 3, 4],) + + +def test_eval_python_1_list_multiple_items(eval_context): + """Test EvalPython_1_List returning multiple items in list""" + node = EvalPython_1_List() + + result = node.exec_py( + pycode="return ['a', 'b', 'c']", + value0=0 + ) + assert result == (['a', 'b', 'c'],) + + +def test_eval_python_list_list_both(eval_context): + """Test EvalPython_List_List with both list input and output""" + node = EvalPython_List_List() + + # Verify both are set + assert EvalPython_List_List.INPUT_IS_LIST is True + assert EvalPython_List_List.OUTPUT_IS_LIST == (True,) + + # Test processing list input and returning list output + result = node.exec_py( + pycode="return [x ** 2 for x in value0]", + value0=[1, 2, 3, 4] + ) + assert result == ([1, 4, 9, 16],) + + +def test_eval_python_list_list_filter(eval_context): + """Test EvalPython_List_List filtering a list""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return [x for x in value0 if x > 5]", + value0=[1, 3, 5, 7, 9, 11] + ) + assert result == ([7, 9, 11],) + + +def test_eval_python_list_list_transform(eval_context): + """Test EvalPython_List_List transforming list elements""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return [str(x).upper() for x in value0]", + value0=['hello', 'world', 'python'] + ) + assert result == (['HELLO', 'WORLD', 'PYTHON'],) + + +def test_eval_python_factory_with_list_flags(eval_context): + """Test factory function with custom list flags""" + # Create node with input as list but output scalar + ListInputNode = eval_python(inputs=1, outputs=1, input_is_list=True, output_is_list=None) + + assert ListInputNode.INPUT_IS_LIST is True + assert ListInputNode.OUTPUT_IS_LIST is None + + node = ListInputNode() + result = node.exec_py( + pycode="return len(value0)", + value0=[1, 2, 3, 4, 5] + ) + assert result == (5,) + + +def test_eval_python_factory_scalar_output_list(eval_context): + """Test factory function with scalar input and list output""" + ScalarToListNode = eval_python(inputs=1, outputs=1, input_is_list=None, output_is_list=(True,)) + + assert ScalarToListNode.INPUT_IS_LIST is False + assert ScalarToListNode.OUTPUT_IS_LIST == (True,) + + node = ScalarToListNode() + result = node.exec_py( + pycode="return [value0] * 3", + value0='x' + ) + assert result == (['x', 'x', 'x'],) + + +def test_eval_python_list_empty_list(eval_context): + """Test list nodes with empty lists""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return []", + value0=[] + ) + assert result == ([],) + + +def test_eval_python_backward_compatibility(): + """Test that EvalPython is an alias for EvalPython_5_5""" + assert EvalPython is EvalPython_5_5 From cc5f16caeb511c4332f623ed66a1a005cde044a6 Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Mon, 10 Nov 2025 10:06:14 -0800 Subject: [PATCH 7/9] tweak --- comfy_extras/eval_web/eval_python.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/eval_web/eval_python.js b/comfy_extras/eval_web/eval_python.js index d344d9e50..d8021a7b2 100644 --- a/comfy_extras/eval_web/eval_python.js +++ b/comfy_extras/eval_web/eval_python.js @@ -158,7 +158,8 @@ app.registerExtension({ }, async beforeRegisterNodeDef(nodeType, nodeData) { - if (nodeData.name === "EvalPython") { + // Handle all EvalPython node variants + if (nodeData.name.startsWith("EvalPython")) { const originalOnConfigure = nodeType.prototype.onConfigure; nodeType.prototype.onConfigure = function (info) { originalOnConfigure?.apply(this, arguments); From 37048fc1a2b816e48bdd3793d803f9fa22fa28b4 Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:23:34 -0800 Subject: [PATCH 8/9] fix issues with zooming in editor, simplify, improve list inputs and outputs --- comfy_extras/eval_web/ace_utils.js | 769 --------------------------- comfy_extras/eval_web/eval_python.js | 113 +++- comfy_extras/nodes/nodes_eval.py | 32 +- tests/unit/test_eval_nodes.py | 14 +- 4 files changed, 105 insertions(+), 823 deletions(-) delete mode 100644 comfy_extras/eval_web/ace_utils.js diff --git a/comfy_extras/eval_web/ace_utils.js b/comfy_extras/eval_web/ace_utils.js deleted file mode 100644 index 78c00d809..000000000 --- a/comfy_extras/eval_web/ace_utils.js +++ /dev/null @@ -1,769 +0,0 @@ -/** - * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode - * - * MIT License - * - * Copyright (c) 2024 Kevin Yuan - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -// Make modal window -function makeModal({ title = "Message", text = "No text", type = "info", parent = null, stylePos = "fixed", classes = [] } = {}) { - const overlay = document.createElement("div"); - Object.assign(overlay.style, { - display: "none", - position: stylePos, - background: "rgba(0 0 0 / 0.8)", - opacity: 0, - top: "0", - left: "0", - right: "0", - bottom: "0", - zIndex: "500", - transition: "all .8s", - cursor: "pointer", - }); - - const boxModal = document.createElement("div"); - Object.assign(boxModal.style, { - transition: "all 0.5s", - opacity: 0, - display: "none", - position: stylePos, - top: "50%", - left: "50%", - transform: "translate(-50%,-50%)", - background: "#525252", - minWidth: "300px", - fontFamily: "sans-serif", - zIndex: "501", - border: "1px solid rgb(255 255 255 / 45%)", - }); - - boxModal.className = "alekpet_modal_window"; - boxModal.classList.add(...classes); - - const boxModalBody = document.createElement("div"); - Object.assign(boxModalBody.style, { - display: "flex", - flexDirection: "column", - textAlign: "center", - }); - - boxModalBody.className = "alekpet_modal_body"; - - const boxModalHtml = ` -
-
${title}
-
-
-
${text}
`; - boxModalBody.innerHTML = boxModalHtml; - - const alekpet_modal_header = boxModalBody.querySelector(".alekpet_modal_header"); - Object.assign(alekpet_modal_header.style, { - display: "flex", - alignItems: "center", - }); - - const close = boxModalBody.querySelector(".alekpet_modal_close"); - Object.assign(close.style, { - cursor: "pointer", - }); - - let parentElement = document.body; - if (parent && parent.nodeType === 1) { - parentElement = parent; - } - - boxModal.append(boxModalBody); - parentElement.append(overlay, boxModal); - - const removeEvent = new Event("removeElements"); - const remove = () => { - animateTransitionProps(boxModal, { opacity: 0 }).then(() => - animateTransitionProps(overlay, { opacity: 0 }).then(() => { - parentElement.removeChild(boxModal); - parentElement.removeChild(overlay); - }), - ); - }; - - boxModal.addEventListener("removeElements", remove); - overlay.addEventListener("removeElements", remove); - - animateTransitionProps(overlay) - .then(() => { - overlay.addEventListener("click", () => { - overlay.dispatchEvent(removeEvent); - }); - animateTransitionProps(boxModal); - }) - .then(() => boxModal.querySelector(".alekpet_modal_close").addEventListener("click", () => boxModal.dispatchEvent(removeEvent))); -} - -function findWidget(node, value, attr = "name", func = "find") { - return node?.widgets ? node.widgets[func]((w) => (Array.isArray(value) ? value.includes(w[attr]) : w[attr] === value)) : null; -} - -function animateTransitionProps(el, props = { opacity: 1 }, preStyles = { display: "block" }) { - Object.assign(el.style, preStyles); - - el.style.transition = !el.style.transition || !window.getComputedStyle(el).getPropertyValue("transition") ? "all .8s" : el.style.transition; - - return new Promise((res) => { - setTimeout(() => { - Object.assign(el.style, props); - - const transstart = () => (el.isAnimating = true); - const transchancel = () => (el.isAnimating = false); - el.addEventListener("transitionstart", transstart); - el.addEventListener("transitioncancel", transchancel); - - el.addEventListener("transitionend", function transend() { - el.isAnimating = false; - el.removeEventListener("transitionend", transend); - el.removeEventListener("transitionend", transchancel); - el.removeEventListener("transitionend", transstart); - res(el); - }); - }, 100); - }); -} - -function animateClick(target, params = {}) { - const { opacityVal = 0.9, callback = () => {} } = params; - if (target?.isAnimating) return; - - const hide = +target.style.opacity === 0; - return animateTransitionProps(target, { - opacity: hide ? opacityVal : 0, - }).then((el) => { - const isHide = hide || el.style.display === "none"; - showHide({ elements: [target], hide: !hide }); - callback(); - return isHide; - }); -} - -function showHide({ elements = [], hide = null, displayProp = "block" } = {}) { - Array.from(elements).forEach((el) => { - if (hide !== null) { - el.style.display = !hide ? displayProp : "none"; - } else { - el.style.display = !el.style.display || el.style.display === "none" ? displayProp : "none"; - } - }); -} - -function isEmptyObject(obj) { - if (!obj) return true; - return Object.keys(obj).length === 0 && obj.constructor === Object; -} - -function makeElement(tag, attrs = {}) { - if (!tag) tag = "div"; - const element = document.createElement(tag); - Object.keys(attrs).forEach((key) => { - const currValue = attrs[key]; - if (key === "class") { - if (Array.isArray(currValue)) { - element.classList.add(...currValue); - } else if (currValue instanceof String || typeof currValue === "string") { - element.className = currValue; - } - } else if (key === "dataset") { - try { - if (Array.isArray(currValue)) { - currValue.forEach((datasetArr) => { - const [prop, propval] = Object.entries(datasetArr)[0]; - element.dataset[prop] = propval; - }); - } else { - Object.entries(currValue).forEach((datasetArr) => { - const [prop, propval] = datasetArr; - element.dataset[prop] = propval; - }); - } - } catch (err) { - console.log(err); - } - } else if (key === "style") { - if (typeof currValue === "object" && !Array.isArray(currValue) && Object.keys(currValue).length) { - Object.assign(element[key], currValue); - } else if (typeof currValue === "object" && Array.isArray(currValue) && currValue.length) { - element[key] = [...currValue]; - } else if (currValue instanceof String || typeof currValue === "string") { - element[key] = currValue; - } - } else if (["for"].includes(key)) { - element.setAttribute(key, currValue); - } else if (key === "children") { - element.append(...(currValue instanceof Array ? currValue : [currValue])); - } else if (key === "parent") { - currValue.append(element); - } else { - element[key] = currValue; - } - }); - return element; -} - -function isValidStyle(opt, strColor) { - let op = new Option().style; - if (!op.hasOwnProperty(opt)) return { result: false, color: "", color_hex: "" }; - - op[opt] = strColor; - - return { - result: op[opt] !== "", - color_rgb: op[opt], - color_hex: rgbToHex(op[opt]), - }; -} - -function rgbToHex(rgb) { - const regEx = new RegExp(/\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)/); - if (regEx.test(rgb)) { - let [, r, g, b] = regEx.exec(rgb); - r = parseInt(r).toString(16); - g = parseInt(g).toString(16); - b = parseInt(b).toString(16); - - r = r.length === 1 ? r + "0" : r; - g = g.length === 1 ? g + "0" : g; - b = b.length === 1 ? b + "0" : b; - - return `#${r}${g}${b}`; - } -} - -async function getDataJSON(url) { - try { - const response = await fetch(url); - const jsonData = await response.json(); - return jsonData; - } catch (err) { - return new Error(err); - } -} - -function deepMerge(target, source) { - if (source?.nodeType) return; - for (let key in source) { - if (source[key] instanceof Object && key in target) { - Object.assign(source[key], deepMerge(target[key], source[key])); - } - } - - Object.assign(target || {}, source); - return target; -} - -const THEME_MODAL_WINDOW_BASE = { - stylesTitle: { - background: "auto", - padding: "5px", - borderRadius: "6px", - marginBottom: "5px", - alignSelf: "stretch", - }, - stylesWrapper: { - display: "none", - opacity: 0, - minWidth: "220px", - position: "absolute", - left: "50%", - top: "50%", - transform: "translate(-50%, -50%)", - transition: "all .8s", - fontFamily: "monospace", - zIndex: 99999, - }, - stylesBox: { - display: "flex", - flexDirection: "column", - background: "#0e0e0e", - padding: "6px", - justifyContent: "center", - alignItems: "center", - gap: "3px", - textAlign: "center", - borderRadius: "6px", - color: "white", - border: "2px solid silver", - boxShadow: "2px 2px 4px silver", - maxWidth: "300px", - }, - stylesClose: { - position: "absolute", - top: "-10px", - right: "-10px", - background: "silver", - borderRadius: "50%", - width: "20px", - height: "20px", - cursor: "pointer", - display: "flex", - justifyContent: "center", - alignItems: "center", - fontSize: "0.8rem", - }, -}; - -const THEMES_MODAL_WINDOW = { - error: { - stylesTitle: { - ...THEME_MODAL_WINDOW_BASE.stylesTitle, - background: "#8f210f", - }, - stylesBox: { - ...THEME_MODAL_WINDOW_BASE.stylesBox, - background: "#3b2222", - boxShadow: "3px 3px 6px #141414", - border: "1px solid #f91b1b", - }, - stylesWrapper: { ...THEME_MODAL_WINDOW_BASE.stylesWrapper }, - stylesClose: { - ...THEME_MODAL_WINDOW_BASE.stylesClose, - background: "#3b2222", - }, - }, - warning: { - stylesTitle: { - ...THEME_MODAL_WINDOW_BASE.stylesTitle, - background: "#e99818", - }, - stylesBox: { - ...THEME_MODAL_WINDOW_BASE.stylesBox, - background: "#594e32", - boxShadow: "3px 3px 6px #141414", - border: "1px solid #e99818", - }, - stylesWrapper: { ...THEME_MODAL_WINDOW_BASE.stylesWrapper }, - stylesClose: { - ...THEME_MODAL_WINDOW_BASE.stylesClose, - background: "#594e32", - }, - }, - normal: { - stylesTitle: { - ...THEME_MODAL_WINDOW_BASE.stylesTitle, - background: "#108f0f", - }, - stylesBox: { - ...THEME_MODAL_WINDOW_BASE.stylesBox, - background: "#223b2a", - boxShadow: "3px 3px 6px #141414", - border: "1px solid #108f0f", - }, - stylesWrapper: { ...THEME_MODAL_WINDOW_BASE.stylesWrapper }, - stylesClose: { - ...THEME_MODAL_WINDOW_BASE.stylesClose, - background: "#223b2a", - }, - }, -}; - -const defaultOptions = { - auto: { - autohide: false, - autoshow: false, - autoremove: false, - propStyles: { opacity: 0 }, - propPreStyles: {}, - timewait: 2000, - }, - overlay: { - overlay_enabled: false, - overlayClasses: [], - overlayStyles: {}, - }, - close: { closeRemove: false, showClose: true }, - parent: null, -}; - -function createWindowModal({ textTitle = "Message", textBody = "Hello world!", textFooter = null, classesWrapper = [], stylesWrapper = {}, classesBox = [], stylesBox = {}, classesTitle = [], stylesTitle = {}, classesBody = [], stylesBody = {}, classesClose = [], stylesClose = {}, classesFooter = [], stylesFooter = {}, options = defaultOptions } = {}) { - // Check all options exist - const _options = deepMerge(JSON.parse(JSON.stringify(defaultOptions)), options); - - const { - parent, - overlay: { overlay_enabled, overlayClasses, overlayStyles }, - close: { closeRemove, showClose }, - auto: { autohide, autoshow, autoremove, timewait, propStyles, propPreStyles }, - } = _options; - - // Function past text(html) - function addText(text, parent) { - if (!parent) return; - - switch (typeof text) { - case "string": - if (/^\<.*\/?\>$/.test(text)) { - parent.innerHTML = text; - } else { - parent.textContent = text; - } - break; - case "object": - default: - if (Array.isArray(text)) { - text.forEach((element) => (element.nodeType === 1 || element.nodeType === 3) && parent.append(element)); - } else if (text.nodeType === 1 || text.nodeType === 3) parent.append(text); - } - } - - // Overlay - let overlayElement = null; - if (overlay_enabled) { - overlayElement = makeElement("div", { - class: [...overlayClasses], - style: { - display: "none", - position: "fixed", - background: "rgba(0 0 0 / 0.8)", - opacity: 0, - top: 0, - left: 0, - right: 0, - bottom: 0, - zIndex: 99999, - transition: "all .8s", - cursor: "pointer", - ...overlayStyles, - }, - }); - } - - // Wrapper - const wrapper_settings = makeElement("div", { - class: ["alekpet__wrapper__window", ...classesWrapper], - }); - - Object.assign(wrapper_settings.style, { - ...THEME_MODAL_WINDOW_BASE.stylesWrapper, - ...stylesWrapper, - }); - - // Box - const box__settings = makeElement("div", { - class: ["alekpet__window__box", ...classesBox], - }); - Object.assign(box__settings.style, { - ...THEME_MODAL_WINDOW_BASE.stylesBox, - ...stylesBox, - }); - - // Title - let box_settings_title = ""; - if (textTitle) { - box_settings_title = makeElement("div", { - class: ["alekpet__window__title", ...classesTitle], - }); - - Object.assign(box_settings_title.style, { - ...THEME_MODAL_WINDOW_BASE.stylesTitle, - ...stylesTitle, - }); - - // Add text (html) to title - addText(textTitle, box_settings_title); - } - // Body - let box_settings_body = ""; - if (textBody) { - box_settings_body = makeElement("div", { - class: ["alekpet__window__body", ...classesBody], - }); - - Object.assign(box_settings_body.style, { - display: "flex", - flexDirection: "column", - alignItems: "flex-end", - gap: "5px", - textWrap: "wrap", - ...stylesBody, - }); - - // Add text (html) to body - addText(textBody, box_settings_body); - } - - // Close button - const close__box__button = makeElement("div", { - class: ["close__box__button", ...classesClose], - textContent: "✖", - }); - - Object.assign(close__box__button.style, { - ...THEME_MODAL_WINDOW_BASE.stylesClose, - ...stylesClose, - }); - - if (!showClose) close__box__button.style.display = "none"; - - const closeEvent = new Event("closeModal"); - const closeModalWindow = function () { - overlay_enabled - ? animateTransitionProps(overlayElement, { - opacity: 0, - }) - .then(() => - animateTransitionProps(wrapper_settings, { - opacity: 0, - }), - ) - .then(() => { - if (closeRemove) { - parent.removeChild(wrapper_settings); - parent.removeChild(overlayElement); - } else { - showHide({ elements: [wrapper_settings, overlayElement] }); - } - }) - : animateTransitionProps(wrapper_settings, { - opacity: 0, - }).then(() => { - showHide({ elements: [wrapper_settings] }); - }); - }; - - close__box__button.addEventListener("closeModal", closeModalWindow); - - close__box__button.addEventListener("click", () => close__box__button.dispatchEvent(closeEvent)); - - close__box__button.onmouseenter = () => { - close__box__button.style.opacity = 0.8; - }; - - close__box__button.onmouseleave = () => { - close__box__button.style.opacity = 1; - }; - - box__settings.append(box_settings_title, box_settings_body); - - // Footer - if (textFooter) { - const box_settings_footer = makeElement("div", { - class: [...classesFooter], - }); - Object.assign(box_settings_footer.style, { - ...stylesFooter, - }); - - // Add text (html) to body - addText(textFooter, box_settings_footer); - - box__settings.append(box_settings_footer); - } - - wrapper_settings.append(close__box__button, box__settings); - - if (parent && parent.nodeType === 1) { - if (overlay_enabled) parent.append(overlayElement); - parent.append(wrapper_settings); - - if (autoshow) { - overlay_enabled - ? animateClick(overlayElement).then(() => - animateClick(wrapper_settings).then( - () => - autohide && - setTimeout( - () => - animateTransitionProps(wrapper_settings, { ...propStyles }, { ...propPreStyles }) - .then(() => animateTransitionProps(overlayElement, { ...propStyles }, { ...propPreStyles })) - .then(() => { - if (autoremove) { - parent.removeChild(wrapper_settings); - parent.removeChild(overlayElement); - } - }), - timewait, - ), - ), - ) - : animateClick(wrapper_settings).then(() => autohide && setTimeout(() => animateTransitionProps(wrapper_settings, { ...propStyles }, { ...propPreStyles }).then(() => autoremove && parent.removeChild(wrapper_settings)), timewait)); - } - } - - return wrapper_settings; -} - -// Prompt -async function comfyuiDesktopPrompt(title, message, defaultValue) { - try { - return await app.extensionManager.dialog.prompt({ - title, - message, - defaultValue, - }); - } catch (err) { - return prompt(title, message); - } -} - -// Alert -function comfyuiDesktopAlert(message) { - try { - app.extensionManager.toast.addAlert(message); - } catch (err) { - alert(message); - } -} - -// Confirm -function confirmModal({ title, message }) { - return new Promise((res) => { - const overlay = makeElement("div", { - class: ["alekpet_confOverlay"], - style: { - background: "rgba(0, 0, 0, 0.7)", - position: "fixed", - top: 0, - left: 0, - right: 0, - bottom: 0, - zIndex: 9999, - userSelect: "none", - }, - }); - - const modal = makeElement("div", { - class: ["alekpet_confModal"], - style: { - ...THEME_MODAL_WINDOW_BASE.stylesBox, - position: "fixed", - top: "50%", - left: "50%", - fontFamily: "monospace", - background: "rgb(92 186 255 / 20%)", - transform: "translate(-50%, -50%)", - borderColor: "rgba(92, 186, 255, 0.63)", - boxShadow: "rgba(92, 186, 255, 0.63) 2px 2px 4px", - }, - }); - - const titleEl = makeElement("div", { - class: ["alekpet_confTitle"], - style: { - ...THEME_MODAL_WINDOW_BASE.stylesTitle, - background: "rgba(92, 186, 255, 0.63)", - }, - textContent: title, - }); - - const messageEl = makeElement("div", { - class: ["alekpet_confMessage"], - style: { - display: "flex", - flexDirection: "column", - alignItems: "flex-end", - gap: "5px", - textWrap: "wrap", - }, - textContent: message, - }); - - const action_box = makeElement("div", { - class: ["alekpet_confActions"], - style: { - display: "flex", - gap: "5px", - width: "100%", - padding: "4px", - justifyContent: "flex-end", - }, - }); - - const remove = () => { - modal.remove(); - overlay.remove(); - }; - - const ok = makeElement("div", { - class: ["alekpet_confButtons", "alekpet_confButtonOk"], - style: { - background: "linear-gradient(45deg, green, limegreen) rgb(21, 100, 6)", - }, - textContent: "Ok", - onclick: (e) => { - res(true); - remove(); - }, - }); - - const Cancel = makeElement("div", { - class: ["alekpet_confButtons", "alekpet_confButtonCancel"], - style: { - background: "linear-gradient(45deg, #b64396, #a52a8b) rgb(135 3 161)", - }, - textContent: "Cancel", - onclick: (e) => { - res(false); - remove(); - }, - }); - - action_box.append(ok, Cancel); - modal.append(titleEl, messageEl, action_box); - overlay.append(modal); - document.body.append(overlay); - }); -} - -async function comfyuiDesktopConfirm(message) { - try { - const result = await confirmModal({ - title: "Confirm", - message: message, - }); - - // Wait update comfyui frontend! Confirm Cancel not return value! Fixed in ComfyUI_frontend ver. v1.10.8 - // https://github.com/Comfy-Org/ComfyUI_frontend/issues/2649 - // const result = await app.extensionManager.dialog.confirm({ - // title: "Confirm", - // message: message, - // }); - return result; - } catch (err) { - return confirm(message); - } -} - -export { - makeModal, - createWindowModal, - animateTransitionProps, - animateClick, - showHide, - makeElement, - getDataJSON, - isEmptyObject, - isValidStyle, - rgbToHex, - findWidget, - THEMES_MODAL_WINDOW, - // - comfyuiDesktopConfirm, - comfyuiDesktopPrompt, - comfyuiDesktopAlert, -}; diff --git a/comfy_extras/eval_web/eval_python.js b/comfy_extras/eval_web/eval_python.js index d8021a7b2..a7e2fa19f 100644 --- a/comfy_extras/eval_web/eval_python.js +++ b/comfy_extras/eval_web/eval_python.js @@ -24,7 +24,6 @@ * SOFTWARE. */ import { app } from "../../scripts/app.js"; -import { makeElement, findWidget } from "./ace_utils.js"; // Load Ace editor using script tag for Safari compatibility // The noconflict build includes AMD loader that works in all browsers @@ -34,7 +33,7 @@ const aceLoadPromise = new Promise((resolve) => { ace = window.ace; resolve(); } else { - const script = document.createElement('script'); + const script = document.createElement("script"); script.src = "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict/ace.js"; script.onload = () => { ace = window.ace; @@ -45,44 +44,108 @@ const aceLoadPromise = new Promise((resolve) => { } }); +// todo: do we really want to do this here? await aceLoadPromise; +const findWidget = (node, value, attr = "name", func = "find") => { + return node?.widgets ? node.widgets[func]((w) => (Array.isArray(value) ? value.includes(w[attr]) : w[attr] === value)) : null; +}; +const makeElement = (tag, attrs = {}) => { + if (!tag) tag = "div"; + const element = document.createElement(tag); + Object.keys(attrs).forEach((key) => { + const currValue = attrs[key]; + if (key === "class") { + if (Array.isArray(currValue)) { + element.classList.add(...currValue); + } else if (currValue instanceof String || typeof currValue === "string") { + element.className = currValue; + } + } else if (key === "dataset") { + try { + if (Array.isArray(currValue)) { + currValue.forEach((datasetArr) => { + const [prop, propval] = Object.entries(datasetArr)[0]; + element.dataset[prop] = propval; + }); + } else { + Object.entries(currValue).forEach((datasetArr) => { + const [prop, propval] = datasetArr; + element.dataset[prop] = propval; + }); + } + } catch (err) { + // todo: what is this trying to do? + } + } else if (key === "style") { + if (typeof currValue === "object" && !Array.isArray(currValue) && Object.keys(currValue).length) { + Object.assign(element[key], currValue); + } else if (typeof currValue === "object" && Array.isArray(currValue) && currValue.length) { + element[key] = [...currValue]; + } else if (currValue instanceof String || typeof currValue === "string") { + element[key] = currValue; + } + } else if (["for"].includes(key)) { + element.setAttribute(key, currValue); + } else if (key === "children") { + element.append(...(currValue instanceof Array ? currValue : [currValue])); + } else if (key === "parent") { + currValue.append(element); + } else { + element[key] = currValue; + } + }); + return element; +}; -function getPosition(node, ctx, w_width, y, n_height) { +const getPosition = (node, ctx, w_width, y, n_height) => { const margin = 5; const rect = ctx.canvas.getBoundingClientRect(); - const transform = new DOMMatrix() - .scaleSelf(rect.width / ctx.canvas.width, rect.height / ctx.canvas.height) - .multiplySelf(ctx.getTransform()) - .translateSelf(margin, margin + y); - const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); + const transform = ctx.getTransform(); + const scale = app.canvas.ds.scale; + + // The context is already transformed to draw at the widget position + // transform.e and transform.f give us the canvas coordinates (in canvas pixels) + // We need to convert these to screen pixels by accounting for the canvas scale + // rect gives us the canvas element's position on the page + + // The transform matrix has scale baked in (transform.a = transform.d = scale) + // transform.e and transform.f are the translation in canvas-pixel space + const canvasPixelToScreenPixel = rect.width / ctx.canvas.width; + + const x = transform.e * canvasPixelToScreenPixel + rect.left; + const y_pos = transform.f * canvasPixelToScreenPixel + rect.top; + + // Convert widget dimensions from canvas coordinates to screen pixels + const scaledWidth = w_width * scale; + const scaledHeight = (n_height - y - 15) * scale; + const scaledMargin = margin * scale; + const scaledY = y * scale; return { - transformOrigin: "0 0", - transform: scale, - left: `${transform.a + transform.e + rect.left}px`, - top: `${transform.d + transform.f + rect.top}px`, - maxWidth: `${w_width - margin * 2}px`, - maxHeight: `${n_height - margin * 2 - y - 15}px`, - width: `${w_width - margin * 2}px`, - height: "90%", + left: `${x + scaledMargin}px`, + top: `${y_pos + scaledY + scaledMargin}px`, + width: `${scaledWidth - scaledMargin * 2}px`, + maxWidth: `${scaledWidth - scaledMargin * 2}px`, + height: `${scaledHeight - scaledMargin * 2}px`, + maxHeight: `${scaledHeight - scaledMargin * 2}px`, position: "absolute", scrollbarColor: "var(--descrip-text) var(--bg-color)", scrollbarWidth: "thin", zIndex: app.graph._nodes.indexOf(node), }; -} +}; // Create code editor widget -function codeEditor(node, inputName, inputData) { +const codeEditor = (node, inputName, inputData) => { const widget = { - type: "pycode", + type: "code_block_python", name: inputName, options: { hideOnZoom: true }, value: inputData[1]?.default || "", draw(ctx, node, widgetWidth, y) { - const hidden = node.flags?.collapsed || (!!this.options.hideOnZoom && app.canvas.ds.scale < 0.5) || this.type === "converted-widget" || this.type === "hidden"; + const hidden = node.flags?.collapsed || (!!this.options.hideOnZoom && app.canvas.ds.scale < 0.5) || this.type === "converted-widget" || this.type === "hidden" || this.type === "converted-widget"; this.codeElement.hidden = hidden; @@ -122,19 +185,19 @@ function codeEditor(node, inputName, inputData) { }; return widget; -} +}; // Trigger workflow change tracking -function markWorkflowChanged() { +const markWorkflowChanged = () => { app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState(); -} +}; // Register extensions app.registerExtension({ name: "Comfy.EvalPython", getCustomWidgets(app) { return { - PYCODE: (node, inputName, inputData) => { + CODE_BLOCK_PYTHON: (node, inputName, inputData) => { const widget = codeEditor(node, inputName, inputData); widget.editor.getSession().on("change", () => { @@ -165,7 +228,7 @@ app.registerExtension({ originalOnConfigure?.apply(this, arguments); if (info?.widgets_values?.length) { - const widgetCodeIndex = findWidget(this, "pycode", "type", "findIndex"); + const widgetCodeIndex = findWidget(this, "code_block_python", "type", "findIndex"); const editor = this.widgets[widgetCodeIndex]?.editor; if (editor) { diff --git a/comfy_extras/nodes/nodes_eval.py b/comfy_extras/nodes/nodes_eval.py index ff04522eb..bda41a02d 100644 --- a/comfy_extras/nodes/nodes_eval.py +++ b/comfy_extras/nodes/nodes_eval.py @@ -36,7 +36,7 @@ return {", ".join([f"value{i}" for i in range(inputs)])} return { "required": { "pycode": ( - "PYCODE", + "CODE_BLOCK_PYTHON", { "default": default_code }, @@ -47,16 +47,19 @@ return {", ".join([f"value{i}" for i in range(inputs)])} RETURN_TYPES = tuple(IO.ANY for _ in range(outputs)) RETURN_NAMES = tuple(f"item{i}" for i in range(outputs)) - OUTPUT_IS_LIST = output_is_list - INPUT_IS_LIST = input_is_list is not None FUNCTION = "exec_py" DESCRIPTION = "" CATEGORY = "eval" + @classmethod + def VALIDATE_INPUTS(cls, *args, **kwargs): + ctx = current_execution_context() + + return ctx.configuration.enable_eval + def exec_py(self, pycode, **kwargs): ctx = current_execution_context() - # Ensure all value inputs have a default of None kwargs = { **{f"value{i}": None for i in range(inputs)}, **kwargs, @@ -68,11 +71,9 @@ return {", ".join([f"value{i}" for i in range(inputs)])} if not ctx.configuration.enable_eval: raise ValueError("Python eval is disabled") - # Extract value arguments in order value_args = [kwargs.pop(f"value{i}") for i in range(inputs)] arg_names = ", ".join(f"value{i}=None" for i in range(inputs)) - # Wrap pycode in a function to support return statements wrapped_code = f"def _eval_func({arg_names}):\n" for line in pycode.splitlines(): wrapped_code += " " + line + "\n" @@ -83,13 +84,8 @@ return {", ".join([f"value{i}" for i in range(inputs)])} "print": print, } - # Execute wrapped function definition exec(wrapped_code, globals_for_eval) - - # Call the function with value arguments results = globals_for_eval["_eval_func"](*value_args) - - # Normalize results to match output count if not isinstance(results, tuple): results = (results,) @@ -100,22 +96,24 @@ return {", ".join([f"value{i}" for i in range(inputs)])} return results - # Set the class name for better debugging/introspection + # todo: interact better with the weird comfyui machinery for this + if input_is_list is not None: + setattr(EvalPythonNode, "INPUT_IS_LIST", input_is_list) + + if output_is_list is not None: + setattr(EvalPythonNode, "OUTPUT_IS_LIST", output_is_list) + EvalPythonNode.__name__ = name EvalPythonNode.__qualname__ = name return EvalPythonNode -# Create the default EvalPython node with 5 inputs and 5 outputs +EvalPython_1_1 = eval_python(inputs=1, outputs=1, name="EvalPython_1_1") EvalPython_5_5 = eval_python(inputs=5, outputs=5, name="EvalPython_5_5") -EvalPython = EvalPython_5_5 # Backward compatibility alias - -# Create list variants EvalPython_List_1 = eval_python(inputs=1, outputs=1, name="EvalPython_List_1", input_is_list=True, output_is_list=None) EvalPython_1_List = eval_python(inputs=1, outputs=1, name="EvalPython_1_List", input_is_list=None, output_is_list=(True,)) EvalPython_List_List = eval_python(inputs=1, outputs=1, name="EvalPython_List_List", input_is_list=True, output_is_list=(True,)) - export_custom_nodes() export_package_as_web_directory("comfy_extras.eval_web") diff --git a/tests/unit/test_eval_nodes.py b/tests/unit/test_eval_nodes.py index f2cb0c763..71076daef 100644 --- a/tests/unit/test_eval_nodes.py +++ b/tests/unit/test_eval_nodes.py @@ -4,9 +4,8 @@ from unittest.mock import Mock, patch from comfy.cli_args import default_configuration from comfy.execution_context import context_configuration from comfy_extras.nodes.nodes_eval import ( - EvalPython, - EvalPython_5_5, eval_python, + EvalPython_5_5, EvalPython_List_1, EvalPython_1_List, EvalPython_List_List, @@ -447,7 +446,7 @@ def test_eval_python_input_types(): assert "required" in input_types assert "optional" in input_types assert "pycode" in input_types["required"] - assert input_types["required"]["pycode"][0] == "PYCODE" + assert input_types["required"]["pycode"][0] == "CODE_BLOCK_PYTHON" # Check optional inputs for i in range(5): @@ -560,7 +559,6 @@ def test_eval_python_list_1_input_is_list(eval_context): # Verify INPUT_IS_LIST is set assert EvalPython_List_1.INPUT_IS_LIST is True - assert EvalPython_List_1.OUTPUT_IS_LIST is None # Test that value0 receives a list result = node.exec_py( @@ -586,7 +584,6 @@ def test_eval_python_1_list_output_is_list(eval_context): node = EvalPython_1_List() # Verify OUTPUT_IS_LIST is set - assert EvalPython_1_List.INPUT_IS_LIST is False assert EvalPython_1_List.OUTPUT_IS_LIST == (True,) # Test that returns a list @@ -652,7 +649,6 @@ def test_eval_python_factory_with_list_flags(eval_context): ListInputNode = eval_python(inputs=1, outputs=1, input_is_list=True, output_is_list=None) assert ListInputNode.INPUT_IS_LIST is True - assert ListInputNode.OUTPUT_IS_LIST is None node = ListInputNode() result = node.exec_py( @@ -666,7 +662,6 @@ def test_eval_python_factory_scalar_output_list(eval_context): """Test factory function with scalar input and list output""" ScalarToListNode = eval_python(inputs=1, outputs=1, input_is_list=None, output_is_list=(True,)) - assert ScalarToListNode.INPUT_IS_LIST is False assert ScalarToListNode.OUTPUT_IS_LIST == (True,) node = ScalarToListNode() @@ -686,8 +681,3 @@ def test_eval_python_list_empty_list(eval_context): value0=[] ) assert result == ([],) - - -def test_eval_python_backward_compatibility(): - """Test that EvalPython is an alias for EvalPython_5_5""" - assert EvalPython is EvalPython_5_5 From 83fddf4cb5d7565166ad29e107cfc2496f0e09a3 Mon Sep 17 00:00:00 2001 From: doctorpangloss <2229300+doctorpangloss@users.noreply.github.com> Date: Mon, 10 Nov 2025 11:51:49 -0800 Subject: [PATCH 9/9] fix tagging --- .github/workflows/docker-build-amd.yml | 4 +++- .github/workflows/docker-build.yml | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docker-build-amd.yml b/.github/workflows/docker-build-amd.yml index d1a61cf0d..4c4e9c04f 100644 --- a/.github/workflows/docker-build-amd.yml +++ b/.github/workflows/docker-build-amd.yml @@ -34,7 +34,9 @@ jobs: tags: | type=raw,value=latest-rocm,enable={{is_default_branch}} type=sha,prefix=,suffix=-rocm - type=semver,pattern={{version}},suffix=-rocm + type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-rocm + type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-rocm + type=match,pattern=v?(\d+\.\d+),group=1,suffix=-rocm - name: Build and push ROCm (AMD) image uses: docker/build-push-action@v6 with: diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index af8e27a76..861ae3dcf 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -41,8 +41,12 @@ jobs: type=raw,value=latest-cuda,enable={{is_default_branch}} type=sha,prefix= type=sha,prefix=,suffix=-cuda - type=semver,pattern={{version}} - type=semver,pattern={{version}},suffix=-cuda + type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1 + type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-cuda + type=match,pattern=v?(\d+\.\d+\.\d+),group=1 + type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-cuda + type=match,pattern=v?(\d+\.\d+),group=1 + type=match,pattern=v?(\d+\.\d+),group=1,suffix=-cuda - name: Build and push CUDA (NVIDIA) image uses: docker/build-push-action@v6 with: