diff --git a/.github/workflows/docker-build-amd.yml b/.github/workflows/docker-build-amd.yml index d1a61cf0d..4c4e9c04f 100644 --- a/.github/workflows/docker-build-amd.yml +++ b/.github/workflows/docker-build-amd.yml @@ -34,7 +34,9 @@ jobs: tags: | type=raw,value=latest-rocm,enable={{is_default_branch}} type=sha,prefix=,suffix=-rocm - type=semver,pattern={{version}},suffix=-rocm + type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-rocm + type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-rocm + type=match,pattern=v?(\d+\.\d+),group=1,suffix=-rocm - name: Build and push ROCm (AMD) image uses: docker/build-push-action@v6 with: diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index af8e27a76..861ae3dcf 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -41,8 +41,12 @@ jobs: type=raw,value=latest-cuda,enable={{is_default_branch}} type=sha,prefix= type=sha,prefix=,suffix=-cuda - type=semver,pattern={{version}} - type=semver,pattern={{version}},suffix=-cuda + type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1 + type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-cuda + type=match,pattern=v?(\d+\.\d+\.\d+),group=1 + type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-cuda + type=match,pattern=v?(\d+\.\d+),group=1 + type=match,pattern=v?(\d+\.\d+),group=1,suffix=-cuda - name: Build and push CUDA (NVIDIA) image uses: docker/build-push-action@v6 with: diff --git a/Dockerfile b/Dockerfile index bb0c04d1b..942cacb57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,7 @@ RUN pip install uv && uv --version && \ # install sageattention ADD pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl /workspace/pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl -RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry flash-attn "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl" +RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry "flash-attn<=2.8.0" "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl" # this exotic command will determine the correct torchaudio to install for the image RUN <<-EOF python -c 'import torch, re, subprocess @@ -66,7 +66,7 @@ WORKDIR /workspace # addresses https://github.com/pytorch/pytorch/issues/104801 # and issues reported by importing nodes_canny # smoke test -RUN python -c "import torch; import xformers; import sageattention; import cv2" && comfyui --quick-test-for-ci --cpu --cwd /workspace +RUN python -c "import torch; import xformers; import sageattention; import cv2; import diffusers.hooks" && comfyui --quick-test-for-ci --cpu --cwd /workspace EXPOSE 8188 CMD ["python", "-m", "comfy.cmd.main", "--listen", "--use-sage-attention", "--reserve-vram=0", "--logging-level=INFO", "--enable-cors"] diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml index 1e983a6c4..260c75a38 100644 --- a/comfy/api/openapi.yaml +++ b/comfy/api/openapi.yaml @@ -354,19 +354,29 @@ paths: required: true description: | The ID of the prompt to query. - responses: - 204: - description: | - The prompt is still in progress - 200: - description: | - Prompt outputs - content: - application/json: - $ref: "#/components/schemas/Outputs" - 404: - description: | - The prompt was not found + responses: + 204: + description: | + The prompt is still in progress + 200: + description: | + Prompt outputs + content: + application/json: + schema: + $ref: "#/components/schemas/Outputs" + 404: + description: | + The prompt was not found + 500: + description: | + An execution error occurred while processing the prompt. + content: + application/json: + description: + An execution status directly from the workers + schema: + $ref: "#/components/schemas/ExecutionStatusAsDict" /api/v1/prompts: get: operationId: list_prompts @@ -861,6 +871,22 @@ components: type: array items: type: string + node_errors: + type: object + description: "Detailed validation errors per node" + additionalProperties: + type: object + properties: + errors: + type: array + items: + $ref: "#/components/schemas/ValidationErrorDict" + dependent_outputs: + type: array + items: + type: string + class_type: + type: string required: - type - details diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fe964ba45..63157dd24 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -156,6 +156,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.") parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") + parser.add_argument("--enable-eval", action="store_true", help="Enable nodes that can evaluate Python code in workflows.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--create-directories", action="store_true", diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 953747f95..903f46f1a 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -169,6 +169,7 @@ class Configuration(dict): whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled. default_device (Optional[int]): Set the id of the default device, all other devices will stay visible. block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental). + enable_eval (Optional[bool]): Enable nodes that can evaluate Python code in workflows. """ def __init__(self, **kwargs): @@ -288,6 +289,7 @@ class Configuration(dict): self.database_url: str = db_config() self.default_device: Optional[int] = None self.block_runtime_package_installation = None + self.enable_eval: Optional[bool] = False for key, value in kwargs.items(): self[key] = value @@ -420,6 +422,7 @@ class FlattenAndAppendAction(argparse.Action): Custom action to handle comma-separated values and multiple invocations of the same argument, flattening them into a single list. """ + def __call__(self, parser, namespace, values, option_string=None): items = getattr(namespace, self.dest, None) if items is None: diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index cdf16e0f1..f79f44ac1 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -1,23 +1,23 @@ -from asyncio import AbstractEventLoop -from collections import defaultdict - -import aiohttp import asyncio import uuid -from aiohttp import WSMessage, ClientResponse, ClientTimeout -from pathlib import Path +from asyncio import AbstractEventLoop from typing import Optional, List from urllib.parse import urlparse, urljoin +import aiohttp +from aiohttp import WSMessage, ClientResponse, ClientTimeout +from opentelemetry import trace + from .client_types import V1QueuePromptResponse from ..api.api_client import JSONEncoder from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt_request import PromptRequest from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict from ..api.schemas import immutabledict -from ..component_model.file_output_path import file_output_path from ..component_model.outputs_types import OutputsDict +tracer = trace.get_tracer(__name__) + class AsyncRemoteComfyClient: """ @@ -57,6 +57,27 @@ class AsyncRemoteComfyClient: def session(self) -> aiohttp.ClientSession: return self._ensure_session() + def _build_headers(self, accept_header: str, prefer_header: Optional[str] = None, content_type: str = "application/json") -> dict: + """Build HTTP headers for requests.""" + headers = {'Content-Type': content_type, 'Accept': accept_header} + if prefer_header: + headers['Prefer'] = prefer_header + return headers + + @tracer.start_as_current_span("Post Prompt") + async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse: + """ + Common method to POST a prompt to a given endpoint. + :param prompt: The prompt to send + :param endpoint: The API endpoint (e.g., "/api/v1/prompts") + :param accept_header: The Accept header value + :param prefer_header: Optional Prefer header value + :return: The response object + """ + prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) + headers = self._build_headers(accept_header, prefer_header) + return await self.session.post(urljoin(self.server_address, endpoint), data=prompt_json, headers=headers) + async def len_queue(self) -> int: async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response: if response.status == 200: @@ -65,34 +86,30 @@ class AsyncRemoteComfyClient: else: raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}") - async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str: + async def queue_and_forget_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = "respond-async", accept_header: str = "application/json") -> str: """ Calls the API to queue a prompt, and forgets about it :param prompt: + :param prefer_header: The Prefer header value (e.g., "respond-async" or None) + :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the task ID """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response: if 200 <= response.status < 400: response_json = await response.json() return response_json["prompt_id"] else: - raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}") - async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: + async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse: """ Calls the API to queue a prompt. :param prompt: + :param prefer_header: The Prefer header value (e.g., "respond-async" or None) + :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response: if 200 <= response.status < 400: return V1QueuePromptResponse(**(await response.json())) else: @@ -112,17 +129,13 @@ class AsyncRemoteComfyClient: :param prompt: :return: """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - headers = {'Content-Type': 'application/json', 'Accept': 'image/png'} - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers=headers) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", "image/png") as response: if 200 <= response.status < 400: return await response.read() else: raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + @tracer.start_as_current_span("Post Prompt (UI)") async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict: """ Uses the comfyui UI API calls to retrieve the outputs dictionary @@ -160,3 +173,41 @@ class AsyncRemoteComfyClient: # images have filename, subfolder, type keys # todo: use the OpenAPI spec for this when I get around to updating it return history_json[prompt_id].outputs + + async def get_prompt_status(self, prompt_id: str) -> ClientResponse: + """ + Get the status of a prompt by ID using the API endpoint. + :param prompt_id: The prompt ID to query + :return: The ClientResponse object (caller should check status and read body) + """ + return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}")) + + @tracer.start_as_current_span("Poll Prompt Until Done") + async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]: + """ + Poll a prompt until it's done (200), errors (500), or times out. + :param prompt_id: The prompt ID to poll + :param max_attempts: Maximum number of polling attempts + :param poll_interval: Time to wait between polls in seconds + :return: Tuple of (status_code, response_json or None) + """ + span = trace.get_current_span() + span.set_attribute("prompt_id", prompt_id) + span.set_attribute("max_attempts", max_attempts) + + for _ in range(max_attempts): + async with await self.get_prompt_status(prompt_id) as response: + if response.status == 200: + return response.status, await response.json() + elif response.status == 500: + return response.status, await response.json() + elif response.status == 404: + return response.status, None + elif response.status == 204: + # Still in progress + await asyncio.sleep(poll_interval) + else: + # Unexpected status + return response.status, None + # Timeout + return 408, None diff --git a/comfy/client/sdxl_with_refiner_workflow.py b/comfy/client/sdxl_with_refiner_workflow.py index 8a4923b86..d7dab309a 100644 --- a/comfy/client/sdxl_with_refiner_workflow.py +++ b/comfy/client/sdxl_with_refiner_workflow.py @@ -1,8 +1,6 @@ import copy from typing import TypeAlias, Union -from ..api.components.schema.prompt import PromptDict, Prompt - JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None] _BASE_PROMPT: JSON = { "4": { @@ -161,7 +159,7 @@ def sdxl_workflow_with_refiner(prompt: str, sampler="euler_ancestral", scheduler="normal", filename_prefix="sdxl_", - seed=42) -> PromptDict: + seed=42) -> dict: prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT) prompt_dict["17"]["inputs"]["text"] = prompt prompt_dict["20"]["inputs"]["text"] = negative_prompt @@ -188,4 +186,4 @@ def sdxl_workflow_with_refiner(prompt: str, prompt_dict["14"]["inputs"]["scheduler"] = scheduler prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix - return Prompt.validate(prompt_dict) + return prompt_dict diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 8095d53be..5ad5865c2 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -1246,16 +1246,36 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty if len(good_outputs) == 0: errors_list = [] + extra_info = {} for o, _errors in errors: for error in _errors: errors_list.append(f"{error['message']}: {error['details']}") + # Aggregate exception_type and traceback from validation errors + if 'extra_info' in error and error['extra_info']: + if 'exception_type' in error['extra_info'] and 'exception_type' not in extra_info: + extra_info['exception_type'] = error['extra_info']['exception_type'] + if 'traceback' in error['extra_info'] and 'traceback' not in extra_info: + extra_info['traceback'] = error['extra_info']['traceback'] + + # Per OpenAPI spec, extra_info must have exception_type and traceback + # For non-exception validation errors, provide synthetic values + if 'exception_type' not in extra_info: + extra_info['exception_type'] = 'ValidationError' + if 'traceback' not in extra_info: + # Capture current stack for validation errors that don't have their own traceback + extra_info['traceback'] = traceback.format_stack() + + # Include detailed node_errors for actionable debugging information + if node_errors: + extra_info['node_errors'] = node_errors + errors_list = "\n".join(errors_list) error = { "type": "prompt_outputs_failed_validation", "message": "Prompt outputs failed validation", "details": errors_list, - "extra_info": {} + "extra_info": extra_info } return ValidationTuple(False, error, list(good_outputs), node_errors) @@ -1301,7 +1321,7 @@ class PromptQueue(AbstractPromptQueue): return copy.deepcopy(item_with_future.queue_tuple), task_id def task_done(self, item_id: str, outputs: HistoryResultDict, - status: Optional[ExecutionStatus]): + status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None): history_result = outputs with self.mutex: queue_item = self.currently_running.pop(item_id) @@ -1311,7 +1331,7 @@ class PromptQueue(AbstractPromptQueue): status_dict = None if status is not None: - status_dict: Optional[ExecutionStatusAsDict] = status.as_dict() + status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details) outputs_ = history_result["outputs"] # Remove sensitive data from extra_data before storing in history diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 182fa2107..a921a8eed 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -79,12 +79,25 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module. await e.execute_async(item[2], prompt_id, item[3], item[4]) need_gc = True + + # Extract error details from status_messages if there's an error + error_details = None + if not e.success: + for event, data in e.status_messages: + if event == "execution_error": + error_details = data + break + + # Convert status_messages tuples to string messages for backward compatibility + messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages] + q.task_done(item_id, e.history_result, status=queue_types.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, - messages=e.status_messages)) + messages=messages), + error_details=error_details) if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 5a8d0f7f8..7c06e4cee 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -15,6 +15,7 @@ import shutil import warnings import fsspec +from opentelemetry.instrumentation.urllib3 import URLLib3Instrumentor from .. import options from ..app import logger @@ -41,6 +42,7 @@ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_ warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.") warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") warnings.filterwarnings('ignore', category=FutureWarning, message=r'`torch\.cuda\.amp\.custom_fwd.*') +warnings.filterwarnings("ignore", category=UserWarning, message="Please use the new API settings to control TF32 behavior.*") warnings.filterwarnings("ignore", message="Importing from timm.models.registry is deprecated, please import via timm.models", category=FutureWarning) warnings.filterwarnings("ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning) warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning) @@ -125,9 +127,11 @@ def _create_tracer(): from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter + from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS + from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor + from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor from ..tracing_compatibility import ProgressSpanSampler from ..tracing_compatibility import patch_spanbuilder_set_channel - from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor resource = Resource.create({ service_attributes.SERVICE_NAME: args.otel_service_name, @@ -141,18 +145,24 @@ def _create_tracer(): has_endpoint = args.otel_exporter_otlp_endpoint is not None if has_endpoint: - otlp_exporter = OTLPSpanExporter() + exporter = OTLPSpanExporter() else: - otlp_exporter = SpanExporter() + exporter = SpanExporter() - processor = BatchSpanProcessor(otlp_exporter) + processor = BatchSpanProcessor(exporter) provider.add_span_processor(processor) # enable instrumentation patch_spanbuilder_set_channel() + AioPikaInstrumentor().instrument() AioHttpServerInstrumentor().instrument() + AioHttpClientInstrumentor().instrument() RequestsInstrumentor().instrument() + URLLib3Instrumentor().instrument() + + + provider.add_span_processor(BaggageSpanProcessor(ALLOW_ALL_BAGGAGE_KEYS)) # makes this behave better as a library return trace.get_tracer(args.otel_service_name, tracer_provider=provider) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index df37647d3..7a4be89a8 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -116,6 +116,24 @@ async def compress_body(request: web.Request, handler): return response +@web.middleware +async def opentelemetry_middleware(request: web.Request, handler): + """Middleware to extract and propagate OpenTelemetry context from request headers""" + from opentelemetry import propagate, context + + # Extract OpenTelemetry context from headers + carrier = dict(request.headers) + ctx = propagate.extract(carrier) + + # Attach context and execute handler + token = context.attach(ctx) + try: + response = await handler(request) + return response + finally: + context.detach(token) + + def create_cors_middleware(allowed_origin: str): @web.middleware async def cors_middleware(request: web.Request, handler): @@ -127,7 +145,7 @@ def create_cors_middleware(allowed_origin: str): response.headers['Access-Control-Allow-Origin'] = allowed_origin response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' - response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' + response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate' response.headers['Access-Control-Allow-Credentials'] = 'true' return response @@ -224,7 +242,7 @@ class PromptServer(ExecutorToClientProgress): self._external_address: Optional[str] = None self.background_tasks: dict[str, Task] = dict() - middlewares = [cache_control, deprecation_warning] + middlewares = [opentelemetry_middleware, cache_control, deprecation_warning] if args.enable_compress_response_body: middlewares.append(compress_body) @@ -867,9 +885,19 @@ class PromptServer(ExecutorToClientProgress): return web.json_response(status=404) elif prompt_id in history_items: history_entry = history_items[prompt_id] + # Check if execution resulted in an error + if "status" in history_entry: + status = history_entry["status"] + if isinstance(status, dict) and status.get("status_str") == "error": + # Return ExecutionStatusAsDict format with status 500, matching POST /api/v1/prompts behavior + return web.Response( + body=json.dumps(status), + status=500, + content_type="application/json" + ) return web.json_response(history_entry["outputs"]) else: - return web.json_response(status=500) + return web.Response(status=404, reason="prompt not found in expected state") @routes.post("/api/v1/prompts") async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: @@ -877,9 +905,13 @@ class PromptServer(ExecutorToClientProgress): if accept == '*/*': accept = "application/json" content_type = request.headers.get("content-type", "application/json") - preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + " " + accept + + # handle media type parameters like "application/json+respond-async" if "+" in content_type: content_type = content_type.split("+")[0] + if "+" in accept: + accept = accept.split("+")[0] wait = not "respond-async" in preferences @@ -965,7 +997,8 @@ class PromptServer(ExecutorToClientProgress): return web.Response(body=str(ex), status=500) if result.status is not None and result.status.status_str == "error": - return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json") + status_dict = result.status.as_dict(error_details=result.error_details) + return web.Response(body=json.dumps(status_dict), status=500, content_type="application/json") # find images and read them output_images: List[FileOutput] = [] for node_id, node in result.outputs.items(): diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 1b3d5de4c..34b5819e1 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -199,9 +199,8 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False): input_config: NotRequired[Dict[str, InputTypeSpec]] received_value: NotRequired[Any] linked_node: NotRequired[str] - traceback: NotRequired[list[str]] exception_message: NotRequired[str] - exception_type: NotRequired[str] + node_errors: NotRequired[Dict[str, 'NodeErrorsDictValue']] class ValidationErrorDict(TypedDict): diff --git a/comfy/component_model/queue_types.py b/comfy/component_model/queue_types.py index 21be0e5f2..f5189cc64 100644 --- a/comfy/component_model/queue_types.py +++ b/comfy/component_model/queue_types.py @@ -18,6 +18,7 @@ class TaskInvocation(NamedTuple): item_id: int | str outputs: OutputsDict status: Optional[ExecutionStatus] + error_details: Optional['ExecutionErrorMessage'] = None class ExecutionStatus(NamedTuple): @@ -25,12 +26,15 @@ class ExecutionStatus(NamedTuple): completed: bool messages: List[str] - def as_dict(self) -> ExecutionStatusAsDict: - return { + def as_dict(self, error_details: Optional['ExecutionErrorMessage'] = None) -> ExecutionStatusAsDict: + result: ExecutionStatusAsDict = { "status_str": self.status_str, "completed": self.completed, "messages": copy.copy(self.messages), } + if error_details is not None: + result["error_details"] = error_details + return result class ExecutionError(RuntimeError): diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 486657347..894921057 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -162,7 +162,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): return item, item[1] - def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]): + def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None): # callee: executed on the worker thread if "outputs" in outputs: outputs: HistoryResultDict @@ -173,7 +173,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): assert pending.completed is not None assert not pending.completed.done() # finish the task. status will transmit the errors in comfy's domain-specific way - pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status)) + pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status, error_details=error_details)) # todo: the caller is responsible for sending a websocket message right now that the UI expects for updates def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]: diff --git a/comfy/tracing_compatibility.py b/comfy/tracing_compatibility.py index 6c0ef16f5..a77f1e79b 100644 --- a/comfy/tracing_compatibility.py +++ b/comfy/tracing_compatibility.py @@ -3,7 +3,7 @@ from typing import Optional, Sequence from aio_pika.abc import AbstractChannel from opentelemetry.context import Context from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision -from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.semconv.attributes.network_attributes import NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT from opentelemetry.trace import SpanKind, Link, TraceState from opentelemetry.util.types import Attributes @@ -22,8 +22,8 @@ def patch_spanbuilder_set_channel() -> None: port = url.port or 5672 self._attributes.update( { - SpanAttributes.NET_PEER_NAME: url.host, - SpanAttributes.NET_PEER_PORT: port, + NETWORK_PEER_ADDRESS: url.host, + NETWORK_PEER_PORT: port, } ) diff --git a/comfy/vendor/aiohttp_server_instrumentation.py b/comfy/vendor/aiohttp_server_instrumentation.py deleted file mode 100644 index 5e334c49f..000000000 --- a/comfy/vendor/aiohttp_server_instrumentation.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright 2020, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import urllib -from timeit import default_timer -from typing import Dict, List, Tuple, Union - -from aiohttp import web -from multidict import CIMultiDictProxy - -from opentelemetry import metrics, trace -_instruments = ("aiohttp ~= 3.0",) -__version__ = "0.49b0.dev" -from opentelemetry.instrumentation.instrumentor import BaseInstrumentor -from opentelemetry.instrumentation.utils import ( - http_status_to_status_code, - is_http_instrumentation_enabled, -) -from opentelemetry.propagate import extract -from opentelemetry.propagators.textmap import Getter -from opentelemetry.semconv.metrics import MetricInstruments -from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace.status import Status, StatusCode -from opentelemetry.util.http import get_excluded_urls, remove_url_credentials - -_duration_attrs = [ - SpanAttributes.HTTP_METHOD, - SpanAttributes.HTTP_HOST, - SpanAttributes.HTTP_SCHEME, - SpanAttributes.HTTP_STATUS_CODE, - SpanAttributes.HTTP_FLAVOR, - SpanAttributes.HTTP_SERVER_NAME, - SpanAttributes.NET_HOST_NAME, - SpanAttributes.NET_HOST_PORT, - SpanAttributes.HTTP_ROUTE, -] - -_active_requests_count_attrs = [ - SpanAttributes.HTTP_METHOD, - SpanAttributes.HTTP_HOST, - SpanAttributes.HTTP_SCHEME, - SpanAttributes.HTTP_FLAVOR, - SpanAttributes.HTTP_SERVER_NAME, -] - -tracer = trace.get_tracer(__name__) -meter = metrics.get_meter(__name__, __version__) -_excluded_urls = get_excluded_urls("AIOHTTP_SERVER") - - -def _parse_duration_attrs(req_attrs): - duration_attrs = {} - for attr_key in _duration_attrs: - if req_attrs.get(attr_key) is not None: - duration_attrs[attr_key] = req_attrs[attr_key] - return duration_attrs - - -def _parse_active_request_count_attrs(req_attrs): - active_requests_count_attrs = {} - for attr_key in _active_requests_count_attrs: - if req_attrs.get(attr_key) is not None: - active_requests_count_attrs[attr_key] = req_attrs[attr_key] - return active_requests_count_attrs - - -def get_default_span_details(request: web.Request) -> Tuple[str, dict]: - """Default implementation for get_default_span_details - Args: - request: the request object itself. - Returns: - a tuple of the span name, and any attributes to attach to the span. - """ - span_name = request.path.strip() or f"HTTP {request.method}" - return span_name, {} - - -def _get_view_func(request: web.Request) -> str: - """Returns the name of the request handler. - Args: - request: the request object itself. - Returns: - a string containing the name of the handler function - """ - try: - return request.match_info.handler.__name__ - except AttributeError: - return "unknown" - - -def collect_request_attributes(request: web.Request) -> Dict: - """Collects HTTP request attributes from the ASGI scope and returns a - dictionary to be used as span creation attributes.""" - - server_host, port, http_url = ( - request.url.host, - request.url.port, - str(request.url), - ) - query_string = request.query_string - if query_string and http_url: - if isinstance(query_string, bytes): - query_string = query_string.decode("utf8") - http_url += "?" + urllib.parse.unquote(query_string) - - result = { - SpanAttributes.HTTP_SCHEME: request.scheme, - SpanAttributes.HTTP_HOST: server_host, - SpanAttributes.NET_HOST_PORT: port, - SpanAttributes.HTTP_ROUTE: _get_view_func(request), - SpanAttributes.HTTP_FLAVOR: f"{request.version.major}.{request.version.minor}", - SpanAttributes.HTTP_TARGET: request.path, - SpanAttributes.HTTP_URL: remove_url_credentials(http_url), - } - - http_method = request.method - if http_method: - result[SpanAttributes.HTTP_METHOD] = http_method - - http_host_value_list = ( - [request.host] if not isinstance(request.host, list) else request.host - ) - if http_host_value_list: - result[SpanAttributes.HTTP_SERVER_NAME] = ",".join( - http_host_value_list - ) - http_user_agent = request.headers.get("user-agent") - if http_user_agent: - result[SpanAttributes.HTTP_USER_AGENT] = http_user_agent - - # remove None values - result = {k: v for k, v in result.items() if v is not None} - - return result - - -def set_status_code(span, status_code: int) -> None: - """Adds HTTP response attributes to span using the status_code argument.""" - - try: - status_code = int(status_code) - except ValueError: - span.set_status( - Status( - StatusCode.ERROR, - "Non-integer HTTP status: " + repr(status_code), - ) - ) - else: - span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status_code) - span.set_status( - Status(http_status_to_status_code(status_code, server_span=True)) - ) - - -class AiohttpGetter(Getter): - """Extract current trace from headers""" - - def get(self, carrier, key: str) -> Union[List, None]: - """Getter implementation to retrieve an HTTP header value from the ASGI - scope. - - Args: - carrier: ASGI scope object - key: header name in scope - Returns: - A list of all header values matching the key, or None if the key - does not match any header. - """ - headers: CIMultiDictProxy = carrier.headers - if not headers: - return None - return headers.getall(key, None) - - def keys(self, carrier: Dict) -> List: - return list(carrier.keys()) - - -getter = AiohttpGetter() - - -@web.middleware -async def middleware(request, handler): - """Middleware for aiohttp implementing tracing logic""" - if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled( - request.url.path - ): - return await handler(request) - - span_name, additional_attributes = get_default_span_details(request) - - req_attrs = collect_request_attributes(request) - duration_attrs = _parse_duration_attrs(req_attrs) - active_requests_count_attrs = _parse_active_request_count_attrs(req_attrs) - - duration_histogram = meter.create_histogram( - name=MetricInstruments.HTTP_SERVER_DURATION, - unit="ms", - description="Measures the duration of inbound HTTP requests.", - ) - - active_requests_counter = meter.create_up_down_counter( - name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS, - unit="requests", - description="measures the number of concurrent HTTP requests those are currently in flight", - ) - - with tracer.start_as_current_span( - span_name, - context=extract(request, getter=getter), - kind=trace.SpanKind.SERVER, - ) as span: - attributes = collect_request_attributes(request) - attributes.update(additional_attributes) - span.set_attributes(attributes) - start = default_timer() - active_requests_counter.add(1, active_requests_count_attrs) - try: - resp = await handler(request) - set_status_code(span, resp.status) - except web.HTTPException as ex: - set_status_code(span, ex.status_code) - raise - except AttributeError: - # No response was returned or a NoneType response was returned, handle gracefully - set_status_code(span, 500) - raise - finally: - duration = max((default_timer() - start) * 1000, 0) - duration_histogram.record(duration, duration_attrs) - active_requests_counter.add(-1, active_requests_count_attrs) - return resp - - -class _InstrumentedApplication(web.Application): - """Insert tracing middleware""" - - def __init__(self, *args, **kwargs): - middlewares = kwargs.pop("middlewares", []) - middlewares.insert(0, middleware) - kwargs["middlewares"] = middlewares - super().__init__(*args, **kwargs) - - -class AioHttpServerInstrumentor(BaseInstrumentor): - # pylint: disable=protected-access,attribute-defined-outside-init - """An instrumentor for aiohttp.web.Application - - See `BaseInstrumentor` - """ - - def _instrument(self, **kwargs): - self._original_app = web.Application - setattr(web, "Application", _InstrumentedApplication) - - def _uninstrument(self, **kwargs): - setattr(web, "Application", self._original_app) - - def instrumentation_dependencies(self): - return _instruments diff --git a/comfy_extras/eval_web/__init__.py b/comfy_extras/eval_web/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy_extras/eval_web/eval_python.js b/comfy_extras/eval_web/eval_python.js new file mode 100644 index 000000000..a7e2fa19f --- /dev/null +++ b/comfy_extras/eval_web/eval_python.js @@ -0,0 +1,242 @@ +/** + * Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode + * + * MIT License + * + * Copyright (c) 2024 Kevin Yuan + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +import { app } from "../../scripts/app.js"; + +// Load Ace editor using script tag for Safari compatibility +// The noconflict build includes AMD loader that works in all browsers +let ace; +const aceLoadPromise = new Promise((resolve) => { + if (window.ace) { + ace = window.ace; + resolve(); + } else { + const script = document.createElement("script"); + script.src = "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict/ace.js"; + script.onload = () => { + ace = window.ace; + ace.config.set("basePath", "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict"); + resolve(); + }; + document.head.appendChild(script); + } +}); + +// todo: do we really want to do this here? +await aceLoadPromise; +const findWidget = (node, value, attr = "name", func = "find") => { + return node?.widgets ? node.widgets[func]((w) => (Array.isArray(value) ? value.includes(w[attr]) : w[attr] === value)) : null; +}; + +const makeElement = (tag, attrs = {}) => { + if (!tag) tag = "div"; + const element = document.createElement(tag); + Object.keys(attrs).forEach((key) => { + const currValue = attrs[key]; + if (key === "class") { + if (Array.isArray(currValue)) { + element.classList.add(...currValue); + } else if (currValue instanceof String || typeof currValue === "string") { + element.className = currValue; + } + } else if (key === "dataset") { + try { + if (Array.isArray(currValue)) { + currValue.forEach((datasetArr) => { + const [prop, propval] = Object.entries(datasetArr)[0]; + element.dataset[prop] = propval; + }); + } else { + Object.entries(currValue).forEach((datasetArr) => { + const [prop, propval] = datasetArr; + element.dataset[prop] = propval; + }); + } + } catch (err) { + // todo: what is this trying to do? + } + } else if (key === "style") { + if (typeof currValue === "object" && !Array.isArray(currValue) && Object.keys(currValue).length) { + Object.assign(element[key], currValue); + } else if (typeof currValue === "object" && Array.isArray(currValue) && currValue.length) { + element[key] = [...currValue]; + } else if (currValue instanceof String || typeof currValue === "string") { + element[key] = currValue; + } + } else if (["for"].includes(key)) { + element.setAttribute(key, currValue); + } else if (key === "children") { + element.append(...(currValue instanceof Array ? currValue : [currValue])); + } else if (key === "parent") { + currValue.append(element); + } else { + element[key] = currValue; + } + }); + return element; +}; + +const getPosition = (node, ctx, w_width, y, n_height) => { + const margin = 5; + + const rect = ctx.canvas.getBoundingClientRect(); + const transform = ctx.getTransform(); + const scale = app.canvas.ds.scale; + + // The context is already transformed to draw at the widget position + // transform.e and transform.f give us the canvas coordinates (in canvas pixels) + // We need to convert these to screen pixels by accounting for the canvas scale + // rect gives us the canvas element's position on the page + + // The transform matrix has scale baked in (transform.a = transform.d = scale) + // transform.e and transform.f are the translation in canvas-pixel space + const canvasPixelToScreenPixel = rect.width / ctx.canvas.width; + + const x = transform.e * canvasPixelToScreenPixel + rect.left; + const y_pos = transform.f * canvasPixelToScreenPixel + rect.top; + + // Convert widget dimensions from canvas coordinates to screen pixels + const scaledWidth = w_width * scale; + const scaledHeight = (n_height - y - 15) * scale; + const scaledMargin = margin * scale; + const scaledY = y * scale; + + return { + left: `${x + scaledMargin}px`, + top: `${y_pos + scaledY + scaledMargin}px`, + width: `${scaledWidth - scaledMargin * 2}px`, + maxWidth: `${scaledWidth - scaledMargin * 2}px`, + height: `${scaledHeight - scaledMargin * 2}px`, + maxHeight: `${scaledHeight - scaledMargin * 2}px`, + position: "absolute", + scrollbarColor: "var(--descrip-text) var(--bg-color)", + scrollbarWidth: "thin", + zIndex: app.graph._nodes.indexOf(node), + }; +}; + +// Create code editor widget +const codeEditor = (node, inputName, inputData) => { + const widget = { + type: "code_block_python", + name: inputName, + options: { hideOnZoom: true }, + value: inputData[1]?.default || "", + draw(ctx, node, widgetWidth, y) { + const hidden = node.flags?.collapsed || (!!this.options.hideOnZoom && app.canvas.ds.scale < 0.5) || this.type === "converted-widget" || this.type === "hidden" || this.type === "converted-widget"; + + this.codeElement.hidden = hidden; + + if (hidden) { + this.options.onHide?.(this); + return; + } + + Object.assign(this.codeElement.style, getPosition(node, ctx, widgetWidth, y, node.size[1])); + }, + computeSize() { + return [500, 250]; + }, + }; + + widget.codeElement = makeElement("pre", { + innerHTML: widget.value, + }); + + widget.editor = ace.edit(widget.codeElement); + widget.editor.setTheme("ace/theme/monokai"); + widget.editor.session.setMode("ace/mode/python"); + widget.editor.setOptions({ + enableAutoIndent: true, + enableLiveAutocompletion: true, + enableBasicAutocompletion: true, + fontFamily: "monospace", + }); + widget.codeElement.hidden = true; + + document.body.appendChild(widget.codeElement); + + const originalCollapse = node.collapse; + node.collapse = function () { + originalCollapse.apply(this, arguments); + widget.codeElement.hidden = !!this.flags?.collapsed; + }; + + return widget; +}; + +// Trigger workflow change tracking +const markWorkflowChanged = () => { + app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState(); +}; + +// Register extensions +app.registerExtension({ + name: "Comfy.EvalPython", + getCustomWidgets(app) { + return { + CODE_BLOCK_PYTHON: (node, inputName, inputData) => { + const widget = codeEditor(node, inputName, inputData); + + widget.editor.getSession().on("change", () => { + widget.value = widget.editor.getValue(); + markWorkflowChanged(); + }); + + node.onRemoved = function () { + for (const w of this.widgets) { + if (w?.codeElement) { + w.codeElement.remove(); + } + } + }; + + node.addCustomWidget(widget); + + return widget; + }, + }; + }, + + async beforeRegisterNodeDef(nodeType, nodeData) { + // Handle all EvalPython node variants + if (nodeData.name.startsWith("EvalPython")) { + const originalOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + originalOnConfigure?.apply(this, arguments); + + if (info?.widgets_values?.length) { + const widgetCodeIndex = findWidget(this, "code_block_python", "type", "findIndex"); + const editor = this.widgets[widgetCodeIndex]?.editor; + + if (editor) { + editor.setValue(info.widgets_values[widgetCodeIndex]); + editor.clearSelection(); + } + } + }; + } + }, +}); diff --git a/comfy_extras/nodes/nodes_eval.py b/comfy_extras/nodes/nodes_eval.py new file mode 100644 index 000000000..bda41a02d --- /dev/null +++ b/comfy_extras/nodes/nodes_eval.py @@ -0,0 +1,119 @@ +import logging + +from comfy.comfy_types import IO +from comfy.execution_context import current_execution_context +from comfy.node_helpers import export_package_as_web_directory, export_custom_nodes +from comfy.nodes.package_typing import CustomNode + +logger = logging.getLogger(__name__) + + +def eval_python(inputs=5, outputs=5, name=None, input_is_list=None, output_is_list=None): + """ + Factory function to create EvalPython node classes with configurable input/output counts. + + Args: + inputs: Number of input value slots (default: 5) + outputs: Number of output item slots (default: 5) + name: Class name (default: f"EvalPython_{inputs}_{outputs}") + input_is_list: Optional list of bools indicating which inputs accept lists (default: None, meaning all scalar) + output_is_list: Optional tuple of bools indicating which outputs return lists (default: None, meaning all scalar) + + Returns: + A CustomNode subclass configured with the specified inputs/outputs + """ + if name is None: + name = f"EvalPython_{inputs}_{outputs}" + + default_code = f""" +print("Hello World!") +return {", ".join([f"value{i}" for i in range(inputs)])} +""" + + class EvalPythonNode(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "pycode": ( + "CODE_BLOCK_PYTHON", + { + "default": default_code + }, + ), + }, + "optional": {f"value{i}": (IO.ANY, {}) for i in range(inputs)}, + } + + RETURN_TYPES = tuple(IO.ANY for _ in range(outputs)) + RETURN_NAMES = tuple(f"item{i}" for i in range(outputs)) + FUNCTION = "exec_py" + DESCRIPTION = "" + CATEGORY = "eval" + + @classmethod + def VALIDATE_INPUTS(cls, *args, **kwargs): + ctx = current_execution_context() + + return ctx.configuration.enable_eval + + def exec_py(self, pycode, **kwargs): + ctx = current_execution_context() + + kwargs = { + **{f"value{i}": None for i in range(inputs)}, + **kwargs, + } + + def print(*args): + ctx.server.send_progress_text(" ".join(map(str, args)), ctx.node_id) + + if not ctx.configuration.enable_eval: + raise ValueError("Python eval is disabled") + + value_args = [kwargs.pop(f"value{i}") for i in range(inputs)] + arg_names = ", ".join(f"value{i}=None" for i in range(inputs)) + + wrapped_code = f"def _eval_func({arg_names}):\n" + for line in pycode.splitlines(): + wrapped_code += " " + line + "\n" + + globals_for_eval = { + **kwargs, + "logger": logger, + "print": print, + } + + exec(wrapped_code, globals_for_eval) + results = globals_for_eval["_eval_func"](*value_args) + if not isinstance(results, tuple): + results = (results,) + + if len(results) < outputs: + results += (None,) * (outputs - len(results)) + elif len(results) > outputs: + results = results[:outputs] + + return results + + # todo: interact better with the weird comfyui machinery for this + if input_is_list is not None: + setattr(EvalPythonNode, "INPUT_IS_LIST", input_is_list) + + if output_is_list is not None: + setattr(EvalPythonNode, "OUTPUT_IS_LIST", output_is_list) + + EvalPythonNode.__name__ = name + EvalPythonNode.__qualname__ = name + + return EvalPythonNode + + +EvalPython_1_1 = eval_python(inputs=1, outputs=1, name="EvalPython_1_1") +EvalPython_5_5 = eval_python(inputs=5, outputs=5, name="EvalPython_5_5") +EvalPython_List_1 = eval_python(inputs=1, outputs=1, name="EvalPython_List_1", input_is_list=True, output_is_list=None) +EvalPython_1_List = eval_python(inputs=1, outputs=1, name="EvalPython_1_List", input_is_list=None, output_is_list=(True,)) +EvalPython_List_List = eval_python(inputs=1, outputs=1, name="EvalPython_List_List", input_is_list=True, output_is_list=(True,)) + +export_custom_nodes() +export_package_as_web_directory("comfy_extras.eval_web") diff --git a/pyproject.toml b/pyproject.toml index e694216d8..f6e6f375c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,13 +65,18 @@ dependencies = [ "natsort", "OpenEXR", "opentelemetry-distro", - "opentelemetry-sdk<1.34.0", - "opentelemetry-exporter-otlp<=1.27.0", + "opentelemetry-sdk", + "opentelemetry-exporter-otlp", "opentelemetry-propagator-jaeger", "opentelemetry-instrumentation", "opentelemetry-util-http", "opentelemetry-instrumentation-aio-pika", "opentelemetry-instrumentation-requests", + "opentelemetry-instrumentation-aiohttp-server", + "opentelemetry-instrumentation-aiohttp-client", + "opentelemetry-instrumentation-asyncio", + "opentelemetry-instrumentation-urllib3", + "opentelemetry-processor-baggage", "opentelemetry-semantic-conventions", "wrapt>=1.16.0", "certifi", diff --git a/tests/conftest.py b/tests/conftest.py index 712cfc0dc..1c5b3df20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,6 @@ import multiprocessing import os import pathlib import subprocess -import sys -import time import urllib from contextvars import ContextVar from multiprocessing import Process @@ -12,9 +10,9 @@ from typing import List, Any, Generator import pytest import requests +import sys +import time -from comfy.cli_args import default_configuration -from comfy.execution_context import context_configuration os.environ['OTEL_METRICS_EXPORTER'] = 'none' os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" @@ -22,6 +20,7 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "True" # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost" +from comfy.cli_args import default_configuration from comfy.cli_args_types import Configuration logging.getLogger("pika").setLevel(logging.CRITICAL + 1) @@ -91,9 +90,17 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers params = rabbitmq.get_connection_params() connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + # Check if OTEL endpoint is configured for integration testing + otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + + env = os.environ.copy() + if otel_endpoint: + env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otel_endpoint + logging.info(f"Configuring services to export traces to: {otel_endpoint}") + frontend_command = [ "comfyui", - "--listen=127.0.0.1", + "--listen=0.0.0.0", "--port=19001", "--cpu", "--distributed-queue-frontend", @@ -101,7 +108,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers f"--distributed-queue-connection-uri={connection_uri}", ] - processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) + processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr, env=env)) # Start multiple workers for i in range(num_workers): @@ -112,7 +119,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers f"--distributed-queue-connection-uri={connection_uri}", f"--executor-factory={executor_factory}" ] - processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) + processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr, env=env)) try: server_address = f"http://127.0.0.1:19001" diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 5aa28c722..3d8757824 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -243,3 +243,368 @@ async def test_two_workers_distinct_requests(): all_workflows.update(worker.processed_workflows) assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}" + + +@pytest.mark.asyncio +async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq): + """Test error reporting with blocking request (no async preference)""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create an invalid prompt that will cause a validation error + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + # Make the prompt invalid by referencing a non-existent checkpoint + prompt["4"]["inputs"]["ckpt_name"] = "nonexistent_checkpoint.safetensors" + + # Post with blocking behavior (no prefer header for async) + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + # Should return 400 for validation error (invalid checkpoint) + assert response.status == 400, f"Expected 400, got {response.status}" + error_body = await response.json() + + # Verify ValidationErrorDict structure per OpenAPI spec + assert "type" in error_body, "Missing 'type' field in error response" + assert "message" in error_body, "Missing 'message' field in error response" + assert "details" in error_body, "Missing 'details' field in error response" + assert "extra_info" in error_body, "Missing 'extra_info' field in error response" + + +@pytest.mark.asyncio +async def test_api_error_reporting_async_prefer_header(frontend_backend_worker_with_rabbitmq): + """Test error reporting with Prefer: respond-async header""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt structure but with invalid checkpoint + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + prompt["4"]["inputs"]["ckpt_name"] = "nonexistent.safetensors" + + # Post with Prefer: respond-async header + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'Prefer': 'respond-async' + } + ) as response: + # Should return 400 immediately for validation error + assert response.status == 400, f"Expected 400 for validation error, got {response.status}" + error_body = await response.json() + assert "type" in error_body + + +@pytest.mark.asyncio +async def test_api_error_reporting_async_accept_mimetype(frontend_backend_worker_with_rabbitmq): + """Test error reporting with +respond-async in Accept mimetype""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt with validation error + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + prompt["4"]["inputs"]["ckpt_name"] = "invalid_model.safetensors" + + # Post with +respond-async in Accept header + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/json+respond-async' + } + ) as response: + # Should return 400 for validation error (happens before queuing) + assert response.status == 400, f"Expected 400, got {response.status}" + error_body = await response.json() + assert "type" in error_body + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_success(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 200 with Outputs on success""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Queue async to get prompt_id + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None + + # Poll until done + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + # For a valid prompt, should get 200 + assert status_code == 200, f"Expected 200 for successful execution, got {status_code}" + assert result is not None + + # Verify it returns outputs structure (dict with node IDs) + assert isinstance(result, dict) + assert len(result) > 0, "Expected non-empty outputs" + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_404(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 404 for non-existent prompt""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Request a non-existent prompt ID + fake_prompt_id = str(uuid.uuid4()) + + async with await client.get_prompt_status(fake_prompt_id) as response: + assert response.status == 404, f"Expected 404 for non-existent prompt, got {response.status}" + + +@pytest.mark.asyncio +async def test_api_get_prompt_status_204_in_progress(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns 204 while prompt is in progress""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt that takes some time to execute + prompt = sdxl_workflow_with_refiner("test", inference_steps=10, refiner_steps=10) + + # Queue async + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + + # Immediately check status (should be 204 or 200 if very fast) + async with await client.get_prompt_status(task_id) as response: + # Should be either 204 (in progress) or 200 (completed very fast) + assert response.status in [200, 204], f"Expected 200 or 204, got {response.status}" + + if response.status == 204: + # No content for in-progress + content = await response.read() + assert len(content) == 0 or content == b'', "Expected no content for 204 response" + + +@pytest.mark.asyncio +async def test_api_async_workflow_both_methods(frontend_backend_worker_with_rabbitmq): + """Test full async workflow: queue with respond-async, then poll for completion""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Method 1: Prefer header + task_id_1 = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id_1 is not None + + # Method 2: +respond-async in Accept header + task_id_2 = await client.queue_and_forget_prompt_api( + prompt, prefer_header=None, accept_header="application/json+respond-async" + ) + assert task_id_2 is not None + + # Poll both until done + status_1, result_1 = await client.poll_prompt_until_done(task_id_1, max_attempts=60, poll_interval=1.0) + status_2, result_2 = await client.poll_prompt_until_done(task_id_2, max_attempts=60, poll_interval=1.0) + + # Both should succeed + assert status_1 == 200, f"Task 1 failed with status {status_1}" + assert status_2 == 200, f"Task 2 failed with status {status_2}" + + # Both should have outputs + assert result_1 is not None and len(result_1) > 0 + assert result_2 is not None and len(result_2) > 0 + + +@pytest.mark.asyncio +async def test_api_validation_error_structure(frontend_backend_worker_with_rabbitmq): + """Test that validation errors return proper ValidationErrorDict structure""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create an invalid prompt (invalid checkpoint name) + prompt = sdxl_workflow_with_refiner("test", "", 1, refiner_steps=1) + prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors" + + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + assert response.status == 400, f"Expected 400, got {response.status}" + + error_body = await response.json() + + # Verify ValidationErrorDict structure per OpenAPI spec + assert "type" in error_body, "Missing 'type'" + assert "message" in error_body, "Missing 'message'" + assert "details" in error_body, "Missing 'details'" + assert "extra_info" in error_body, "Missing 'extra_info'" + + assert error_body["type"] == "prompt_outputs_failed_validation", "unexpected type" + + # extra_info should have exception_type and traceback + assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info" + assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info" + assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list" + + # extra_info should have node_errors with detailed validation information + assert "node_errors" in error_body["extra_info"], "Missing 'node_errors' in extra_info" + node_errors = error_body["extra_info"]["node_errors"] + assert isinstance(node_errors, dict), "node_errors should be a dict" + assert len(node_errors) > 0, "node_errors should contain at least one node" + + # Verify node_errors structure for node "4" (CheckpointLoaderSimple with invalid ckpt_name) + assert "4" in node_errors, "Node '4' should have validation errors" + node_4_errors = node_errors["4"] + assert "errors" in node_4_errors, "Node '4' should have 'errors' field" + assert "class_type" in node_4_errors, "Node '4' should have 'class_type' field" + assert "dependent_outputs" in node_4_errors, "Node '4' should have 'dependent_outputs' field" + + assert node_4_errors["class_type"] == "CheckpointLoaderSimple", "Node '4' class_type should be CheckpointLoaderSimple" + assert len(node_4_errors["errors"]) > 0, "Node '4' should have at least one error" + + # Verify the error details include the validation error type and message + first_error = node_4_errors["errors"][0] + assert "type" in first_error, "Error should have 'type' field" + assert "message" in first_error, "Error should have 'message' field" + assert "details" in first_error, "Error should have 'details' field" + assert first_error["type"] == "value_not_in_list", f"Expected 'value_not_in_list' error, got {first_error['type']}" + assert "fake.safetensors" in first_error["details"], "Error details should mention 'fake.safetensors'" + + +@pytest.mark.asyncio +async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq): + """Test that successful execution returns proper response structure""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a valid prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + + # Queue and wait for blocking response + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + assert response.status == 200, f"Expected 200, got {response.status}" + + result = await response.json() + + # Should have 'outputs' key (and deprecated 'urls' key) + assert "outputs" in result, "Missing 'outputs' in response" + + # outputs should be a dict with node IDs as keys + outputs = result["outputs"] + assert isinstance(outputs, dict), "outputs should be a dict" + assert len(outputs) > 0, "outputs should not be empty" + + # Each output should follow the Output schema + for node_id, output in outputs.items(): + assert isinstance(output, dict), f"Output for node {node_id} should be a dict" + # Should have images or other output types + if "images" in output: + assert isinstance(output["images"], list), f"images for node {node_id} should be a list" + for image in output["images"]: + assert "filename" in image, f"image missing 'filename' in node {node_id}" + assert "subfolder" in image, f"image missing 'subfolder' in node {node_id}" + assert "type" in image, f"image missing 'type' in node {node_id}" + + +@pytest.mark.asyncio +async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_with_rabbitmq): + """Test GET /api/v1/prompts/{prompt_id} returns Outputs directly (not wrapped in history entry)""" + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create and queue a prompt + prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt) + + # Poll until done + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + assert status_code == 200, f"Expected 200, got {status_code}" + assert result is not None, "Result should not be None" + + # Per OpenAPI spec, GET should return Outputs directly, not wrapped + # result should be a dict with node IDs as keys + assert isinstance(result, dict), "Result should be a dict (Outputs)" + + # Should NOT have 'prompt', 'outputs', 'status' keys (those are in history entry) + # Should have node IDs directly + for key in result.keys(): + # Node IDs are typically numeric strings like "4", "13", etc. + # Should not be "prompt", "outputs", "status" + assert key not in ["prompt", "status"], \ + f"GET endpoint should return Outputs directly, not history entry. Found key: {key}" + + +@pytest.mark.asyncio +async def test_api_execution_error_blocking_mode(frontend_backend_worker_with_rabbitmq): + """Test that execution errors (not validation) return proper error structure in blocking mode""" + from comfy_execution.graph_utils import GraphBuilder + + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt that will fail during execution (not validation) + # Use Regex with a group name that doesn't exist - validation passes but execution fails + g = GraphBuilder() + regex_match = g.node("Regex", pattern="hello", string="hello world") + # Request a non-existent group name - this will pass validation but fail during execution + match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group") + g.node("SaveString", value=match_group.out(0), filename_prefix="test") + + prompt = g.finalize() + prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt) + + async with client.session.post( + f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts", + data=prompt_json, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'} + ) as response: + # Execution errors return 500 + assert response.status == 500, f"Expected 500 for execution error, got {response.status}" + + error_body = await response.json() + + # Verify ExecutionStatus structure + assert "status_str" in error_body, "Missing 'status_str'" + assert "completed" in error_body, "Missing 'completed'" + assert "messages" in error_body, "Missing 'messages'" + + assert error_body["status_str"] == "error", f"Expected 'error', got {error_body['status_str']}" + assert error_body["completed"] == False, "completed should be False for errors" + assert isinstance(error_body["messages"], list), "messages should be a list" + assert len(error_body["messages"]) > 0, "messages should contain error details" + + +@pytest.mark.asyncio +async def test_api_execution_error_async_mode(frontend_backend_worker_with_rabbitmq): + """Test that execution errors return proper error structure in respond-async mode""" + from comfy_execution.graph_utils import GraphBuilder + + async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client: + # Create a prompt that will fail during execution (not validation) + # Use Regex with a group name that doesn't exist - validation passes but execution fails + g = GraphBuilder() + regex_match = g.node("Regex", pattern="hello", string="hello world") + # Request a non-existent group name - this will pass validation but fail during execution + match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group") + g.node("SaveString", value=match_group.out(0), filename_prefix="test") + + prompt = g.finalize() + + # Queue with respond-async + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None, "Should get task_id in async mode" + + # Poll for completion + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + + # In async mode with polling, errors come back as 200 with error in the response body + # because the prompt was accepted (202) and we're just retrieving the completed result + assert status_code in (200, 500), f"Expected 200 or 500, got {status_code}" + + if status_code == 500: + # Error returned directly - should be ExecutionStatus + assert "status_str" in result, "Missing 'status_str'" + assert "completed" in result, "Missing 'completed'" + assert "messages" in result, "Missing 'messages'" + assert result["status_str"] == "error" + assert result["completed"] == False + assert len(result["messages"]) > 0 + else: + # Error in successful response - result might be ExecutionStatus or empty outputs + # If it's a dict with status info, verify it + if "status_str" in result: + assert result["status_str"] == "error" + assert result["completed"] == False + assert len(result["messages"]) > 0 diff --git a/tests/distributed/test_tracing.py b/tests/distributed/test_tracing.py new file mode 100644 index 000000000..8c156a9bf --- /dev/null +++ b/tests/distributed/test_tracing.py @@ -0,0 +1,128 @@ +import asyncio +import logging + +logging.basicConfig(level=logging.ERROR) + +import uuid + +import pytest +from testcontainers.rabbitmq import RabbitMqContainer +from opentelemetry import trace, propagate, context +from opentelemetry.trace import SpanKind +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner +from comfy.component_model.make_mutable import make_mutable +from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus +from comfy.distributed.server_stub import ServerStub + + +async def create_test_prompt() -> QueueItem: + from comfy.cmd.execution import validate_prompt + + prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) + item_id = str(uuid.uuid4()) + + validation_tuple = await validate_prompt(item_id, prompt) + queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) + return QueueItem(queue_tuple, None) + + +@pytest.mark.asyncio +async def test_rabbitmq_message_properties_contain_trace_context(): + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + import aio_pika + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = trace.get_tracer(__name__, tracer_provider=provider) + + with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER): + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: + queue_item = await create_test_prompt() + + put_task = asyncio.create_task(frontend.put_async(queue_item)) + + incoming, incoming_prompt_id = await worker.get_async(timeout=5.0) + assert incoming is not None, "Worker should receive message" + + worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, [])) + + result = await put_task + assert result is not None, "Frontend should get result" + + # Now inspect the RabbitMQ queue directly to see message structure + connection = await aio_pika.connect_robust(connection_uri) + channel = await connection.channel() + + # Declare a test queue to inspect message format + test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True) + + # Publish a test message with trace context + carrier = {} + propagate.inject(carrier) + + test_message = aio_pika.Message( + body=b"test", + headers=carrier + ) + + await channel.default_exchange.publish( + test_message, + routing_key=test_queue.name + ) + + # Get and inspect the message + received = await test_queue.get(timeout=2, fail=False) + if received: + headers = received.headers or {} + + # Document what trace headers should be present + # OpenTelemetry uses 'traceparent' header for W3C Trace Context + has_traceparent = "traceparent" in headers + + assert has_traceparent + + await received.ack() + + await connection.close() + + +@pytest.mark.asyncio +async def test_distributed_queue_uses_async_interface(): + """ + Test that demonstrates the correct way to use DistributedPromptQueue in async context. + The synchronous get() method cannot be used in async tests due to event loop assertions. + """ + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: + queue_item = await create_test_prompt() + + # Start consuming in background + result_future = asyncio.create_task(frontend.put_async(queue_item)) + + # Worker gets item asynchronously (not using blocking get()) + incoming, incoming_prompt_id = await worker.get_async(timeout=5.0) + assert incoming is not None, "Should receive a queue item" + + # Complete the work + worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, [])) + + # Wait for frontend to complete + result = await result_future + assert result is not None, "Should get result from worker" + assert result.status.status_str == "success" diff --git a/tests/distributed/test_tracing_integration.py b/tests/distributed/test_tracing_integration.py new file mode 100644 index 000000000..36a9fabd6 --- /dev/null +++ b/tests/distributed/test_tracing_integration.py @@ -0,0 +1,569 @@ +""" +Integration tests for distributed tracing across RabbitMQ and services. + +These tests validate that trace context propagates correctly from frontend +to backend workers through RabbitMQ, and that Jaeger can reconstruct the +full distributed trace. +""" +import asyncio +import logging +import os +import tempfile +import time +import uuid + +import pytest +import requests +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.semconv.attributes import service_attributes +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.nginx import NginxContainer + +from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class JaegerContainer(DockerContainer): + """Testcontainer for Jaeger all-in-one with OTLP support.""" + + def __init__(self, image: str = "jaegertracing/all-in-one:latest"): + super().__init__(image) + self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP + self.with_env("COLLECTOR_OTLP_ENABLED", "true") + + def get_query_url(self) -> str: + """Get Jaeger Query API URL.""" + host = self.get_container_host_ip() + port = self.get_exposed_port(16686) + return f"http://{host}:{port}" + + def get_otlp_endpoint(self) -> str: + """Get OTLP HTTP endpoint for sending traces.""" + host = self.get_container_host_ip() + port = self.get_exposed_port(4318) + return f"http://{host}:{port}" + + def start(self): + super().start() + wait_for_logs(self, ".*Starting GRPC server.*", timeout=30) + return self + + +@pytest.fixture(scope="function") +def nginx_proxy(frontend_backend_worker_with_rabbitmq): + """ + Provide an nginx proxy in front of the ComfyUI frontend. + This tests if nginx is blocking W3C trace context propagation. + """ + import socket + import subprocess + + # Extract host and port from frontend address + frontend_url = frontend_backend_worker_with_rabbitmq + # frontend_url is like "http://127.0.0.1:19001" + import re + match = re.match(r'http://([^:]+):(\d+)', frontend_url) + if not match: + raise ValueError(f"Could not parse frontend URL: {frontend_url}") + + frontend_host = match.group(1) + frontend_port = match.group(2) + nginx_port = 8085 + + # Get the Docker bridge gateway IP (this is how containers reach the host on Linux) + # Try to get the default Docker bridge gateway + try: + result = subprocess.run( + ["docker", "network", "inspect", "bridge", "-f", "{{range .IPAM.Config}}{{.Gateway}}{{end}}"], + capture_output=True, + text=True, + check=True + ) + docker_gateway = result.stdout.strip() + logger.info(f"Using Docker gateway IP: {docker_gateway}") + except Exception as e: + # Fallback: try common gateway IPs + docker_gateway = "172.17.0.1" # Default Docker bridge gateway on Linux + logger.warning(f"Could not detect Docker gateway, using default: {docker_gateway}") + + # Create nginx config that proxies to the frontend and passes trace headers + nginx_conf = f""" +events {{ + worker_connections 1024; +}} + +http {{ + upstream backend {{ + server {docker_gateway}:{frontend_port}; + }} + + server {{ + listen {nginx_port}; + + location / {{ + proxy_pass http://backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + }} + }} +}} +""" + + # Write config to a temporary file + with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as f: + f.write(nginx_conf) + nginx_conf_path = f.name + + try: + # Start nginx container with the config + nginx = NginxContainer(port=nginx_port) + nginx.with_volume_mapping(nginx_conf_path, "/etc/nginx/nginx.conf") + nginx.start() + + # Get the nginx URL + host = nginx.get_container_host_ip() + port = nginx.get_exposed_port(nginx_port) + nginx_url = f"http://{host}:{port}" + + logger.info(f"Nginx proxy started at {nginx_url} -> {frontend_url}") + + # Wait for nginx to be ready + for _ in range(30): + try: + response = requests.get(nginx_url, timeout=1) + if response.status_code: + break + except Exception: + pass + time.sleep(0.5) + + yield nginx_url + finally: + nginx.stop() + os.unlink(nginx_conf_path) + + +@pytest.fixture(scope="module") +def jaeger_container(): + """ + Provide a Jaeger container for collecting traces. + + This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the + Jaeger container, and cleans it up when the container stops. + """ + container = JaegerContainer() + container.start() + + # Wait for Jaeger to be fully ready + query_url = container.get_query_url() + otlp_endpoint = container.get_otlp_endpoint() + + for _ in range(30): + try: + response = requests.get(f"{query_url}/api/services") + if response.status_code == 200: + logger.info(f"Jaeger ready at {query_url}") + logger.info(f"OTLP endpoint: {otlp_endpoint}") + break + except Exception: + pass + time.sleep(1) + + # Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test + old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint + logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}") + + try: + yield container + finally: + # Restore original OTEL_EXPORTER_OTLP_ENDPOINT + if old_endpoint is not None: + os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint + logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}") + else: + os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None) + logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT") + + container.stop() + + +def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None, + lookback: str = "1h", limit: int = 100) -> dict: + """ + Query Jaeger for traces. + + Args: + jaeger_url: Base URL of Jaeger query service + service: Service name to query + operation: Optional operation name filter + lookback: Lookback period (e.g., "1h", "30m") + limit: Maximum number of traces to return + + Returns: + JSON response from Jaeger API + """ + params = { + "service": service, + "lookback": lookback, + "limit": limit + } + if operation: + params["operation"] = operation + + response = requests.get(f"{jaeger_url}/api/traces", params=params) + response.raise_for_status() + return response.json() + + +def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict: + """Find a specific trace by operation name.""" + for trace in traces_response.get("data", []): + for span in trace.get("spans", []): + if span.get("operationName") == operation_name: + return trace + return None + + +def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool: + """ + Verify that a trace spans multiple services and maintains parent-child relationships. + + Args: + trace: Jaeger trace object + expected_services: List of service names expected in the trace + + Returns: + True if trace shows proper distributed tracing across services + """ + if not trace: + return False + + spans = trace.get("spans", []) + if not spans: + return False + + # Check that all expected services are present + trace_services = set() + for span in spans: + process_id = span.get("processID") + if process_id: + process = trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName") + if service_name: + trace_services.add(service_name) + + logger.info(f"Trace contains services: {trace_services}") + logger.info(f"Expected services: {set(expected_services)}") + + # Verify all expected services are present + for service in expected_services: + if service not in trace_services: + logger.warning(f"Expected service '{service}' not found in trace") + return False + + # Verify all spans share the same trace ID + trace_ids = set(span.get("traceID") for span in spans) + if len(trace_ids) != 1: + logger.warning(f"Multiple trace IDs found: {trace_ids}") + return False + + # Verify parent-child relationships exist + span_ids = {span.get("spanID") for span in spans} + has_parent_refs = False + + for span in spans: + references = span.get("references", []) + for ref in references: + if ref.get("refType") == "CHILD_OF": + parent_span_id = ref.get("spanID") + if parent_span_id in span_ids: + has_parent_refs = True + logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}") + + if not has_parent_refs: + logger.warning("No parent-child relationships found in trace") + return False + + return True + + +# order matters, execute jaeger_container first +@pytest.mark.asyncio +async def test_tracing_integration(jaeger_container, nginx_proxy): + """ + Integration test for distributed tracing across services with nginx proxy. + + This test: + 1. Starts ComfyUI frontend and worker with RabbitMQ + 2. Starts nginx proxy in front of the frontend to test trace context propagation through nginx + 3. Configures OTLP export to Jaeger testcontainer + 4. Submits a workflow through the nginx proxy + 5. Queries Jaeger to verify trace propagation + 6. Validates that the trace spans multiple services with proper relationships + + This specifically tests if nginx is blocking W3C trace context (traceparent/tracestate headers). + """ + server_address = nginx_proxy + jaeger_url = jaeger_container.get_query_url() + otlp_endpoint = jaeger_container.get_otlp_endpoint() + + logger.info(f"Frontend server: {server_address}") + logger.info(f"Jaeger UI: {jaeger_url}") + logger.info(f"OTLP endpoint: {otlp_endpoint}") + + # Set up tracing for the async HTTP client + resource = Resource.create({ + service_attributes.SERVICE_NAME: "comfyui-client", + }) + provider = TracerProvider(resource=resource) + exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces") + processor = BatchSpanProcessor(exporter) + provider.add_span_processor(processor) + from opentelemetry import trace + + trace.set_tracer_provider(provider) + + # Instrument aiohttp client + AioHttpClientInstrumentor().instrument() + + # we have to call this very late, so that the instrumentation isn't initialized too early + from comfy.client.aio_client import AsyncRemoteComfyClient + + # Note: In a real integration test, you'd need to configure the ComfyUI + # services to export traces to this Jaeger instance. For now, this test + # documents the expected behavior. + + # Create a unique prompt to identify our trace + test_id = str(uuid.uuid4())[:8] + prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1) + + # Get the tracer for the client + client_tracer = trace.get_tracer("test_tracing_integration") + + # Submit the workflow - wrap in a span to capture the trace ID + with client_tracer.start_as_current_span("submit_workflow") as workflow_span: + trace_id = format(workflow_span.get_span_context().trace_id, '032x') + logger.info(f"Started trace with trace_id: {trace_id}") + + async with AsyncRemoteComfyClient(server_address=server_address) as client: + logger.info(f"Submitting workflow with test_id: {test_id}") + + # Queue the prompt with async response + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None, "Failed to get task ID" + + logger.info(f"Queued task: {task_id}") + + # Poll for completion + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + assert status_code == 200, f"Task failed with status {status_code}" + logger.info("Task completed successfully") + + # Give Jaeger time to receive and process spans + await asyncio.sleep(5) + + # Query Jaeger for traces + # Note: The actual service names depend on how your services are configured + # Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker" + + expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names + + logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...") + + # First, try to find our specific trace by trace_id from the client service + our_trace = None + for service in expected_services: + try: + traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m") + if traces_response.get("data"): + logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'") + for trace in traces_response["data"]: + if trace.get("traceID") == trace_id: + our_trace = trace + logger.info(f"Found our trace in service '{service}'") + break + if our_trace: + break + except Exception as e: + logger.warning(f"Could not query traces for service '{service}': {e}") + + # Assert we can find the trace we just created + assert our_trace is not None, ( + f"Could not find trace with trace_id {trace_id} in Jaeger. " + f"This indicates that spans from comfyui-client are not being exported correctly." + ) + + logger.info(f"Successfully found trace with trace_id {trace_id}") + + # Extract services from the trace + trace_services = set() + for span in our_trace.get("spans", []): + process_id = span.get("processID") + if process_id: + process = our_trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName") + if service_name: + trace_services.add(service_name) + + logger.info(f"Services found in trace: {trace_services}") + + # Assert that comfyui-client service is present (since we instrumented it) + assert "comfyui-client" in trace_services, ( + f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. " + f"This indicates the client instrumentation is not working." + ) + + # Validate trace structure + logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans") + + # Log all spans for debugging + for span in our_trace.get("spans", []): + process_id = span.get("processID") + process = our_trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName", "unknown") + operation = span.get("operationName", "unknown") + logger.info(f" Span: {service_name}.{operation}") + + # Verify trace continuity - only if both services are present + assert "comfyui" in trace_services + is_continuous = verify_trace_continuity(our_trace, expected_services) + + # This assertion documents what SHOULD happen when distributed tracing works + assert is_continuous, ( + "Trace does not show proper distributed tracing. " + "Expected to see spans from multiple services with parent-child relationships. " + "This indicates that trace context is not being propagated correctly through RabbitMQ." + ) + +@pytest.mark.asyncio +async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq): + """ + Test that HTTP requests include traceparent headers. + + This validates that the HTTP layer is properly instrumented for tracing. + """ + server_address = frontend_backend_worker_with_rabbitmq + + # Make a simple HTTP request and check for trace headers + # Note: We're checking the server's response headers to see if it's trace-aware + response = requests.get(f"{server_address}/system_stats") + + logger.info(f"Response headers: {dict(response.headers)}") + + # The server should be instrumented and may include trace context in responses + # or at minimum, should accept traceparent headers in requests + + # Test sending a traceparent header + test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + response_with_trace = requests.get( + f"{server_address}/system_stats", + headers={"traceparent": test_traceparent} + ) + + # Should not error when traceparent is provided + assert response_with_trace.status_code == 200, "Server should accept traceparent header" + + logger.info("✓ Server accepts traceparent headers in HTTP requests") + + +@pytest.mark.asyncio +async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container): + """ + Test that multiple independent requests create separate traces. + + This validates that trace context is properly scoped per request. + """ + server_address = frontend_backend_worker_with_rabbitmq + + # Submit multiple workflows + task_ids = [] + + from comfy.client.aio_client import AsyncRemoteComfyClient + async with AsyncRemoteComfyClient(server_address=server_address) as client: + for i in range(3): + prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + task_ids.append(task_id) + logger.info(f"Queued task {i}: {task_id}") + + # Wait for all to complete + for i, task_id in enumerate(task_ids): + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + assert status_code == 200, f"Task {i} failed" + logger.info(f"Task {i} completed") + + # Give Jaeger time to receive spans + await asyncio.sleep(5) + + # Query Jaeger and verify we have multiple distinct traces + jaeger_url = jaeger_container.get_query_url() + + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10) + traces = traces_response.get("data", []) + + assert len(traces) >= 2 + # Get trace IDs + trace_ids = [trace.get("traceID") for trace in traces] + unique_trace_ids = set(trace_ids) + + logger.info(f"Found {len(unique_trace_ids)} unique traces") + + # Verify we have multiple distinct traces + assert len(unique_trace_ids) >= 2, ( + f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. " + "Each request should create its own trace." + ) + + logger.info("✓ Multiple requests created distinct traces") + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="rabbitmq has to be configured for observability?") +async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container): + """ + Test that traces include RabbitMQ publish/consume operations. + + This is critical for distributed tracing - the RabbitMQ operations + are what link the frontend and backend spans together. + """ + server_address = frontend_backend_worker_with_rabbitmq + jaeger_url = jaeger_container.get_query_url() + + # Submit a workflow + from comfy.client.aio_client import AsyncRemoteComfyClient + async with AsyncRemoteComfyClient(server_address=server_address) as client: + prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt) + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60) + assert status_code == 200 + + await asyncio.sleep(5) + + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m") + traces = traces_response.get("data", []) + + # Look for RabbitMQ-related operations in any trace + rabbitmq_operations = [ + "publish", "consume", "amq_queue_publish", "amq_queue_consume", + "amq.basic.publish", "amq.basic.consume", "send", "receive" + ] + + found_rabbitmq_ops = [] + for trace in traces: + for span in trace.get("spans", []): + op_name = span.get("operationName", "").lower() + for rmq_op in rabbitmq_operations: + if rmq_op in op_name: + found_rabbitmq_ops.append(op_name) + + assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces" \ No newline at end of file diff --git a/tests/unit/test_eval_nodes.py b/tests/unit/test_eval_nodes.py new file mode 100644 index 000000000..71076daef --- /dev/null +++ b/tests/unit/test_eval_nodes.py @@ -0,0 +1,683 @@ +import pytest +from unittest.mock import Mock, patch + +from comfy.cli_args import default_configuration +from comfy.execution_context import context_configuration +from comfy_extras.nodes.nodes_eval import ( + eval_python, + EvalPython_5_5, + EvalPython_List_1, + EvalPython_1_List, + EvalPython_List_List, +) + + +@pytest.fixture +def eval_context(): + """Fixture that sets up execution context with eval enabled""" + config = default_configuration() + config.enable_eval = True + with context_configuration(config): + yield + + +def test_eval_python_basic_return(eval_context): + """Test basic return statement with single value""" + node = EvalPython_5_5() + result = node.exec_py(pycode="return 42", value0=0, value1=1, value2=2, value3=3, value4=4) + assert result == (42, None, None, None, None) + + +def test_eval_python_multiple_returns(eval_context): + """Test return statement with tuple of values""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 1, 2, 3", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (1, 2, 3, None, None) + + +def test_eval_python_all_five_returns(eval_context): + """Test return statement with all five values""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 'a', 'b', 'c', 'd', 'e'", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ('a', 'b', 'c', 'd', 'e') + + +def test_eval_python_excess_returns(eval_context): + """Test that excess return values are truncated to 5""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 1, 2, 3, 4, 5, 6, 7", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_use_value_args(eval_context): + """Test that value arguments are accessible in pycode""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return value0 + value1 + value2", + value0=10, value1=20, value2=30, value3=0, value4=0 + ) + assert result == (60, None, None, None, None) + + +def test_eval_python_all_value_args(eval_context): + """Test all value arguments are accessible""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return value0, value1, value2, value3, value4", + value0=1, value1=2, value2=3, value3=4, value4=5 + ) + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_computation(eval_context): + """Test computation with value arguments""" + node = EvalPython_5_5() + code = """ +x = value0 * 2 +y = value1 * 3 +z = x + y +return z +""" + result = node.exec_py( + pycode=code, + value0=5, value1=10, value2=0, value3=0, value4=0 + ) + assert result == (40, None, None, None, None) + + +def test_eval_python_multiline(eval_context): + """Test multiline code with conditionals""" + node = EvalPython_5_5() + code = """ +if value0 > 10: + result = "large" +else: + result = "small" +return result, value0 +""" + result = node.exec_py( + pycode=code, + value0=15, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("large", 15, None, None, None) + + +def test_eval_python_list_comprehension(eval_context): + """Test list comprehension and iteration""" + node = EvalPython_5_5() + code = """ +numbers = [value0, value1, value2] +doubled = [x * 2 for x in numbers] +return sum(doubled) +""" + result = node.exec_py( + pycode=code, + value0=1, value1=2, value2=3, value3=0, value4=0 + ) + assert result == (12, None, None, None, None) + + +def test_eval_python_string_operations(eval_context): + """Test string operations""" + node = EvalPython_5_5() + code = """ +s1 = str(value0) +s2 = str(value1) +return s1 + s2, len(s1 + s2) +""" + result = node.exec_py( + pycode=code, + value0=123, value1=456, value2=0, value3=0, value4=0 + ) + assert result == ("123456", 6, None, None, None) + + +def test_eval_python_type_mixing(eval_context): + """Test mixing different types""" + node = EvalPython_5_5() + code = """ +return value0, str(value1), float(value2), bool(value3) +""" + result = node.exec_py( + pycode=code, + value0=42, value1=100, value2=3, value3=1, value4=0 + ) + assert result == (42, "100", 3.0, True, None) + + +def test_eval_python_logger_available(eval_context): + """Test that logger is available in eval context""" + node = EvalPython_5_5() + code = """ +logger.info("test log") +return "success" +""" + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("success", None, None, None, None) + + +def test_eval_python_print_available(eval_context): + """Test that print function is available""" + node = EvalPython_5_5() + code = """ +print("Hello World!") +return "printed" +""" + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("printed", None, None, None, None) + +def test_eval_python_print_is_called(eval_context): + """Test that print function is called and receives correct arguments""" + node = EvalPython_5_5() + + # Track print calls + print_calls = [] + + code = """ +print("Hello", "World") +print("Line 2") +return "done" +""" + + # Mock exec to capture the globals dict and verify print is there + original_exec = exec + captured_globals = {} + + def mock_exec(code_str, globals_dict, *args, **kwargs): + # Capture the globals dict + captured_globals.update(globals_dict) + + # Wrap the print function to track calls + original_print = globals_dict.get('print') + if original_print: + def tracked_print(*args): + print_calls.append(args) + return original_print(*args) + globals_dict['print'] = tracked_print + + # Run the original exec + return original_exec(code_str, globals_dict, *args, **kwargs) + + with patch('builtins.exec', side_effect=mock_exec): + result = node.exec_py( + pycode=code, + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + + # Verify the result + assert result == ("done", None, None, None, None) + + # Verify print was in the globals + assert 'print' in captured_globals + + # Verify print was called twice with correct arguments + assert len(print_calls) == 2 + assert print_calls[0] == ("Hello", "World") + assert print_calls[1] == ("Line 2",) + + +def test_eval_python_print_sends_to_server(eval_context): + """Test that print sends messages to PromptServer via context""" + from comfy.execution_context import current_execution_context + + node = EvalPython_5_5() + ctx = current_execution_context() + + # Mock the server's send_progress_text method + original_send = ctx.server.send_progress_text if hasattr(ctx.server, 'send_progress_text') else None + mock_send = Mock() + ctx.server.send_progress_text = mock_send + + code = """ +print("Hello", "World") +print("Value:", value0) +return "done" +""" + + try: + result = node.exec_py( + pycode=code, + value0=42, value1=0, value2=0, value3=0, value4=0 + ) + + # Verify the result + assert result == ("done", None, None, None, None) + + # Verify print messages were sent to server + assert mock_send.call_count == 2 + + # Verify the messages sent + calls = mock_send.call_args_list + assert calls[0][0][0] == "Hello World" + assert calls[0][0][1] == ctx.node_id + assert calls[1][0][0] == "Value: 42" + assert calls[1][0][1] == ctx.node_id + finally: + # Restore original + if original_send: + ctx.server.send_progress_text = original_send + + +def test_eval_python_config_disabled_raises(): + """Test that enable_eval=False raises an error""" + node = EvalPython_5_5() + config = default_configuration() + config.enable_eval = False + with context_configuration(config): + with pytest.raises(ValueError, match="Python eval is disabled"): + node.exec_py( + pycode="return 42", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + + +def test_eval_python_config_enabled_works(eval_context): + """Test that enable_eval=True allows execution""" + node = EvalPython_5_5() + result = node.exec_py( + pycode="return 42", + value0=0, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (42, None, None, None, None) + + +def test_eval_python_default_code(eval_context): + """Test the default code example works""" + node = EvalPython_5_5() + # Get the default code from INPUT_TYPES + default_code = EvalPython_5_5.INPUT_TYPES()["required"]["pycode"][1]["default"] + + result = node.exec_py( + pycode=default_code, + value0=1, value1=2, value2=3, value3=4, value4=5 + ) + # Default code prints and returns the values + assert result == (1, 2, 3, 4, 5) + + +def test_eval_python_function_definition(eval_context): + """Test defining and using functions""" + node = EvalPython_5_5() + code = """ +def multiply(a, b): + return a * b + +result = multiply(value0, value1) +return result +""" + result = node.exec_py( + pycode=code, + value0=7, value1=8, value2=0, value3=0, value4=0 + ) + assert result == (56, None, None, None, None) + + +def test_eval_python_nested_functions(eval_context): + """Test nested function definitions""" + node = EvalPython_5_5() + code = """ +def outer(x): + def inner(y): + return y * 2 + return inner(x) + 10 + +result = outer(value0) +return result +""" + result = node.exec_py( + pycode=code, + value0=5, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (20, None, None, None, None) + + +def test_eval_python_dict_operations(eval_context): + """Test dictionary creation and operations""" + node = EvalPython_5_5() + code = """ +data = { + 'a': value0, + 'b': value1, + 'c': value2 +} +return sum(data.values()), len(data) +""" + result = node.exec_py( + pycode=code, + value0=10, value1=20, value2=30, value3=0, value4=0 + ) + assert result == (60, 3, None, None, None) + + +def test_eval_python_list_operations(eval_context): + """Test list creation and operations""" + node = EvalPython_5_5() + code = """ +items = [value0, value1, value2, value3, value4] +filtered = [x for x in items if x > 5] +return len(filtered), sum(filtered) +""" + result = node.exec_py( + pycode=code, + value0=1, value1=10, value2=3, value3=15, value4=2 + ) + assert result == (2, 25, None, None, None) + + +def test_eval_python_early_return(eval_context): + """Test early return in conditional""" + node = EvalPython_5_5() + code = """ +if value0 > 100: + return "large" +return "small" +""" + result = node.exec_py( + pycode=code, + value0=150, value1=0, value2=0, value3=0, value4=0 + ) + assert result == ("large", None, None, None, None) + + +def test_eval_python_loop_with_return(eval_context): + """Test loop with return statement""" + node = EvalPython_5_5() + code = """ +total = 0 +for i in range(value0): + total += i +return total +""" + result = node.exec_py( + pycode=code, + value0=10, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (45, None, None, None, None) + + +def test_eval_python_exception_handling(eval_context): + """Test try/except blocks""" + node = EvalPython_5_5() + code = """ +try: + result = value0 / value1 +except ZeroDivisionError: + result = float('inf') +return result +""" + result = node.exec_py( + pycode=code, + value0=10, value1=0, value2=0, value3=0, value4=0 + ) + assert result == (float('inf'), None, None, None, None) + + +def test_eval_python_none_values(eval_context): + """Test handling None values in inputs""" + node = EvalPython_5_5() + code = """ +return value0, value1 is None, value2 is None +""" + result = node.exec_py( + pycode=code, + value0=42, value1=None, value2=None, value3=0, value4=0 + ) + assert result == (42, True, True, None, None) + + +def test_eval_python_input_types(): + """Test that INPUT_TYPES returns correct structure""" + input_types = EvalPython_5_5.INPUT_TYPES() + assert "required" in input_types + assert "optional" in input_types + assert "pycode" in input_types["required"] + assert input_types["required"]["pycode"][0] == "CODE_BLOCK_PYTHON" + + # Check optional inputs + for i in range(5): + assert f"value{i}" in input_types["optional"] + + +def test_eval_python_metadata(): + """Test node metadata""" + assert EvalPython_5_5.FUNCTION == "exec_py" + assert EvalPython_5_5.CATEGORY == "eval" + assert len(EvalPython_5_5.RETURN_TYPES) == 5 + assert len(EvalPython_5_5.RETURN_NAMES) == 5 + assert all(name.startswith("item") for name in EvalPython_5_5.RETURN_NAMES) + + +def test_eval_python_factory_custom_inputs_outputs(eval_context): + """Test creating nodes with custom input/output counts""" + # Create a node with 3 inputs and 2 outputs + CustomNode = eval_python(inputs=3, outputs=2) + + node = CustomNode() + + # Verify INPUT_TYPES has correct number of inputs + input_types = CustomNode.INPUT_TYPES() + assert len(input_types["optional"]) == 3 + assert "value0" in input_types["optional"] + assert "value1" in input_types["optional"] + assert "value2" in input_types["optional"] + assert "value3" not in input_types["optional"] + + # Verify RETURN_TYPES has correct number of outputs + assert len(CustomNode.RETURN_TYPES) == 2 + assert len(CustomNode.RETURN_NAMES) == 2 + + # Test execution + result = node.exec_py( + pycode="return value0 + value1 + value2, value0 * 2", + value0=1, value1=2, value2=3 + ) + assert result == (6, 2) + + +def test_eval_python_factory_custom_name(eval_context): + """Test creating nodes with custom names""" + CustomNode = eval_python(inputs=2, outputs=2, name="MyCustomEval") + + assert CustomNode.__name__ == "MyCustomEval" + assert CustomNode.__qualname__ == "MyCustomEval" + + +def test_eval_python_factory_default_name(eval_context): + """Test that default name follows pattern""" + CustomNode = eval_python(inputs=3, outputs=4) + + assert CustomNode.__name__ == "EvalPython_3_4" + assert CustomNode.__qualname__ == "EvalPython_3_4" + + +def test_eval_python_factory_single_output(eval_context): + """Test node with single output""" + SingleOutputNode = eval_python(inputs=2, outputs=1) + + node = SingleOutputNode() + result = node.exec_py( + pycode="return value0 + value1", + value0=10, value1=20 + ) + assert result == (30,) + + +def test_eval_python_factory_many_outputs(eval_context): + """Test node with many outputs""" + ManyOutputNode = eval_python(inputs=1, outputs=10) + + node = ManyOutputNode() + result = node.exec_py( + pycode="return tuple(range(10))", + value0=0 + ) + assert result == (0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + + +def test_eval_python_factory_fewer_returns_than_outputs(eval_context): + """Test that fewer returns are padded with None""" + Node = eval_python(inputs=2, outputs=5) + + node = Node() + result = node.exec_py( + pycode="return value0, value1", + value0=1, value1=2 + ) + assert result == (1, 2, None, None, None) + + +def test_eval_python_factory_more_returns_than_outputs(eval_context): + """Test that excess returns are truncated""" + Node = eval_python(inputs=2, outputs=3) + + node = Node() + result = node.exec_py( + pycode="return 1, 2, 3, 4, 5", + value0=0, value1=0 + ) + assert result == (1, 2, 3) + + +def test_eval_python_list_1_input_is_list(eval_context): + """Test EvalPython_List_1 with list input""" + node = EvalPython_List_1() + + # Verify INPUT_IS_LIST is set + assert EvalPython_List_1.INPUT_IS_LIST is True + + # Test that value0 receives a list + result = node.exec_py( + pycode="return sum(value0)", + value0=[1, 2, 3, 4, 5] + ) + assert result == (15,) + + +def test_eval_python_list_1_iterate_list(eval_context): + """Test EvalPython_List_1 iterating over list input""" + node = EvalPython_List_1() + + result = node.exec_py( + pycode="return [x * 2 for x in value0]", + value0=[1, 2, 3] + ) + assert result == ([2, 4, 6],) + + +def test_eval_python_1_list_output_is_list(eval_context): + """Test EvalPython_1_List with list output""" + node = EvalPython_1_List() + + # Verify OUTPUT_IS_LIST is set + assert EvalPython_1_List.OUTPUT_IS_LIST == (True,) + + # Test that returns a list + result = node.exec_py( + pycode="return list(range(value0))", + value0=5 + ) + assert result == ([0, 1, 2, 3, 4],) + + +def test_eval_python_1_list_multiple_items(eval_context): + """Test EvalPython_1_List returning multiple items in list""" + node = EvalPython_1_List() + + result = node.exec_py( + pycode="return ['a', 'b', 'c']", + value0=0 + ) + assert result == (['a', 'b', 'c'],) + + +def test_eval_python_list_list_both(eval_context): + """Test EvalPython_List_List with both list input and output""" + node = EvalPython_List_List() + + # Verify both are set + assert EvalPython_List_List.INPUT_IS_LIST is True + assert EvalPython_List_List.OUTPUT_IS_LIST == (True,) + + # Test processing list input and returning list output + result = node.exec_py( + pycode="return [x ** 2 for x in value0]", + value0=[1, 2, 3, 4] + ) + assert result == ([1, 4, 9, 16],) + + +def test_eval_python_list_list_filter(eval_context): + """Test EvalPython_List_List filtering a list""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return [x for x in value0 if x > 5]", + value0=[1, 3, 5, 7, 9, 11] + ) + assert result == ([7, 9, 11],) + + +def test_eval_python_list_list_transform(eval_context): + """Test EvalPython_List_List transforming list elements""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return [str(x).upper() for x in value0]", + value0=['hello', 'world', 'python'] + ) + assert result == (['HELLO', 'WORLD', 'PYTHON'],) + + +def test_eval_python_factory_with_list_flags(eval_context): + """Test factory function with custom list flags""" + # Create node with input as list but output scalar + ListInputNode = eval_python(inputs=1, outputs=1, input_is_list=True, output_is_list=None) + + assert ListInputNode.INPUT_IS_LIST is True + + node = ListInputNode() + result = node.exec_py( + pycode="return len(value0)", + value0=[1, 2, 3, 4, 5] + ) + assert result == (5,) + + +def test_eval_python_factory_scalar_output_list(eval_context): + """Test factory function with scalar input and list output""" + ScalarToListNode = eval_python(inputs=1, outputs=1, input_is_list=None, output_is_list=(True,)) + + assert ScalarToListNode.OUTPUT_IS_LIST == (True,) + + node = ScalarToListNode() + result = node.exec_py( + pycode="return [value0] * 3", + value0='x' + ) + assert result == (['x', 'x', 'x'],) + + +def test_eval_python_list_empty_list(eval_context): + """Test list nodes with empty lists""" + node = EvalPython_List_List() + + result = node.exec_py( + pycode="return []", + value0=[] + ) + assert result == ([],)