Merge branch 'develop'

This commit is contained in:
doctorpangloss 2025-11-10 11:51:58 -08:00
commit 4f6615e939
26 changed files with 2361 additions and 350 deletions

View File

@ -34,7 +34,9 @@ jobs:
tags: | tags: |
type=raw,value=latest-rocm,enable={{is_default_branch}} type=raw,value=latest-rocm,enable={{is_default_branch}}
type=sha,prefix=,suffix=-rocm type=sha,prefix=,suffix=-rocm
type=semver,pattern={{version}},suffix=-rocm type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-rocm
type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-rocm
type=match,pattern=v?(\d+\.\d+),group=1,suffix=-rocm
- name: Build and push ROCm (AMD) image - name: Build and push ROCm (AMD) image
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:

View File

@ -41,8 +41,12 @@ jobs:
type=raw,value=latest-cuda,enable={{is_default_branch}} type=raw,value=latest-cuda,enable={{is_default_branch}}
type=sha,prefix= type=sha,prefix=
type=sha,prefix=,suffix=-cuda type=sha,prefix=,suffix=-cuda
type=semver,pattern={{version}} type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1
type=semver,pattern={{version}},suffix=-cuda type=match,pattern=v?(\d+\.\d+\.\d+\.\d+),group=1,suffix=-cuda
type=match,pattern=v?(\d+\.\d+\.\d+),group=1
type=match,pattern=v?(\d+\.\d+\.\d+),group=1,suffix=-cuda
type=match,pattern=v?(\d+\.\d+),group=1
type=match,pattern=v?(\d+\.\d+),group=1,suffix=-cuda
- name: Build and push CUDA (NVIDIA) image - name: Build and push CUDA (NVIDIA) image
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:

View File

@ -33,7 +33,7 @@ RUN pip install uv && uv --version && \
# install sageattention # install sageattention
ADD pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl /workspace/pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl ADD pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl /workspace/pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl
RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry flash-attn "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl" RUN uv pip install -U --no-deps --no-build-isolation spandrel timm tensorboard poetry "flash-attn<=2.8.0" "xformers==0.0.31.post1" "file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl"
# this exotic command will determine the correct torchaudio to install for the image # this exotic command will determine the correct torchaudio to install for the image
RUN <<-EOF RUN <<-EOF
python -c 'import torch, re, subprocess python -c 'import torch, re, subprocess
@ -66,7 +66,7 @@ WORKDIR /workspace
# addresses https://github.com/pytorch/pytorch/issues/104801 # addresses https://github.com/pytorch/pytorch/issues/104801
# and issues reported by importing nodes_canny # and issues reported by importing nodes_canny
# smoke test # smoke test
RUN python -c "import torch; import xformers; import sageattention; import cv2" && comfyui --quick-test-for-ci --cpu --cwd /workspace RUN python -c "import torch; import xformers; import sageattention; import cv2; import diffusers.hooks" && comfyui --quick-test-for-ci --cpu --cwd /workspace
EXPOSE 8188 EXPOSE 8188
CMD ["python", "-m", "comfy.cmd.main", "--listen", "--use-sage-attention", "--reserve-vram=0", "--logging-level=INFO", "--enable-cors"] CMD ["python", "-m", "comfy.cmd.main", "--listen", "--use-sage-attention", "--reserve-vram=0", "--logging-level=INFO", "--enable-cors"]

View File

@ -354,19 +354,29 @@ paths:
required: true required: true
description: | description: |
The ID of the prompt to query. The ID of the prompt to query.
responses: responses:
204: 204:
description: | description: |
The prompt is still in progress The prompt is still in progress
200: 200:
description: | description: |
Prompt outputs Prompt outputs
content: content:
application/json: application/json:
$ref: "#/components/schemas/Outputs" schema:
404: $ref: "#/components/schemas/Outputs"
description: | 404:
The prompt was not found description: |
The prompt was not found
500:
description: |
An execution error occurred while processing the prompt.
content:
application/json:
description:
An execution status directly from the workers
schema:
$ref: "#/components/schemas/ExecutionStatusAsDict"
/api/v1/prompts: /api/v1/prompts:
get: get:
operationId: list_prompts operationId: list_prompts
@ -861,6 +871,22 @@ components:
type: array type: array
items: items:
type: string type: string
node_errors:
type: object
description: "Detailed validation errors per node"
additionalProperties:
type: object
properties:
errors:
type: array
items:
$ref: "#/components/schemas/ValidationErrorDict"
dependent_outputs:
type: array
items:
type: string
class_type:
type: string
required: required:
- type - type
- details - details

View File

@ -156,6 +156,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.") parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--enable-eval", action="store_true", help="Enable nodes that can evaluate Python code in workflows.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
parser.add_argument("--create-directories", action="store_true", parser.add_argument("--create-directories", action="store_true",

View File

@ -169,6 +169,7 @@ class Configuration(dict):
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled. whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible. default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental). block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental).
enable_eval (Optional[bool]): Enable nodes that can evaluate Python code in workflows.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -288,6 +289,7 @@ class Configuration(dict):
self.database_url: str = db_config() self.database_url: str = db_config()
self.default_device: Optional[int] = None self.default_device: Optional[int] = None
self.block_runtime_package_installation = None self.block_runtime_package_installation = None
self.enable_eval: Optional[bool] = False
for key, value in kwargs.items(): for key, value in kwargs.items():
self[key] = value self[key] = value
@ -420,6 +422,7 @@ class FlattenAndAppendAction(argparse.Action):
Custom action to handle comma-separated values and multiple invocations Custom action to handle comma-separated values and multiple invocations
of the same argument, flattening them into a single list. of the same argument, flattening them into a single list.
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
items = getattr(namespace, self.dest, None) items = getattr(namespace, self.dest, None)
if items is None: if items is None:

View File

@ -1,23 +1,23 @@
from asyncio import AbstractEventLoop
from collections import defaultdict
import aiohttp
import asyncio import asyncio
import uuid import uuid
from aiohttp import WSMessage, ClientResponse, ClientTimeout from asyncio import AbstractEventLoop
from pathlib import Path
from typing import Optional, List from typing import Optional, List
from urllib.parse import urlparse, urljoin from urllib.parse import urlparse, urljoin
import aiohttp
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from opentelemetry import trace
from .client_types import V1QueuePromptResponse from .client_types import V1QueuePromptResponse
from ..api.api_client import JSONEncoder from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..api.components.schema.prompt_request import PromptRequest from ..api.components.schema.prompt_request import PromptRequest
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
from ..api.schemas import immutabledict from ..api.schemas import immutabledict
from ..component_model.file_output_path import file_output_path
from ..component_model.outputs_types import OutputsDict from ..component_model.outputs_types import OutputsDict
tracer = trace.get_tracer(__name__)
class AsyncRemoteComfyClient: class AsyncRemoteComfyClient:
""" """
@ -57,6 +57,27 @@ class AsyncRemoteComfyClient:
def session(self) -> aiohttp.ClientSession: def session(self) -> aiohttp.ClientSession:
return self._ensure_session() return self._ensure_session()
def _build_headers(self, accept_header: str, prefer_header: Optional[str] = None, content_type: str = "application/json") -> dict:
"""Build HTTP headers for requests."""
headers = {'Content-Type': content_type, 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
return headers
@tracer.start_as_current_span("Post Prompt")
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
"""
Common method to POST a prompt to a given endpoint.
:param prompt: The prompt to send
:param endpoint: The API endpoint (e.g., "/api/v1/prompts")
:param accept_header: The Accept header value
:param prefer_header: Optional Prefer header value
:return: The response object
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
headers = self._build_headers(accept_header, prefer_header)
return await self.session.post(urljoin(self.server_address, endpoint), data=prompt_json, headers=headers)
async def len_queue(self) -> int: async def len_queue(self) -> int:
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response: async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
if response.status == 200: if response.status == 200:
@ -65,34 +86,30 @@ class AsyncRemoteComfyClient:
else: else:
raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}") raise RuntimeError(f"unexpected response: {response.status}: {await response.text()}")
async def queue_and_forget_prompt_api(self, prompt: PromptDict) -> str: async def queue_and_forget_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = "respond-async", accept_header: str = "application/json") -> str:
""" """
Calls the API to queue a prompt, and forgets about it Calls the API to queue a prompt, and forgets about it
:param prompt: :param prompt:
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the task ID :return: the task ID
""" """
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
response: ClientResponse
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'Prefer': 'respond-async'}) as response:
if 200 <= response.status < 400: if 200 <= response.status < 400:
response_json = await response.json() response_json = await response.json()
return response_json["prompt_id"] return response_json["prompt_id"]
else: else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") raise RuntimeError(f"could not prompt: {response.status}, reason={response.reason}: {await response.text()}")
async def queue_prompt_api(self, prompt: PromptDict) -> V1QueuePromptResponse: async def queue_prompt_api(self, prompt: PromptDict, prefer_header: Optional[str] = None, accept_header: str = "application/json") -> V1QueuePromptResponse:
""" """
Calls the API to queue a prompt. Calls the API to queue a prompt.
:param prompt: :param prompt:
:param prefer_header: The Prefer header value (e.g., "respond-async" or None)
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
""" """
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
response: ClientResponse
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}) as response:
if 200 <= response.status < 400: if 200 <= response.status < 400:
return V1QueuePromptResponse(**(await response.json())) return V1QueuePromptResponse(**(await response.json()))
else: else:
@ -112,17 +129,13 @@ class AsyncRemoteComfyClient:
:param prompt: :param prompt:
:return: :return:
""" """
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) async with await self._post_prompt(prompt, "/api/v1/prompts", "image/png") as response:
response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': 'image/png'}
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers=headers) as response:
if 200 <= response.status < 400: if 200 <= response.status < 400:
return await response.read() return await response.read()
else: else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
@tracer.start_as_current_span("Post Prompt (UI)")
async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict: async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict:
""" """
Uses the comfyui UI API calls to retrieve the outputs dictionary Uses the comfyui UI API calls to retrieve the outputs dictionary
@ -160,3 +173,41 @@ class AsyncRemoteComfyClient:
# images have filename, subfolder, type keys # images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it # todo: use the OpenAPI spec for this when I get around to updating it
return history_json[prompt_id].outputs return history_json[prompt_id].outputs
async def get_prompt_status(self, prompt_id: str) -> ClientResponse:
"""
Get the status of a prompt by ID using the API endpoint.
:param prompt_id: The prompt ID to query
:return: The ClientResponse object (caller should check status and read body)
"""
return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}"))
@tracer.start_as_current_span("Poll Prompt Until Done")
async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]:
"""
Poll a prompt until it's done (200), errors (500), or times out.
:param prompt_id: The prompt ID to poll
:param max_attempts: Maximum number of polling attempts
:param poll_interval: Time to wait between polls in seconds
:return: Tuple of (status_code, response_json or None)
"""
span = trace.get_current_span()
span.set_attribute("prompt_id", prompt_id)
span.set_attribute("max_attempts", max_attempts)
for _ in range(max_attempts):
async with await self.get_prompt_status(prompt_id) as response:
if response.status == 200:
return response.status, await response.json()
elif response.status == 500:
return response.status, await response.json()
elif response.status == 404:
return response.status, None
elif response.status == 204:
# Still in progress
await asyncio.sleep(poll_interval)
else:
# Unexpected status
return response.status, None
# Timeout
return 408, None

View File

@ -1,8 +1,6 @@
import copy import copy
from typing import TypeAlias, Union from typing import TypeAlias, Union
from ..api.components.schema.prompt import PromptDict, Prompt
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None] JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
_BASE_PROMPT: JSON = { _BASE_PROMPT: JSON = {
"4": { "4": {
@ -161,7 +159,7 @@ def sdxl_workflow_with_refiner(prompt: str,
sampler="euler_ancestral", sampler="euler_ancestral",
scheduler="normal", scheduler="normal",
filename_prefix="sdxl_", filename_prefix="sdxl_",
seed=42) -> PromptDict: seed=42) -> dict:
prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT) prompt_dict: JSON = copy.deepcopy(_BASE_PROMPT)
prompt_dict["17"]["inputs"]["text"] = prompt prompt_dict["17"]["inputs"]["text"] = prompt
prompt_dict["20"]["inputs"]["text"] = negative_prompt prompt_dict["20"]["inputs"]["text"] = negative_prompt
@ -188,4 +186,4 @@ def sdxl_workflow_with_refiner(prompt: str,
prompt_dict["14"]["inputs"]["scheduler"] = scheduler prompt_dict["14"]["inputs"]["scheduler"] = scheduler
prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix prompt_dict["13"]["inputs"]["filename_prefix"] = filename_prefix
return Prompt.validate(prompt_dict) return prompt_dict

View File

@ -1246,16 +1246,36 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = [] errors_list = []
extra_info = {}
for o, _errors in errors: for o, _errors in errors:
for error in _errors: for error in _errors:
errors_list.append(f"{error['message']}: {error['details']}") errors_list.append(f"{error['message']}: {error['details']}")
# Aggregate exception_type and traceback from validation errors
if 'extra_info' in error and error['extra_info']:
if 'exception_type' in error['extra_info'] and 'exception_type' not in extra_info:
extra_info['exception_type'] = error['extra_info']['exception_type']
if 'traceback' in error['extra_info'] and 'traceback' not in extra_info:
extra_info['traceback'] = error['extra_info']['traceback']
# Per OpenAPI spec, extra_info must have exception_type and traceback
# For non-exception validation errors, provide synthetic values
if 'exception_type' not in extra_info:
extra_info['exception_type'] = 'ValidationError'
if 'traceback' not in extra_info:
# Capture current stack for validation errors that don't have their own traceback
extra_info['traceback'] = traceback.format_stack()
# Include detailed node_errors for actionable debugging information
if node_errors:
extra_info['node_errors'] = node_errors
errors_list = "\n".join(errors_list) errors_list = "\n".join(errors_list)
error = { error = {
"type": "prompt_outputs_failed_validation", "type": "prompt_outputs_failed_validation",
"message": "Prompt outputs failed validation", "message": "Prompt outputs failed validation",
"details": errors_list, "details": errors_list,
"extra_info": {} "extra_info": extra_info
} }
return ValidationTuple(False, error, list(good_outputs), node_errors) return ValidationTuple(False, error, list(good_outputs), node_errors)
@ -1301,7 +1321,7 @@ class PromptQueue(AbstractPromptQueue):
return copy.deepcopy(item_with_future.queue_tuple), task_id return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id: str, outputs: HistoryResultDict, def task_done(self, item_id: str, outputs: HistoryResultDict,
status: Optional[ExecutionStatus]): status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None):
history_result = outputs history_result = outputs
with self.mutex: with self.mutex:
queue_item = self.currently_running.pop(item_id) queue_item = self.currently_running.pop(item_id)
@ -1311,7 +1331,7 @@ class PromptQueue(AbstractPromptQueue):
status_dict = None status_dict = None
if status is not None: if status is not None:
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict() status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details)
outputs_ = history_result["outputs"] outputs_ = history_result["outputs"]
# Remove sensitive data from extra_data before storing in history # Remove sensitive data from extra_data before storing in history

View File

@ -79,12 +79,25 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
await e.execute_async(item[2], prompt_id, item[3], item[4]) await e.execute_async(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
# Extract error details from status_messages if there's an error
error_details = None
if not e.success:
for event, data in e.status_messages:
if event == "execution_error":
error_details = data
break
# Convert status_messages tuples to string messages for backward compatibility
messages = [f"{event}: {data.get('exception_message', str(data))}" if isinstance(data, dict) and 'exception_message' in data else f"{event}" for event, data in e.status_messages]
q.task_done(item_id, q.task_done(item_id,
e.history_result, e.history_result,
status=queue_types.ExecutionStatus( status=queue_types.ExecutionStatus(
status_str='success' if e.success else 'error', status_str='success' if e.success else 'error',
completed=e.success, completed=e.success,
messages=e.status_messages)) messages=messages),
error_details=error_details)
if server_instance.client_id is not None: if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
server_instance.client_id) server_instance.client_id)

View File

@ -15,6 +15,7 @@ import shutil
import warnings import warnings
import fsspec import fsspec
from opentelemetry.instrumentation.urllib3 import URLLib3Instrumentor
from .. import options from .. import options
from ..app import logger from ..app import logger
@ -41,6 +42,7 @@ warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_
warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.") warnings.filterwarnings("ignore", message="Torch was not compiled with flash attention.")
warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*") warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
warnings.filterwarnings('ignore', category=FutureWarning, message=r'`torch\.cuda\.amp\.custom_fwd.*') warnings.filterwarnings('ignore', category=FutureWarning, message=r'`torch\.cuda\.amp\.custom_fwd.*')
warnings.filterwarnings("ignore", category=UserWarning, message="Please use the new API settings to control TF32 behavior.*")
warnings.filterwarnings("ignore", message="Importing from timm.models.registry is deprecated, please import via timm.models", category=FutureWarning) warnings.filterwarnings("ignore", message="Importing from timm.models.registry is deprecated, please import via timm.models", category=FutureWarning)
warnings.filterwarnings("ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning) warnings.filterwarnings("ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning)
warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning) warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning)
@ -125,9 +127,11 @@ def _create_tracer():
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
from ..tracing_compatibility import ProgressSpanSampler from ..tracing_compatibility import ProgressSpanSampler
from ..tracing_compatibility import patch_spanbuilder_set_channel from ..tracing_compatibility import patch_spanbuilder_set_channel
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
resource = Resource.create({ resource = Resource.create({
service_attributes.SERVICE_NAME: args.otel_service_name, service_attributes.SERVICE_NAME: args.otel_service_name,
@ -141,18 +145,24 @@ def _create_tracer():
has_endpoint = args.otel_exporter_otlp_endpoint is not None has_endpoint = args.otel_exporter_otlp_endpoint is not None
if has_endpoint: if has_endpoint:
otlp_exporter = OTLPSpanExporter() exporter = OTLPSpanExporter()
else: else:
otlp_exporter = SpanExporter() exporter = SpanExporter()
processor = BatchSpanProcessor(otlp_exporter) processor = BatchSpanProcessor(exporter)
provider.add_span_processor(processor) provider.add_span_processor(processor)
# enable instrumentation # enable instrumentation
patch_spanbuilder_set_channel() patch_spanbuilder_set_channel()
AioPikaInstrumentor().instrument() AioPikaInstrumentor().instrument()
AioHttpServerInstrumentor().instrument() AioHttpServerInstrumentor().instrument()
AioHttpClientInstrumentor().instrument()
RequestsInstrumentor().instrument() RequestsInstrumentor().instrument()
URLLib3Instrumentor().instrument()
provider.add_span_processor(BaggageSpanProcessor(ALLOW_ALL_BAGGAGE_KEYS))
# makes this behave better as a library # makes this behave better as a library
return trace.get_tracer(args.otel_service_name, tracer_provider=provider) return trace.get_tracer(args.otel_service_name, tracer_provider=provider)

View File

@ -116,6 +116,24 @@ async def compress_body(request: web.Request, handler):
return response return response
@web.middleware
async def opentelemetry_middleware(request: web.Request, handler):
"""Middleware to extract and propagate OpenTelemetry context from request headers"""
from opentelemetry import propagate, context
# Extract OpenTelemetry context from headers
carrier = dict(request.headers)
ctx = propagate.extract(carrier)
# Attach context and execute handler
token = context.attach(ctx)
try:
response = await handler(request)
return response
finally:
context.detach(token)
def create_cors_middleware(allowed_origin: str): def create_cors_middleware(allowed_origin: str):
@web.middleware @web.middleware
async def cors_middleware(request: web.Request, handler): async def cors_middleware(request: web.Request, handler):
@ -127,7 +145,7 @@ def create_cors_middleware(allowed_origin: str):
response.headers['Access-Control-Allow-Origin'] = allowed_origin response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate'
response.headers['Access-Control-Allow-Credentials'] = 'true' response.headers['Access-Control-Allow-Credentials'] = 'true'
return response return response
@ -224,7 +242,7 @@ class PromptServer(ExecutorToClientProgress):
self._external_address: Optional[str] = None self._external_address: Optional[str] = None
self.background_tasks: dict[str, Task] = dict() self.background_tasks: dict[str, Task] = dict()
middlewares = [cache_control, deprecation_warning] middlewares = [opentelemetry_middleware, cache_control, deprecation_warning]
if args.enable_compress_response_body: if args.enable_compress_response_body:
middlewares.append(compress_body) middlewares.append(compress_body)
@ -867,9 +885,19 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(status=404) return web.json_response(status=404)
elif prompt_id in history_items: elif prompt_id in history_items:
history_entry = history_items[prompt_id] history_entry = history_items[prompt_id]
# Check if execution resulted in an error
if "status" in history_entry:
status = history_entry["status"]
if isinstance(status, dict) and status.get("status_str") == "error":
# Return ExecutionStatusAsDict format with status 500, matching POST /api/v1/prompts behavior
return web.Response(
body=json.dumps(status),
status=500,
content_type="application/json"
)
return web.json_response(history_entry["outputs"]) return web.json_response(history_entry["outputs"])
else: else:
return web.json_response(status=500) return web.Response(status=404, reason="prompt not found in expected state")
@routes.post("/api/v1/prompts") @routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
@ -877,9 +905,13 @@ class PromptServer(ExecutorToClientProgress):
if accept == '*/*': if accept == '*/*':
accept = "application/json" accept = "application/json"
content_type = request.headers.get("content-type", "application/json") content_type = request.headers.get("content-type", "application/json")
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type + " " + accept
# handle media type parameters like "application/json+respond-async"
if "+" in content_type: if "+" in content_type:
content_type = content_type.split("+")[0] content_type = content_type.split("+")[0]
if "+" in accept:
accept = accept.split("+")[0]
wait = not "respond-async" in preferences wait = not "respond-async" in preferences
@ -965,7 +997,8 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(body=str(ex), status=500) return web.Response(body=str(ex), status=500)
if result.status is not None and result.status.status_str == "error": if result.status is not None and result.status.status_str == "error":
return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json") status_dict = result.status.as_dict(error_details=result.error_details)
return web.Response(body=json.dumps(status_dict), status=500, content_type="application/json")
# find images and read them # find images and read them
output_images: List[FileOutput] = [] output_images: List[FileOutput] = []
for node_id, node in result.outputs.items(): for node_id, node in result.outputs.items():

View File

@ -199,9 +199,8 @@ class ValidationErrorExtraInfoDict(TypedDict, total=False):
input_config: NotRequired[Dict[str, InputTypeSpec]] input_config: NotRequired[Dict[str, InputTypeSpec]]
received_value: NotRequired[Any] received_value: NotRequired[Any]
linked_node: NotRequired[str] linked_node: NotRequired[str]
traceback: NotRequired[list[str]]
exception_message: NotRequired[str] exception_message: NotRequired[str]
exception_type: NotRequired[str] node_errors: NotRequired[Dict[str, 'NodeErrorsDictValue']]
class ValidationErrorDict(TypedDict): class ValidationErrorDict(TypedDict):

View File

@ -18,6 +18,7 @@ class TaskInvocation(NamedTuple):
item_id: int | str item_id: int | str
outputs: OutputsDict outputs: OutputsDict
status: Optional[ExecutionStatus] status: Optional[ExecutionStatus]
error_details: Optional['ExecutionErrorMessage'] = None
class ExecutionStatus(NamedTuple): class ExecutionStatus(NamedTuple):
@ -25,12 +26,15 @@ class ExecutionStatus(NamedTuple):
completed: bool completed: bool
messages: List[str] messages: List[str]
def as_dict(self) -> ExecutionStatusAsDict: def as_dict(self, error_details: Optional['ExecutionErrorMessage'] = None) -> ExecutionStatusAsDict:
return { result: ExecutionStatusAsDict = {
"status_str": self.status_str, "status_str": self.status_str,
"completed": self.completed, "completed": self.completed,
"messages": copy.copy(self.messages), "messages": copy.copy(self.messages),
} }
if error_details is not None:
result["error_details"] = error_details
return result
class ExecutionError(RuntimeError): class ExecutionError(RuntimeError):

View File

@ -162,7 +162,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
return item, item[1] return item, item[1]
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]): def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None):
# callee: executed on the worker thread # callee: executed on the worker thread
if "outputs" in outputs: if "outputs" in outputs:
outputs: HistoryResultDict outputs: HistoryResultDict
@ -173,7 +173,7 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
assert pending.completed is not None assert pending.completed is not None
assert not pending.completed.done() assert not pending.completed.done()
# finish the task. status will transmit the errors in comfy's domain-specific way # finish the task. status will transmit the errors in comfy's domain-specific way
pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status)) pending.completed.set_result(TaskInvocation(item_id=item_id, outputs=outputs, status=status, error_details=error_details))
# todo: the caller is responsible for sending a websocket message right now that the UI expects for updates # todo: the caller is responsible for sending a websocket message right now that the UI expects for updates
def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]: def get_current_queue(self) -> Tuple[List[QueueTuple], List[QueueTuple]]:

View File

@ -3,7 +3,7 @@ from typing import Optional, Sequence
from aio_pika.abc import AbstractChannel from aio_pika.abc import AbstractChannel
from opentelemetry.context import Context from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision
from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.semconv.attributes.network_attributes import NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT
from opentelemetry.trace import SpanKind, Link, TraceState from opentelemetry.trace import SpanKind, Link, TraceState
from opentelemetry.util.types import Attributes from opentelemetry.util.types import Attributes
@ -22,8 +22,8 @@ def patch_spanbuilder_set_channel() -> None:
port = url.port or 5672 port = url.port or 5672
self._attributes.update( self._attributes.update(
{ {
SpanAttributes.NET_PEER_NAME: url.host, NETWORK_PEER_ADDRESS: url.host,
SpanAttributes.NET_PEER_PORT: port, NETWORK_PEER_PORT: port,
} }
) )

View File

@ -1,271 +0,0 @@
# Copyright 2020, OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib
from timeit import default_timer
from typing import Dict, List, Tuple, Union
from aiohttp import web
from multidict import CIMultiDictProxy
from opentelemetry import metrics, trace
_instruments = ("aiohttp ~= 3.0",)
__version__ = "0.49b0.dev"
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
http_status_to_status_code,
is_http_instrumentation_enabled,
)
from opentelemetry.propagate import extract
from opentelemetry.propagators.textmap import Getter
from opentelemetry.semconv.metrics import MetricInstruments
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util.http import get_excluded_urls, remove_url_credentials
_duration_attrs = [
SpanAttributes.HTTP_METHOD,
SpanAttributes.HTTP_HOST,
SpanAttributes.HTTP_SCHEME,
SpanAttributes.HTTP_STATUS_CODE,
SpanAttributes.HTTP_FLAVOR,
SpanAttributes.HTTP_SERVER_NAME,
SpanAttributes.NET_HOST_NAME,
SpanAttributes.NET_HOST_PORT,
SpanAttributes.HTTP_ROUTE,
]
_active_requests_count_attrs = [
SpanAttributes.HTTP_METHOD,
SpanAttributes.HTTP_HOST,
SpanAttributes.HTTP_SCHEME,
SpanAttributes.HTTP_FLAVOR,
SpanAttributes.HTTP_SERVER_NAME,
]
tracer = trace.get_tracer(__name__)
meter = metrics.get_meter(__name__, __version__)
_excluded_urls = get_excluded_urls("AIOHTTP_SERVER")
def _parse_duration_attrs(req_attrs):
duration_attrs = {}
for attr_key in _duration_attrs:
if req_attrs.get(attr_key) is not None:
duration_attrs[attr_key] = req_attrs[attr_key]
return duration_attrs
def _parse_active_request_count_attrs(req_attrs):
active_requests_count_attrs = {}
for attr_key in _active_requests_count_attrs:
if req_attrs.get(attr_key) is not None:
active_requests_count_attrs[attr_key] = req_attrs[attr_key]
return active_requests_count_attrs
def get_default_span_details(request: web.Request) -> Tuple[str, dict]:
"""Default implementation for get_default_span_details
Args:
request: the request object itself.
Returns:
a tuple of the span name, and any attributes to attach to the span.
"""
span_name = request.path.strip() or f"HTTP {request.method}"
return span_name, {}
def _get_view_func(request: web.Request) -> str:
"""Returns the name of the request handler.
Args:
request: the request object itself.
Returns:
a string containing the name of the handler function
"""
try:
return request.match_info.handler.__name__
except AttributeError:
return "unknown"
def collect_request_attributes(request: web.Request) -> Dict:
"""Collects HTTP request attributes from the ASGI scope and returns a
dictionary to be used as span creation attributes."""
server_host, port, http_url = (
request.url.host,
request.url.port,
str(request.url),
)
query_string = request.query_string
if query_string and http_url:
if isinstance(query_string, bytes):
query_string = query_string.decode("utf8")
http_url += "?" + urllib.parse.unquote(query_string)
result = {
SpanAttributes.HTTP_SCHEME: request.scheme,
SpanAttributes.HTTP_HOST: server_host,
SpanAttributes.NET_HOST_PORT: port,
SpanAttributes.HTTP_ROUTE: _get_view_func(request),
SpanAttributes.HTTP_FLAVOR: f"{request.version.major}.{request.version.minor}",
SpanAttributes.HTTP_TARGET: request.path,
SpanAttributes.HTTP_URL: remove_url_credentials(http_url),
}
http_method = request.method
if http_method:
result[SpanAttributes.HTTP_METHOD] = http_method
http_host_value_list = (
[request.host] if not isinstance(request.host, list) else request.host
)
if http_host_value_list:
result[SpanAttributes.HTTP_SERVER_NAME] = ",".join(
http_host_value_list
)
http_user_agent = request.headers.get("user-agent")
if http_user_agent:
result[SpanAttributes.HTTP_USER_AGENT] = http_user_agent
# remove None values
result = {k: v for k, v in result.items() if v is not None}
return result
def set_status_code(span, status_code: int) -> None:
"""Adds HTTP response attributes to span using the status_code argument."""
try:
status_code = int(status_code)
except ValueError:
span.set_status(
Status(
StatusCode.ERROR,
"Non-integer HTTP status: " + repr(status_code),
)
)
else:
span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status_code)
span.set_status(
Status(http_status_to_status_code(status_code, server_span=True))
)
class AiohttpGetter(Getter):
"""Extract current trace from headers"""
def get(self, carrier, key: str) -> Union[List, None]:
"""Getter implementation to retrieve an HTTP header value from the ASGI
scope.
Args:
carrier: ASGI scope object
key: header name in scope
Returns:
A list of all header values matching the key, or None if the key
does not match any header.
"""
headers: CIMultiDictProxy = carrier.headers
if not headers:
return None
return headers.getall(key, None)
def keys(self, carrier: Dict) -> List:
return list(carrier.keys())
getter = AiohttpGetter()
@web.middleware
async def middleware(request, handler):
"""Middleware for aiohttp implementing tracing logic"""
if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled(
request.url.path
):
return await handler(request)
span_name, additional_attributes = get_default_span_details(request)
req_attrs = collect_request_attributes(request)
duration_attrs = _parse_duration_attrs(req_attrs)
active_requests_count_attrs = _parse_active_request_count_attrs(req_attrs)
duration_histogram = meter.create_histogram(
name=MetricInstruments.HTTP_SERVER_DURATION,
unit="ms",
description="Measures the duration of inbound HTTP requests.",
)
active_requests_counter = meter.create_up_down_counter(
name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS,
unit="requests",
description="measures the number of concurrent HTTP requests those are currently in flight",
)
with tracer.start_as_current_span(
span_name,
context=extract(request, getter=getter),
kind=trace.SpanKind.SERVER,
) as span:
attributes = collect_request_attributes(request)
attributes.update(additional_attributes)
span.set_attributes(attributes)
start = default_timer()
active_requests_counter.add(1, active_requests_count_attrs)
try:
resp = await handler(request)
set_status_code(span, resp.status)
except web.HTTPException as ex:
set_status_code(span, ex.status_code)
raise
except AttributeError:
# No response was returned or a NoneType response was returned, handle gracefully
set_status_code(span, 500)
raise
finally:
duration = max((default_timer() - start) * 1000, 0)
duration_histogram.record(duration, duration_attrs)
active_requests_counter.add(-1, active_requests_count_attrs)
return resp
class _InstrumentedApplication(web.Application):
"""Insert tracing middleware"""
def __init__(self, *args, **kwargs):
middlewares = kwargs.pop("middlewares", [])
middlewares.insert(0, middleware)
kwargs["middlewares"] = middlewares
super().__init__(*args, **kwargs)
class AioHttpServerInstrumentor(BaseInstrumentor):
# pylint: disable=protected-access,attribute-defined-outside-init
"""An instrumentor for aiohttp.web.Application
See `BaseInstrumentor`
"""
def _instrument(self, **kwargs):
self._original_app = web.Application
setattr(web, "Application", _InstrumentedApplication)
def _uninstrument(self, **kwargs):
setattr(web, "Application", self._original_app)
def instrumentation_dependencies(self):
return _instruments

View File

View File

@ -0,0 +1,242 @@
/**
* Uses code adapted from https://github.com/yorkane/ComfyUI-KYNode
*
* MIT License
*
* Copyright (c) 2024 Kevin Yuan
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
import { app } from "../../scripts/app.js";
// Load Ace editor using script tag for Safari compatibility
// The noconflict build includes AMD loader that works in all browsers
let ace;
const aceLoadPromise = new Promise((resolve) => {
if (window.ace) {
ace = window.ace;
resolve();
} else {
const script = document.createElement("script");
script.src = "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict/ace.js";
script.onload = () => {
ace = window.ace;
ace.config.set("basePath", "https://cdn.jsdelivr.net/npm/ace-builds@1.43.4/src-noconflict");
resolve();
};
document.head.appendChild(script);
}
});
// todo: do we really want to do this here?
await aceLoadPromise;
const findWidget = (node, value, attr = "name", func = "find") => {
return node?.widgets ? node.widgets[func]((w) => (Array.isArray(value) ? value.includes(w[attr]) : w[attr] === value)) : null;
};
const makeElement = (tag, attrs = {}) => {
if (!tag) tag = "div";
const element = document.createElement(tag);
Object.keys(attrs).forEach((key) => {
const currValue = attrs[key];
if (key === "class") {
if (Array.isArray(currValue)) {
element.classList.add(...currValue);
} else if (currValue instanceof String || typeof currValue === "string") {
element.className = currValue;
}
} else if (key === "dataset") {
try {
if (Array.isArray(currValue)) {
currValue.forEach((datasetArr) => {
const [prop, propval] = Object.entries(datasetArr)[0];
element.dataset[prop] = propval;
});
} else {
Object.entries(currValue).forEach((datasetArr) => {
const [prop, propval] = datasetArr;
element.dataset[prop] = propval;
});
}
} catch (err) {
// todo: what is this trying to do?
}
} else if (key === "style") {
if (typeof currValue === "object" && !Array.isArray(currValue) && Object.keys(currValue).length) {
Object.assign(element[key], currValue);
} else if (typeof currValue === "object" && Array.isArray(currValue) && currValue.length) {
element[key] = [...currValue];
} else if (currValue instanceof String || typeof currValue === "string") {
element[key] = currValue;
}
} else if (["for"].includes(key)) {
element.setAttribute(key, currValue);
} else if (key === "children") {
element.append(...(currValue instanceof Array ? currValue : [currValue]));
} else if (key === "parent") {
currValue.append(element);
} else {
element[key] = currValue;
}
});
return element;
};
const getPosition = (node, ctx, w_width, y, n_height) => {
const margin = 5;
const rect = ctx.canvas.getBoundingClientRect();
const transform = ctx.getTransform();
const scale = app.canvas.ds.scale;
// The context is already transformed to draw at the widget position
// transform.e and transform.f give us the canvas coordinates (in canvas pixels)
// We need to convert these to screen pixels by accounting for the canvas scale
// rect gives us the canvas element's position on the page
// The transform matrix has scale baked in (transform.a = transform.d = scale)
// transform.e and transform.f are the translation in canvas-pixel space
const canvasPixelToScreenPixel = rect.width / ctx.canvas.width;
const x = transform.e * canvasPixelToScreenPixel + rect.left;
const y_pos = transform.f * canvasPixelToScreenPixel + rect.top;
// Convert widget dimensions from canvas coordinates to screen pixels
const scaledWidth = w_width * scale;
const scaledHeight = (n_height - y - 15) * scale;
const scaledMargin = margin * scale;
const scaledY = y * scale;
return {
left: `${x + scaledMargin}px`,
top: `${y_pos + scaledY + scaledMargin}px`,
width: `${scaledWidth - scaledMargin * 2}px`,
maxWidth: `${scaledWidth - scaledMargin * 2}px`,
height: `${scaledHeight - scaledMargin * 2}px`,
maxHeight: `${scaledHeight - scaledMargin * 2}px`,
position: "absolute",
scrollbarColor: "var(--descrip-text) var(--bg-color)",
scrollbarWidth: "thin",
zIndex: app.graph._nodes.indexOf(node),
};
};
// Create code editor widget
const codeEditor = (node, inputName, inputData) => {
const widget = {
type: "code_block_python",
name: inputName,
options: { hideOnZoom: true },
value: inputData[1]?.default || "",
draw(ctx, node, widgetWidth, y) {
const hidden = node.flags?.collapsed || (!!this.options.hideOnZoom && app.canvas.ds.scale < 0.5) || this.type === "converted-widget" || this.type === "hidden" || this.type === "converted-widget";
this.codeElement.hidden = hidden;
if (hidden) {
this.options.onHide?.(this);
return;
}
Object.assign(this.codeElement.style, getPosition(node, ctx, widgetWidth, y, node.size[1]));
},
computeSize() {
return [500, 250];
},
};
widget.codeElement = makeElement("pre", {
innerHTML: widget.value,
});
widget.editor = ace.edit(widget.codeElement);
widget.editor.setTheme("ace/theme/monokai");
widget.editor.session.setMode("ace/mode/python");
widget.editor.setOptions({
enableAutoIndent: true,
enableLiveAutocompletion: true,
enableBasicAutocompletion: true,
fontFamily: "monospace",
});
widget.codeElement.hidden = true;
document.body.appendChild(widget.codeElement);
const originalCollapse = node.collapse;
node.collapse = function () {
originalCollapse.apply(this, arguments);
widget.codeElement.hidden = !!this.flags?.collapsed;
};
return widget;
};
// Trigger workflow change tracking
const markWorkflowChanged = () => {
app?.extensionManager?.workflow?.activeWorkflow?.changeTracker?.checkState();
};
// Register extensions
app.registerExtension({
name: "Comfy.EvalPython",
getCustomWidgets(app) {
return {
CODE_BLOCK_PYTHON: (node, inputName, inputData) => {
const widget = codeEditor(node, inputName, inputData);
widget.editor.getSession().on("change", () => {
widget.value = widget.editor.getValue();
markWorkflowChanged();
});
node.onRemoved = function () {
for (const w of this.widgets) {
if (w?.codeElement) {
w.codeElement.remove();
}
}
};
node.addCustomWidget(widget);
return widget;
},
};
},
async beforeRegisterNodeDef(nodeType, nodeData) {
// Handle all EvalPython node variants
if (nodeData.name.startsWith("EvalPython")) {
const originalOnConfigure = nodeType.prototype.onConfigure;
nodeType.prototype.onConfigure = function (info) {
originalOnConfigure?.apply(this, arguments);
if (info?.widgets_values?.length) {
const widgetCodeIndex = findWidget(this, "code_block_python", "type", "findIndex");
const editor = this.widgets[widgetCodeIndex]?.editor;
if (editor) {
editor.setValue(info.widgets_values[widgetCodeIndex]);
editor.clearSelection();
}
}
};
}
},
});

View File

@ -0,0 +1,119 @@
import logging
from comfy.comfy_types import IO
from comfy.execution_context import current_execution_context
from comfy.node_helpers import export_package_as_web_directory, export_custom_nodes
from comfy.nodes.package_typing import CustomNode
logger = logging.getLogger(__name__)
def eval_python(inputs=5, outputs=5, name=None, input_is_list=None, output_is_list=None):
"""
Factory function to create EvalPython node classes with configurable input/output counts.
Args:
inputs: Number of input value slots (default: 5)
outputs: Number of output item slots (default: 5)
name: Class name (default: f"EvalPython_{inputs}_{outputs}")
input_is_list: Optional list of bools indicating which inputs accept lists (default: None, meaning all scalar)
output_is_list: Optional tuple of bools indicating which outputs return lists (default: None, meaning all scalar)
Returns:
A CustomNode subclass configured with the specified inputs/outputs
"""
if name is None:
name = f"EvalPython_{inputs}_{outputs}"
default_code = f"""
print("Hello World!")
return {", ".join([f"value{i}" for i in range(inputs)])}
"""
class EvalPythonNode(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"pycode": (
"CODE_BLOCK_PYTHON",
{
"default": default_code
},
),
},
"optional": {f"value{i}": (IO.ANY, {}) for i in range(inputs)},
}
RETURN_TYPES = tuple(IO.ANY for _ in range(outputs))
RETURN_NAMES = tuple(f"item{i}" for i in range(outputs))
FUNCTION = "exec_py"
DESCRIPTION = ""
CATEGORY = "eval"
@classmethod
def VALIDATE_INPUTS(cls, *args, **kwargs):
ctx = current_execution_context()
return ctx.configuration.enable_eval
def exec_py(self, pycode, **kwargs):
ctx = current_execution_context()
kwargs = {
**{f"value{i}": None for i in range(inputs)},
**kwargs,
}
def print(*args):
ctx.server.send_progress_text(" ".join(map(str, args)), ctx.node_id)
if not ctx.configuration.enable_eval:
raise ValueError("Python eval is disabled")
value_args = [kwargs.pop(f"value{i}") for i in range(inputs)]
arg_names = ", ".join(f"value{i}=None" for i in range(inputs))
wrapped_code = f"def _eval_func({arg_names}):\n"
for line in pycode.splitlines():
wrapped_code += " " + line + "\n"
globals_for_eval = {
**kwargs,
"logger": logger,
"print": print,
}
exec(wrapped_code, globals_for_eval)
results = globals_for_eval["_eval_func"](*value_args)
if not isinstance(results, tuple):
results = (results,)
if len(results) < outputs:
results += (None,) * (outputs - len(results))
elif len(results) > outputs:
results = results[:outputs]
return results
# todo: interact better with the weird comfyui machinery for this
if input_is_list is not None:
setattr(EvalPythonNode, "INPUT_IS_LIST", input_is_list)
if output_is_list is not None:
setattr(EvalPythonNode, "OUTPUT_IS_LIST", output_is_list)
EvalPythonNode.__name__ = name
EvalPythonNode.__qualname__ = name
return EvalPythonNode
EvalPython_1_1 = eval_python(inputs=1, outputs=1, name="EvalPython_1_1")
EvalPython_5_5 = eval_python(inputs=5, outputs=5, name="EvalPython_5_5")
EvalPython_List_1 = eval_python(inputs=1, outputs=1, name="EvalPython_List_1", input_is_list=True, output_is_list=None)
EvalPython_1_List = eval_python(inputs=1, outputs=1, name="EvalPython_1_List", input_is_list=None, output_is_list=(True,))
EvalPython_List_List = eval_python(inputs=1, outputs=1, name="EvalPython_List_List", input_is_list=True, output_is_list=(True,))
export_custom_nodes()
export_package_as_web_directory("comfy_extras.eval_web")

View File

@ -65,13 +65,18 @@ dependencies = [
"natsort", "natsort",
"OpenEXR", "OpenEXR",
"opentelemetry-distro", "opentelemetry-distro",
"opentelemetry-sdk<1.34.0", "opentelemetry-sdk",
"opentelemetry-exporter-otlp<=1.27.0", "opentelemetry-exporter-otlp",
"opentelemetry-propagator-jaeger", "opentelemetry-propagator-jaeger",
"opentelemetry-instrumentation", "opentelemetry-instrumentation",
"opentelemetry-util-http", "opentelemetry-util-http",
"opentelemetry-instrumentation-aio-pika", "opentelemetry-instrumentation-aio-pika",
"opentelemetry-instrumentation-requests", "opentelemetry-instrumentation-requests",
"opentelemetry-instrumentation-aiohttp-server",
"opentelemetry-instrumentation-aiohttp-client",
"opentelemetry-instrumentation-asyncio",
"opentelemetry-instrumentation-urllib3",
"opentelemetry-processor-baggage",
"opentelemetry-semantic-conventions", "opentelemetry-semantic-conventions",
"wrapt>=1.16.0", "wrapt>=1.16.0",
"certifi", "certifi",

View File

@ -3,8 +3,6 @@ import multiprocessing
import os import os
import pathlib import pathlib
import subprocess import subprocess
import sys
import time
import urllib import urllib
from contextvars import ContextVar from contextvars import ContextVar
from multiprocessing import Process from multiprocessing import Process
@ -12,9 +10,9 @@ from typing import List, Any, Generator
import pytest import pytest
import requests import requests
import sys
import time
from comfy.cli_args import default_configuration
from comfy.execution_context import context_configuration
os.environ['OTEL_METRICS_EXPORTER'] = 'none' os.environ['OTEL_METRICS_EXPORTER'] = 'none'
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
@ -22,6 +20,7 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
# fixes issues with running the testcontainers rabbitmqcontainer on Windows # fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost" os.environ["TC_HOST"] = "localhost"
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration from comfy.cli_args_types import Configuration
logging.getLogger("pika").setLevel(logging.CRITICAL + 1) logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
@ -91,9 +90,17 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
params = rabbitmq.get_connection_params() params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
# Check if OTEL endpoint is configured for integration testing
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
env = os.environ.copy()
if otel_endpoint:
env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otel_endpoint
logging.info(f"Configuring services to export traces to: {otel_endpoint}")
frontend_command = [ frontend_command = [
"comfyui", "comfyui",
"--listen=127.0.0.1", "--listen=0.0.0.0",
"--port=19001", "--port=19001",
"--cpu", "--cpu",
"--distributed-queue-frontend", "--distributed-queue-frontend",
@ -101,7 +108,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
f"--distributed-queue-connection-uri={connection_uri}", f"--distributed-queue-connection-uri={connection_uri}",
] ]
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
# Start multiple workers # Start multiple workers
for i in range(num_workers): for i in range(num_workers):
@ -112,7 +119,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
f"--distributed-queue-connection-uri={connection_uri}", f"--distributed-queue-connection-uri={connection_uri}",
f"--executor-factory={executor_factory}" f"--executor-factory={executor_factory}"
] ]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
try: try:
server_address = f"http://127.0.0.1:19001" server_address = f"http://127.0.0.1:19001"

View File

@ -243,3 +243,368 @@ async def test_two_workers_distinct_requests():
all_workflows.update(worker.processed_workflows) all_workflows.update(worker.processed_workflows)
assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}" assert len(all_workflows) == 2, f"Expected 2 distinct workflows, but got {len(all_workflows)}"
@pytest.mark.asyncio
async def test_api_error_reporting_blocking_request(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with blocking request (no async preference)"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create an invalid prompt that will cause a validation error
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Make the prompt invalid by referencing a non-existent checkpoint
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent_checkpoint.safetensors"
# Post with blocking behavior (no prefer header for async)
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
# Should return 400 for validation error (invalid checkpoint)
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
# Verify ValidationErrorDict structure per OpenAPI spec
assert "type" in error_body, "Missing 'type' field in error response"
assert "message" in error_body, "Missing 'message' field in error response"
assert "details" in error_body, "Missing 'details' field in error response"
assert "extra_info" in error_body, "Missing 'extra_info' field in error response"
@pytest.mark.asyncio
async def test_api_error_reporting_async_prefer_header(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with Prefer: respond-async header"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt structure but with invalid checkpoint
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "nonexistent.safetensors"
# Post with Prefer: respond-async header
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={
'Content-Type': 'application/json',
'Accept': 'application/json',
'Prefer': 'respond-async'
}
) as response:
# Should return 400 immediately for validation error
assert response.status == 400, f"Expected 400 for validation error, got {response.status}"
error_body = await response.json()
assert "type" in error_body
@pytest.mark.asyncio
async def test_api_error_reporting_async_accept_mimetype(frontend_backend_worker_with_rabbitmq):
"""Test error reporting with +respond-async in Accept mimetype"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt with validation error
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "invalid_model.safetensors"
# Post with +respond-async in Accept header
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={
'Content-Type': 'application/json',
'Accept': 'application/json+respond-async'
}
) as response:
# Should return 400 for validation error (happens before queuing)
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
assert "type" in error_body
@pytest.mark.asyncio
async def test_api_get_prompt_status_success(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 200 with Outputs on success"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Queue async to get prompt_id
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None
# Poll until done
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
# For a valid prompt, should get 200
assert status_code == 200, f"Expected 200 for successful execution, got {status_code}"
assert result is not None
# Verify it returns outputs structure (dict with node IDs)
assert isinstance(result, dict)
assert len(result) > 0, "Expected non-empty outputs"
@pytest.mark.asyncio
async def test_api_get_prompt_status_404(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 404 for non-existent prompt"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Request a non-existent prompt ID
fake_prompt_id = str(uuid.uuid4())
async with await client.get_prompt_status(fake_prompt_id) as response:
assert response.status == 404, f"Expected 404 for non-existent prompt, got {response.status}"
@pytest.mark.asyncio
async def test_api_get_prompt_status_204_in_progress(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns 204 while prompt is in progress"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt that takes some time to execute
prompt = sdxl_workflow_with_refiner("test", inference_steps=10, refiner_steps=10)
# Queue async
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
# Immediately check status (should be 204 or 200 if very fast)
async with await client.get_prompt_status(task_id) as response:
# Should be either 204 (in progress) or 200 (completed very fast)
assert response.status in [200, 204], f"Expected 200 or 204, got {response.status}"
if response.status == 204:
# No content for in-progress
content = await response.read()
assert len(content) == 0 or content == b'', "Expected no content for 204 response"
@pytest.mark.asyncio
async def test_api_async_workflow_both_methods(frontend_backend_worker_with_rabbitmq):
"""Test full async workflow: queue with respond-async, then poll for completion"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Method 1: Prefer header
task_id_1 = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id_1 is not None
# Method 2: +respond-async in Accept header
task_id_2 = await client.queue_and_forget_prompt_api(
prompt, prefer_header=None, accept_header="application/json+respond-async"
)
assert task_id_2 is not None
# Poll both until done
status_1, result_1 = await client.poll_prompt_until_done(task_id_1, max_attempts=60, poll_interval=1.0)
status_2, result_2 = await client.poll_prompt_until_done(task_id_2, max_attempts=60, poll_interval=1.0)
# Both should succeed
assert status_1 == 200, f"Task 1 failed with status {status_1}"
assert status_2 == 200, f"Task 2 failed with status {status_2}"
# Both should have outputs
assert result_1 is not None and len(result_1) > 0
assert result_2 is not None and len(result_2) > 0
@pytest.mark.asyncio
async def test_api_validation_error_structure(frontend_backend_worker_with_rabbitmq):
"""Test that validation errors return proper ValidationErrorDict structure"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create an invalid prompt (invalid checkpoint name)
prompt = sdxl_workflow_with_refiner("test", "", 1, refiner_steps=1)
prompt["4"]["inputs"]["ckpt_name"] = "fake.safetensors"
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
assert response.status == 400, f"Expected 400, got {response.status}"
error_body = await response.json()
# Verify ValidationErrorDict structure per OpenAPI spec
assert "type" in error_body, "Missing 'type'"
assert "message" in error_body, "Missing 'message'"
assert "details" in error_body, "Missing 'details'"
assert "extra_info" in error_body, "Missing 'extra_info'"
assert error_body["type"] == "prompt_outputs_failed_validation", "unexpected type"
# extra_info should have exception_type and traceback
assert "exception_type" in error_body["extra_info"], "Missing 'exception_type' in extra_info"
assert "traceback" in error_body["extra_info"], "Missing 'traceback' in extra_info"
assert isinstance(error_body["extra_info"]["traceback"], list), "traceback should be a list"
# extra_info should have node_errors with detailed validation information
assert "node_errors" in error_body["extra_info"], "Missing 'node_errors' in extra_info"
node_errors = error_body["extra_info"]["node_errors"]
assert isinstance(node_errors, dict), "node_errors should be a dict"
assert len(node_errors) > 0, "node_errors should contain at least one node"
# Verify node_errors structure for node "4" (CheckpointLoaderSimple with invalid ckpt_name)
assert "4" in node_errors, "Node '4' should have validation errors"
node_4_errors = node_errors["4"]
assert "errors" in node_4_errors, "Node '4' should have 'errors' field"
assert "class_type" in node_4_errors, "Node '4' should have 'class_type' field"
assert "dependent_outputs" in node_4_errors, "Node '4' should have 'dependent_outputs' field"
assert node_4_errors["class_type"] == "CheckpointLoaderSimple", "Node '4' class_type should be CheckpointLoaderSimple"
assert len(node_4_errors["errors"]) > 0, "Node '4' should have at least one error"
# Verify the error details include the validation error type and message
first_error = node_4_errors["errors"][0]
assert "type" in first_error, "Error should have 'type' field"
assert "message" in first_error, "Error should have 'message' field"
assert "details" in first_error, "Error should have 'details' field"
assert first_error["type"] == "value_not_in_list", f"Expected 'value_not_in_list' error, got {first_error['type']}"
assert "fake.safetensors" in first_error["details"], "Error details should mention 'fake.safetensors'"
@pytest.mark.asyncio
async def test_api_success_response_contract(frontend_backend_worker_with_rabbitmq):
"""Test that successful execution returns proper response structure"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a valid prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
# Queue and wait for blocking response
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
assert response.status == 200, f"Expected 200, got {response.status}"
result = await response.json()
# Should have 'outputs' key (and deprecated 'urls' key)
assert "outputs" in result, "Missing 'outputs' in response"
# outputs should be a dict with node IDs as keys
outputs = result["outputs"]
assert isinstance(outputs, dict), "outputs should be a dict"
assert len(outputs) > 0, "outputs should not be empty"
# Each output should follow the Output schema
for node_id, output in outputs.items():
assert isinstance(output, dict), f"Output for node {node_id} should be a dict"
# Should have images or other output types
if "images" in output:
assert isinstance(output["images"], list), f"images for node {node_id} should be a list"
for image in output["images"]:
assert "filename" in image, f"image missing 'filename' in node {node_id}"
assert "subfolder" in image, f"image missing 'subfolder' in node {node_id}"
assert "type" in image, f"image missing 'type' in node {node_id}"
@pytest.mark.asyncio
async def test_api_get_prompt_returns_outputs_directly(frontend_backend_worker_with_rabbitmq):
"""Test GET /api/v1/prompts/{prompt_id} returns Outputs directly (not wrapped in history entry)"""
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create and queue a prompt
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
task_id = await client.queue_and_forget_prompt_api(prompt)
# Poll until done
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
assert status_code == 200, f"Expected 200, got {status_code}"
assert result is not None, "Result should not be None"
# Per OpenAPI spec, GET should return Outputs directly, not wrapped
# result should be a dict with node IDs as keys
assert isinstance(result, dict), "Result should be a dict (Outputs)"
# Should NOT have 'prompt', 'outputs', 'status' keys (those are in history entry)
# Should have node IDs directly
for key in result.keys():
# Node IDs are typically numeric strings like "4", "13", etc.
# Should not be "prompt", "outputs", "status"
assert key not in ["prompt", "status"], \
f"GET endpoint should return Outputs directly, not history entry. Found key: {key}"
@pytest.mark.asyncio
async def test_api_execution_error_blocking_mode(frontend_backend_worker_with_rabbitmq):
"""Test that execution errors (not validation) return proper error structure in blocking mode"""
from comfy_execution.graph_utils import GraphBuilder
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt that will fail during execution (not validation)
# Use Regex with a group name that doesn't exist - validation passes but execution fails
g = GraphBuilder()
regex_match = g.node("Regex", pattern="hello", string="hello world")
# Request a non-existent group name - this will pass validation but fail during execution
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
prompt = g.finalize()
prompt_json = client._AsyncRemoteComfyClient__json_encoder.encode(prompt)
async with client.session.post(
f"{frontend_backend_worker_with_rabbitmq}/api/v1/prompts",
data=prompt_json,
headers={'Content-Type': 'application/json', 'Accept': 'application/json'}
) as response:
# Execution errors return 500
assert response.status == 500, f"Expected 500 for execution error, got {response.status}"
error_body = await response.json()
# Verify ExecutionStatus structure
assert "status_str" in error_body, "Missing 'status_str'"
assert "completed" in error_body, "Missing 'completed'"
assert "messages" in error_body, "Missing 'messages'"
assert error_body["status_str"] == "error", f"Expected 'error', got {error_body['status_str']}"
assert error_body["completed"] == False, "completed should be False for errors"
assert isinstance(error_body["messages"], list), "messages should be a list"
assert len(error_body["messages"]) > 0, "messages should contain error details"
@pytest.mark.asyncio
async def test_api_execution_error_async_mode(frontend_backend_worker_with_rabbitmq):
"""Test that execution errors return proper error structure in respond-async mode"""
from comfy_execution.graph_utils import GraphBuilder
async with AsyncRemoteComfyClient(server_address=frontend_backend_worker_with_rabbitmq) as client:
# Create a prompt that will fail during execution (not validation)
# Use Regex with a group name that doesn't exist - validation passes but execution fails
g = GraphBuilder()
regex_match = g.node("Regex", pattern="hello", string="hello world")
# Request a non-existent group name - this will pass validation but fail during execution
match_group = g.node("RegexMatchGroupByName", match=regex_match.out(0), name="nonexistent_group")
g.node("SaveString", value=match_group.out(0), filename_prefix="test")
prompt = g.finalize()
# Queue with respond-async
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None, "Should get task_id in async mode"
# Poll for completion
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
# In async mode with polling, errors come back as 200 with error in the response body
# because the prompt was accepted (202) and we're just retrieving the completed result
assert status_code in (200, 500), f"Expected 200 or 500, got {status_code}"
if status_code == 500:
# Error returned directly - should be ExecutionStatus
assert "status_str" in result, "Missing 'status_str'"
assert "completed" in result, "Missing 'completed'"
assert "messages" in result, "Missing 'messages'"
assert result["status_str"] == "error"
assert result["completed"] == False
assert len(result["messages"]) > 0
else:
# Error in successful response - result might be ExecutionStatus or empty outputs
# If it's a dict with status info, verify it
if "status_str" in result:
assert result["status_str"] == "error"
assert result["completed"] == False
assert len(result["messages"]) > 0

View File

@ -0,0 +1,128 @@
import asyncio
import logging
logging.basicConfig(level=logging.ERROR)
import uuid
import pytest
from testcontainers.rabbitmq import RabbitMqContainer
from opentelemetry import trace, propagate, context
from opentelemetry.trace import SpanKind
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus
from comfy.distributed.server_stub import ServerStub
async def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
item_id = str(uuid.uuid4())
validation_tuple = await validate_prompt(item_id, prompt)
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
return QueueItem(queue_tuple, None)
@pytest.mark.asyncio
async def test_rabbitmq_message_properties_contain_trace_context():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
import aio_pika
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
tracer = trace.get_tracer(__name__, tracer_provider=provider)
with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER):
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
queue_item = await create_test_prompt()
put_task = asyncio.create_task(frontend.put_async(queue_item))
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
assert incoming is not None, "Worker should receive message"
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
result = await put_task
assert result is not None, "Frontend should get result"
# Now inspect the RabbitMQ queue directly to see message structure
connection = await aio_pika.connect_robust(connection_uri)
channel = await connection.channel()
# Declare a test queue to inspect message format
test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True)
# Publish a test message with trace context
carrier = {}
propagate.inject(carrier)
test_message = aio_pika.Message(
body=b"test",
headers=carrier
)
await channel.default_exchange.publish(
test_message,
routing_key=test_queue.name
)
# Get and inspect the message
received = await test_queue.get(timeout=2, fail=False)
if received:
headers = received.headers or {}
# Document what trace headers should be present
# OpenTelemetry uses 'traceparent' header for W3C Trace Context
has_traceparent = "traceparent" in headers
assert has_traceparent
await received.ack()
await connection.close()
@pytest.mark.asyncio
async def test_distributed_queue_uses_async_interface():
"""
Test that demonstrates the correct way to use DistributedPromptQueue in async context.
The synchronous get() method cannot be used in async tests due to event loop assertions.
"""
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
queue_item = await create_test_prompt()
# Start consuming in background
result_future = asyncio.create_task(frontend.put_async(queue_item))
# Worker gets item asynchronously (not using blocking get())
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
assert incoming is not None, "Should receive a queue item"
# Complete the work
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
# Wait for frontend to complete
result = await result_future
assert result is not None, "Should get result from worker"
assert result.status.status_str == "success"

View File

@ -0,0 +1,569 @@
"""
Integration tests for distributed tracing across RabbitMQ and services.
These tests validate that trace context propagates correctly from frontend
to backend workers through RabbitMQ, and that Jaeger can reconstruct the
full distributed trace.
"""
import asyncio
import logging
import os
import tempfile
import time
import uuid
import pytest
import requests
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.attributes import service_attributes
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.nginx import NginxContainer
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class JaegerContainer(DockerContainer):
"""Testcontainer for Jaeger all-in-one with OTLP support."""
def __init__(self, image: str = "jaegertracing/all-in-one:latest"):
super().__init__(image)
self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP
self.with_env("COLLECTOR_OTLP_ENABLED", "true")
def get_query_url(self) -> str:
"""Get Jaeger Query API URL."""
host = self.get_container_host_ip()
port = self.get_exposed_port(16686)
return f"http://{host}:{port}"
def get_otlp_endpoint(self) -> str:
"""Get OTLP HTTP endpoint for sending traces."""
host = self.get_container_host_ip()
port = self.get_exposed_port(4318)
return f"http://{host}:{port}"
def start(self):
super().start()
wait_for_logs(self, ".*Starting GRPC server.*", timeout=30)
return self
@pytest.fixture(scope="function")
def nginx_proxy(frontend_backend_worker_with_rabbitmq):
"""
Provide an nginx proxy in front of the ComfyUI frontend.
This tests if nginx is blocking W3C trace context propagation.
"""
import socket
import subprocess
# Extract host and port from frontend address
frontend_url = frontend_backend_worker_with_rabbitmq
# frontend_url is like "http://127.0.0.1:19001"
import re
match = re.match(r'http://([^:]+):(\d+)', frontend_url)
if not match:
raise ValueError(f"Could not parse frontend URL: {frontend_url}")
frontend_host = match.group(1)
frontend_port = match.group(2)
nginx_port = 8085
# Get the Docker bridge gateway IP (this is how containers reach the host on Linux)
# Try to get the default Docker bridge gateway
try:
result = subprocess.run(
["docker", "network", "inspect", "bridge", "-f", "{{range .IPAM.Config}}{{.Gateway}}{{end}}"],
capture_output=True,
text=True,
check=True
)
docker_gateway = result.stdout.strip()
logger.info(f"Using Docker gateway IP: {docker_gateway}")
except Exception as e:
# Fallback: try common gateway IPs
docker_gateway = "172.17.0.1" # Default Docker bridge gateway on Linux
logger.warning(f"Could not detect Docker gateway, using default: {docker_gateway}")
# Create nginx config that proxies to the frontend and passes trace headers
nginx_conf = f"""
events {{
worker_connections 1024;
}}
http {{
upstream backend {{
server {docker_gateway}:{frontend_port};
}}
server {{
listen {nginx_port};
location / {{
proxy_pass http://backend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}}
}}
}}
"""
# Write config to a temporary file
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as f:
f.write(nginx_conf)
nginx_conf_path = f.name
try:
# Start nginx container with the config
nginx = NginxContainer(port=nginx_port)
nginx.with_volume_mapping(nginx_conf_path, "/etc/nginx/nginx.conf")
nginx.start()
# Get the nginx URL
host = nginx.get_container_host_ip()
port = nginx.get_exposed_port(nginx_port)
nginx_url = f"http://{host}:{port}"
logger.info(f"Nginx proxy started at {nginx_url} -> {frontend_url}")
# Wait for nginx to be ready
for _ in range(30):
try:
response = requests.get(nginx_url, timeout=1)
if response.status_code:
break
except Exception:
pass
time.sleep(0.5)
yield nginx_url
finally:
nginx.stop()
os.unlink(nginx_conf_path)
@pytest.fixture(scope="module")
def jaeger_container():
"""
Provide a Jaeger container for collecting traces.
This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the
Jaeger container, and cleans it up when the container stops.
"""
container = JaegerContainer()
container.start()
# Wait for Jaeger to be fully ready
query_url = container.get_query_url()
otlp_endpoint = container.get_otlp_endpoint()
for _ in range(30):
try:
response = requests.get(f"{query_url}/api/services")
if response.status_code == 200:
logger.info(f"Jaeger ready at {query_url}")
logger.info(f"OTLP endpoint: {otlp_endpoint}")
break
except Exception:
pass
time.sleep(1)
# Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test
old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint
logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}")
try:
yield container
finally:
# Restore original OTEL_EXPORTER_OTLP_ENDPOINT
if old_endpoint is not None:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint
logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}")
else:
os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None)
logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT")
container.stop()
def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None,
lookback: str = "1h", limit: int = 100) -> dict:
"""
Query Jaeger for traces.
Args:
jaeger_url: Base URL of Jaeger query service
service: Service name to query
operation: Optional operation name filter
lookback: Lookback period (e.g., "1h", "30m")
limit: Maximum number of traces to return
Returns:
JSON response from Jaeger API
"""
params = {
"service": service,
"lookback": lookback,
"limit": limit
}
if operation:
params["operation"] = operation
response = requests.get(f"{jaeger_url}/api/traces", params=params)
response.raise_for_status()
return response.json()
def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict:
"""Find a specific trace by operation name."""
for trace in traces_response.get("data", []):
for span in trace.get("spans", []):
if span.get("operationName") == operation_name:
return trace
return None
def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool:
"""
Verify that a trace spans multiple services and maintains parent-child relationships.
Args:
trace: Jaeger trace object
expected_services: List of service names expected in the trace
Returns:
True if trace shows proper distributed tracing across services
"""
if not trace:
return False
spans = trace.get("spans", [])
if not spans:
return False
# Check that all expected services are present
trace_services = set()
for span in spans:
process_id = span.get("processID")
if process_id:
process = trace.get("processes", {}).get(process_id, {})
service_name = process.get("serviceName")
if service_name:
trace_services.add(service_name)
logger.info(f"Trace contains services: {trace_services}")
logger.info(f"Expected services: {set(expected_services)}")
# Verify all expected services are present
for service in expected_services:
if service not in trace_services:
logger.warning(f"Expected service '{service}' not found in trace")
return False
# Verify all spans share the same trace ID
trace_ids = set(span.get("traceID") for span in spans)
if len(trace_ids) != 1:
logger.warning(f"Multiple trace IDs found: {trace_ids}")
return False
# Verify parent-child relationships exist
span_ids = {span.get("spanID") for span in spans}
has_parent_refs = False
for span in spans:
references = span.get("references", [])
for ref in references:
if ref.get("refType") == "CHILD_OF":
parent_span_id = ref.get("spanID")
if parent_span_id in span_ids:
has_parent_refs = True
logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}")
if not has_parent_refs:
logger.warning("No parent-child relationships found in trace")
return False
return True
# order matters, execute jaeger_container first
@pytest.mark.asyncio
async def test_tracing_integration(jaeger_container, nginx_proxy):
"""
Integration test for distributed tracing across services with nginx proxy.
This test:
1. Starts ComfyUI frontend and worker with RabbitMQ
2. Starts nginx proxy in front of the frontend to test trace context propagation through nginx
3. Configures OTLP export to Jaeger testcontainer
4. Submits a workflow through the nginx proxy
5. Queries Jaeger to verify trace propagation
6. Validates that the trace spans multiple services with proper relationships
This specifically tests if nginx is blocking W3C trace context (traceparent/tracestate headers).
"""
server_address = nginx_proxy
jaeger_url = jaeger_container.get_query_url()
otlp_endpoint = jaeger_container.get_otlp_endpoint()
logger.info(f"Frontend server: {server_address}")
logger.info(f"Jaeger UI: {jaeger_url}")
logger.info(f"OTLP endpoint: {otlp_endpoint}")
# Set up tracing for the async HTTP client
resource = Resource.create({
service_attributes.SERVICE_NAME: "comfyui-client",
})
provider = TracerProvider(resource=resource)
exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces")
processor = BatchSpanProcessor(exporter)
provider.add_span_processor(processor)
from opentelemetry import trace
trace.set_tracer_provider(provider)
# Instrument aiohttp client
AioHttpClientInstrumentor().instrument()
# we have to call this very late, so that the instrumentation isn't initialized too early
from comfy.client.aio_client import AsyncRemoteComfyClient
# Note: In a real integration test, you'd need to configure the ComfyUI
# services to export traces to this Jaeger instance. For now, this test
# documents the expected behavior.
# Create a unique prompt to identify our trace
test_id = str(uuid.uuid4())[:8]
prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1)
# Get the tracer for the client
client_tracer = trace.get_tracer("test_tracing_integration")
# Submit the workflow - wrap in a span to capture the trace ID
with client_tracer.start_as_current_span("submit_workflow") as workflow_span:
trace_id = format(workflow_span.get_span_context().trace_id, '032x')
logger.info(f"Started trace with trace_id: {trace_id}")
async with AsyncRemoteComfyClient(server_address=server_address) as client:
logger.info(f"Submitting workflow with test_id: {test_id}")
# Queue the prompt with async response
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None, "Failed to get task ID"
logger.info(f"Queued task: {task_id}")
# Poll for completion
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
assert status_code == 200, f"Task failed with status {status_code}"
logger.info("Task completed successfully")
# Give Jaeger time to receive and process spans
await asyncio.sleep(5)
# Query Jaeger for traces
# Note: The actual service names depend on how your services are configured
# Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker"
expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names
logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...")
# First, try to find our specific trace by trace_id from the client service
our_trace = None
for service in expected_services:
try:
traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m")
if traces_response.get("data"):
logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'")
for trace in traces_response["data"]:
if trace.get("traceID") == trace_id:
our_trace = trace
logger.info(f"Found our trace in service '{service}'")
break
if our_trace:
break
except Exception as e:
logger.warning(f"Could not query traces for service '{service}': {e}")
# Assert we can find the trace we just created
assert our_trace is not None, (
f"Could not find trace with trace_id {trace_id} in Jaeger. "
f"This indicates that spans from comfyui-client are not being exported correctly."
)
logger.info(f"Successfully found trace with trace_id {trace_id}")
# Extract services from the trace
trace_services = set()
for span in our_trace.get("spans", []):
process_id = span.get("processID")
if process_id:
process = our_trace.get("processes", {}).get(process_id, {})
service_name = process.get("serviceName")
if service_name:
trace_services.add(service_name)
logger.info(f"Services found in trace: {trace_services}")
# Assert that comfyui-client service is present (since we instrumented it)
assert "comfyui-client" in trace_services, (
f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. "
f"This indicates the client instrumentation is not working."
)
# Validate trace structure
logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans")
# Log all spans for debugging
for span in our_trace.get("spans", []):
process_id = span.get("processID")
process = our_trace.get("processes", {}).get(process_id, {})
service_name = process.get("serviceName", "unknown")
operation = span.get("operationName", "unknown")
logger.info(f" Span: {service_name}.{operation}")
# Verify trace continuity - only if both services are present
assert "comfyui" in trace_services
is_continuous = verify_trace_continuity(our_trace, expected_services)
# This assertion documents what SHOULD happen when distributed tracing works
assert is_continuous, (
"Trace does not show proper distributed tracing. "
"Expected to see spans from multiple services with parent-child relationships. "
"This indicates that trace context is not being propagated correctly through RabbitMQ."
)
@pytest.mark.asyncio
async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq):
"""
Test that HTTP requests include traceparent headers.
This validates that the HTTP layer is properly instrumented for tracing.
"""
server_address = frontend_backend_worker_with_rabbitmq
# Make a simple HTTP request and check for trace headers
# Note: We're checking the server's response headers to see if it's trace-aware
response = requests.get(f"{server_address}/system_stats")
logger.info(f"Response headers: {dict(response.headers)}")
# The server should be instrumented and may include trace context in responses
# or at minimum, should accept traceparent headers in requests
# Test sending a traceparent header
test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
response_with_trace = requests.get(
f"{server_address}/system_stats",
headers={"traceparent": test_traceparent}
)
# Should not error when traceparent is provided
assert response_with_trace.status_code == 200, "Server should accept traceparent header"
logger.info("✓ Server accepts traceparent headers in HTTP requests")
@pytest.mark.asyncio
async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container):
"""
Test that multiple independent requests create separate traces.
This validates that trace context is properly scoped per request.
"""
server_address = frontend_backend_worker_with_rabbitmq
# Submit multiple workflows
task_ids = []
from comfy.client.aio_client import AsyncRemoteComfyClient
async with AsyncRemoteComfyClient(server_address=server_address) as client:
for i in range(3):
prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1)
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
task_ids.append(task_id)
logger.info(f"Queued task {i}: {task_id}")
# Wait for all to complete
for i, task_id in enumerate(task_ids):
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
assert status_code == 200, f"Task {i} failed"
logger.info(f"Task {i} completed")
# Give Jaeger time to receive spans
await asyncio.sleep(5)
# Query Jaeger and verify we have multiple distinct traces
jaeger_url = jaeger_container.get_query_url()
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10)
traces = traces_response.get("data", [])
assert len(traces) >= 2
# Get trace IDs
trace_ids = [trace.get("traceID") for trace in traces]
unique_trace_ids = set(trace_ids)
logger.info(f"Found {len(unique_trace_ids)} unique traces")
# Verify we have multiple distinct traces
assert len(unique_trace_ids) >= 2, (
f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. "
"Each request should create its own trace."
)
logger.info("✓ Multiple requests created distinct traces")
@pytest.mark.asyncio
@pytest.mark.skip(reason="rabbitmq has to be configured for observability?")
async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container):
"""
Test that traces include RabbitMQ publish/consume operations.
This is critical for distributed tracing - the RabbitMQ operations
are what link the frontend and backend spans together.
"""
server_address = frontend_backend_worker_with_rabbitmq
jaeger_url = jaeger_container.get_query_url()
# Submit a workflow
from comfy.client.aio_client import AsyncRemoteComfyClient
async with AsyncRemoteComfyClient(server_address=server_address) as client:
prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1)
task_id = await client.queue_and_forget_prompt_api(prompt)
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60)
assert status_code == 200
await asyncio.sleep(5)
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m")
traces = traces_response.get("data", [])
# Look for RabbitMQ-related operations in any trace
rabbitmq_operations = [
"publish", "consume", "amq_queue_publish", "amq_queue_consume",
"amq.basic.publish", "amq.basic.consume", "send", "receive"
]
found_rabbitmq_ops = []
for trace in traces:
for span in trace.get("spans", []):
op_name = span.get("operationName", "").lower()
for rmq_op in rabbitmq_operations:
if rmq_op in op_name:
found_rabbitmq_ops.append(op_name)
assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces"

View File

@ -0,0 +1,683 @@
import pytest
from unittest.mock import Mock, patch
from comfy.cli_args import default_configuration
from comfy.execution_context import context_configuration
from comfy_extras.nodes.nodes_eval import (
eval_python,
EvalPython_5_5,
EvalPython_List_1,
EvalPython_1_List,
EvalPython_List_List,
)
@pytest.fixture
def eval_context():
"""Fixture that sets up execution context with eval enabled"""
config = default_configuration()
config.enable_eval = True
with context_configuration(config):
yield
def test_eval_python_basic_return(eval_context):
"""Test basic return statement with single value"""
node = EvalPython_5_5()
result = node.exec_py(pycode="return 42", value0=0, value1=1, value2=2, value3=3, value4=4)
assert result == (42, None, None, None, None)
def test_eval_python_multiple_returns(eval_context):
"""Test return statement with tuple of values"""
node = EvalPython_5_5()
result = node.exec_py(
pycode="return 1, 2, 3",
value0=0, value1=0, value2=0, value3=0, value4=0
)
assert result == (1, 2, 3, None, None)
def test_eval_python_all_five_returns(eval_context):
"""Test return statement with all five values"""
node = EvalPython_5_5()
result = node.exec_py(
pycode="return 'a', 'b', 'c', 'd', 'e'",
value0=0, value1=0, value2=0, value3=0, value4=0
)
assert result == ('a', 'b', 'c', 'd', 'e')
def test_eval_python_excess_returns(eval_context):
"""Test that excess return values are truncated to 5"""
node = EvalPython_5_5()
result = node.exec_py(
pycode="return 1, 2, 3, 4, 5, 6, 7",
value0=0, value1=0, value2=0, value3=0, value4=0
)
assert result == (1, 2, 3, 4, 5)
def test_eval_python_use_value_args(eval_context):
"""Test that value arguments are accessible in pycode"""
node = EvalPython_5_5()
result = node.exec_py(
pycode="return value0 + value1 + value2",
value0=10, value1=20, value2=30, value3=0, value4=0
)
assert result == (60, None, None, None, None)
def test_eval_python_all_value_args(eval_context):
"""Test all value arguments are accessible"""
node = EvalPython_5_5()
result = node.exec_py(
pycode="return value0, value1, value2, value3, value4",
value0=1, value1=2, value2=3, value3=4, value4=5
)
assert result == (1, 2, 3, 4, 5)
def test_eval_python_computation(eval_context):
"""Test computation with value arguments"""
node = EvalPython_5_5()
code = """
x = value0 * 2
y = value1 * 3
z = x + y
return z
"""
result = node.exec_py(
pycode=code,
value0=5, value1=10, value2=0, value3=0, value4=0
)
assert result == (40, None, None, None, None)
def test_eval_python_multiline(eval_context):
"""Test multiline code with conditionals"""
node = EvalPython_5_5()
code = """
if value0 > 10:
result = "large"
else:
result = "small"
return result, value0
"""
result = node.exec_py(
pycode=code,
value0=15, value1=0, value2=0, value3=0, value4=0
)
assert result == ("large", 15, None, None, None)
def test_eval_python_list_comprehension(eval_context):
"""Test list comprehension and iteration"""
node = EvalPython_5_5()
code = """
numbers = [value0, value1, value2]
doubled = [x * 2 for x in numbers]
return sum(doubled)
"""
result = node.exec_py(
pycode=code,
value0=1, value1=2, value2=3, value3=0, value4=0
)
assert result == (12, None, None, None, None)
def test_eval_python_string_operations(eval_context):
"""Test string operations"""
node = EvalPython_5_5()
code = """
s1 = str(value0)
s2 = str(value1)
return s1 + s2, len(s1 + s2)
"""
result = node.exec_py(
pycode=code,
value0=123, value1=456, value2=0, value3=0, value4=0
)
assert result == ("123456", 6, None, None, None)
def test_eval_python_type_mixing(eval_context):
"""Test mixing different types"""
node = EvalPython_5_5()
code = """
return value0, str(value1), float(value2), bool(value3)
"""
result = node.exec_py(
pycode=code,
value0=42, value1=100, value2=3, value3=1, value4=0
)
assert result == (42, "100", 3.0, True, None)
def test_eval_python_logger_available(eval_context):
"""Test that logger is available in eval context"""
node = EvalPython_5_5()
code = """
logger.info("test log")
return "success"
"""
result = node.exec_py(
pycode=code,
value0=0, value1=0, value2=0, value3=0, value4=0
)
assert result == ("success", None, None, None, None)
def test_eval_python_print_available(eval_context):
"""Test that print function is available"""
node = EvalPython_5_5()
code = """
print("Hello World!")
return "printed"
"""
result = node.exec_py(
pycode=code,
value0=0, value1=0, value2=0, value3=0, value4=0
)
assert result == ("printed", None, None, None, None)
def test_eval_python_print_is_called(eval_context):
"""Test that print function is called and receives correct arguments"""
node = EvalPython_5_5()
# Track print calls
print_calls = []
code = """
print("Hello", "World")
print("Line 2")
return "done"
"""
# Mock exec to capture the globals dict and verify print is there
original_exec = exec
captured_globals = {}
def mock_exec(code_str, globals_dict, *args, **kwargs):
# Capture the globals dict
captured_globals.update(globals_dict)
# Wrap the print function to track calls
original_print = globals_dict.get('print')
if original_print:
def tracked_print(*args):
print_calls.append(args)
return original_print(*args)
globals_dict['print'] = tracked_print
# Run the original exec
return original_exec(code_str, globals_dict, *args, **kwargs)
with patch('builtins.exec', side_effect=mock_exec):
result = node.exec_py(
pycode=code,
value0=0, value1=0, value2=0, value3=0, value4=0
)
# Verify the result
assert result == ("done", None, None, None, None)
# Verify print was in the globals
assert 'print' in captured_globals
# Verify print was called twice with correct arguments
assert len(print_calls) == 2
assert print_calls[0] == ("Hello", "World")
assert print_calls[1] == ("Line 2",)
def test_eval_python_print_sends_to_server(eval_context):
"""Test that print sends messages to PromptServer via context"""
from comfy.execution_context import current_execution_context
node = EvalPython_5_5()
ctx = current_execution_context()
# Mock the server's send_progress_text method
original_send = ctx.server.send_progress_text if hasattr(ctx.server, 'send_progress_text') else None
mock_send = Mock()
ctx.server.send_progress_text = mock_send
code = """
print("Hello", "World")
print("Value:", value0)
return "done"
"""
try:
result = node.exec_py(
pycode=code,
value0=42, value1=0, value2=0, value3=0, value4=0
)
# Verify the result
assert result == ("done", None, None, None, None)
# Verify print messages were sent to server
assert mock_send.call_count == 2
# Verify the messages sent
calls = mock_send.call_args_list
assert calls[0][0][0] == "Hello World"
assert calls[0][0][1] == ctx.node_id
assert calls[1][0][0] == "Value: 42"
assert calls[1][0][1] == ctx.node_id
finally:
# Restore original
if original_send:
ctx.server.send_progress_text = original_send
def test_eval_python_config_disabled_raises():
"""Test that enable_eval=False raises an error"""
node = EvalPython_5_5()
config = default_configuration()
config.enable_eval = False
with context_configuration(config):
with pytest.raises(ValueError, match="Python eval is disabled"):
node.exec_py(
pycode="return 42",
value0=0, value1=0, value2=0, value3=0, value4=0
)
def test_eval_python_config_enabled_works(eval_context):
"""Test that enable_eval=True allows execution"""
node = EvalPython_5_5()
result = node.exec_py(
pycode="return 42",
value0=0, value1=0, value2=0, value3=0, value4=0
)
assert result == (42, None, None, None, None)
def test_eval_python_default_code(eval_context):
"""Test the default code example works"""
node = EvalPython_5_5()
# Get the default code from INPUT_TYPES
default_code = EvalPython_5_5.INPUT_TYPES()["required"]["pycode"][1]["default"]
result = node.exec_py(
pycode=default_code,
value0=1, value1=2, value2=3, value3=4, value4=5
)
# Default code prints and returns the values
assert result == (1, 2, 3, 4, 5)
def test_eval_python_function_definition(eval_context):
"""Test defining and using functions"""
node = EvalPython_5_5()
code = """
def multiply(a, b):
return a * b
result = multiply(value0, value1)
return result
"""
result = node.exec_py(
pycode=code,
value0=7, value1=8, value2=0, value3=0, value4=0
)
assert result == (56, None, None, None, None)
def test_eval_python_nested_functions(eval_context):
"""Test nested function definitions"""
node = EvalPython_5_5()
code = """
def outer(x):
def inner(y):
return y * 2
return inner(x) + 10
result = outer(value0)
return result
"""
result = node.exec_py(
pycode=code,
value0=5, value1=0, value2=0, value3=0, value4=0
)
assert result == (20, None, None, None, None)
def test_eval_python_dict_operations(eval_context):
"""Test dictionary creation and operations"""
node = EvalPython_5_5()
code = """
data = {
'a': value0,
'b': value1,
'c': value2
}
return sum(data.values()), len(data)
"""
result = node.exec_py(
pycode=code,
value0=10, value1=20, value2=30, value3=0, value4=0
)
assert result == (60, 3, None, None, None)
def test_eval_python_list_operations(eval_context):
"""Test list creation and operations"""
node = EvalPython_5_5()
code = """
items = [value0, value1, value2, value3, value4]
filtered = [x for x in items if x > 5]
return len(filtered), sum(filtered)
"""
result = node.exec_py(
pycode=code,
value0=1, value1=10, value2=3, value3=15, value4=2
)
assert result == (2, 25, None, None, None)
def test_eval_python_early_return(eval_context):
"""Test early return in conditional"""
node = EvalPython_5_5()
code = """
if value0 > 100:
return "large"
return "small"
"""
result = node.exec_py(
pycode=code,
value0=150, value1=0, value2=0, value3=0, value4=0
)
assert result == ("large", None, None, None, None)
def test_eval_python_loop_with_return(eval_context):
"""Test loop with return statement"""
node = EvalPython_5_5()
code = """
total = 0
for i in range(value0):
total += i
return total
"""
result = node.exec_py(
pycode=code,
value0=10, value1=0, value2=0, value3=0, value4=0
)
assert result == (45, None, None, None, None)
def test_eval_python_exception_handling(eval_context):
"""Test try/except blocks"""
node = EvalPython_5_5()
code = """
try:
result = value0 / value1
except ZeroDivisionError:
result = float('inf')
return result
"""
result = node.exec_py(
pycode=code,
value0=10, value1=0, value2=0, value3=0, value4=0
)
assert result == (float('inf'), None, None, None, None)
def test_eval_python_none_values(eval_context):
"""Test handling None values in inputs"""
node = EvalPython_5_5()
code = """
return value0, value1 is None, value2 is None
"""
result = node.exec_py(
pycode=code,
value0=42, value1=None, value2=None, value3=0, value4=0
)
assert result == (42, True, True, None, None)
def test_eval_python_input_types():
"""Test that INPUT_TYPES returns correct structure"""
input_types = EvalPython_5_5.INPUT_TYPES()
assert "required" in input_types
assert "optional" in input_types
assert "pycode" in input_types["required"]
assert input_types["required"]["pycode"][0] == "CODE_BLOCK_PYTHON"
# Check optional inputs
for i in range(5):
assert f"value{i}" in input_types["optional"]
def test_eval_python_metadata():
"""Test node metadata"""
assert EvalPython_5_5.FUNCTION == "exec_py"
assert EvalPython_5_5.CATEGORY == "eval"
assert len(EvalPython_5_5.RETURN_TYPES) == 5
assert len(EvalPython_5_5.RETURN_NAMES) == 5
assert all(name.startswith("item") for name in EvalPython_5_5.RETURN_NAMES)
def test_eval_python_factory_custom_inputs_outputs(eval_context):
"""Test creating nodes with custom input/output counts"""
# Create a node with 3 inputs and 2 outputs
CustomNode = eval_python(inputs=3, outputs=2)
node = CustomNode()
# Verify INPUT_TYPES has correct number of inputs
input_types = CustomNode.INPUT_TYPES()
assert len(input_types["optional"]) == 3
assert "value0" in input_types["optional"]
assert "value1" in input_types["optional"]
assert "value2" in input_types["optional"]
assert "value3" not in input_types["optional"]
# Verify RETURN_TYPES has correct number of outputs
assert len(CustomNode.RETURN_TYPES) == 2
assert len(CustomNode.RETURN_NAMES) == 2
# Test execution
result = node.exec_py(
pycode="return value0 + value1 + value2, value0 * 2",
value0=1, value1=2, value2=3
)
assert result == (6, 2)
def test_eval_python_factory_custom_name(eval_context):
"""Test creating nodes with custom names"""
CustomNode = eval_python(inputs=2, outputs=2, name="MyCustomEval")
assert CustomNode.__name__ == "MyCustomEval"
assert CustomNode.__qualname__ == "MyCustomEval"
def test_eval_python_factory_default_name(eval_context):
"""Test that default name follows pattern"""
CustomNode = eval_python(inputs=3, outputs=4)
assert CustomNode.__name__ == "EvalPython_3_4"
assert CustomNode.__qualname__ == "EvalPython_3_4"
def test_eval_python_factory_single_output(eval_context):
"""Test node with single output"""
SingleOutputNode = eval_python(inputs=2, outputs=1)
node = SingleOutputNode()
result = node.exec_py(
pycode="return value0 + value1",
value0=10, value1=20
)
assert result == (30,)
def test_eval_python_factory_many_outputs(eval_context):
"""Test node with many outputs"""
ManyOutputNode = eval_python(inputs=1, outputs=10)
node = ManyOutputNode()
result = node.exec_py(
pycode="return tuple(range(10))",
value0=0
)
assert result == (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
def test_eval_python_factory_fewer_returns_than_outputs(eval_context):
"""Test that fewer returns are padded with None"""
Node = eval_python(inputs=2, outputs=5)
node = Node()
result = node.exec_py(
pycode="return value0, value1",
value0=1, value1=2
)
assert result == (1, 2, None, None, None)
def test_eval_python_factory_more_returns_than_outputs(eval_context):
"""Test that excess returns are truncated"""
Node = eval_python(inputs=2, outputs=3)
node = Node()
result = node.exec_py(
pycode="return 1, 2, 3, 4, 5",
value0=0, value1=0
)
assert result == (1, 2, 3)
def test_eval_python_list_1_input_is_list(eval_context):
"""Test EvalPython_List_1 with list input"""
node = EvalPython_List_1()
# Verify INPUT_IS_LIST is set
assert EvalPython_List_1.INPUT_IS_LIST is True
# Test that value0 receives a list
result = node.exec_py(
pycode="return sum(value0)",
value0=[1, 2, 3, 4, 5]
)
assert result == (15,)
def test_eval_python_list_1_iterate_list(eval_context):
"""Test EvalPython_List_1 iterating over list input"""
node = EvalPython_List_1()
result = node.exec_py(
pycode="return [x * 2 for x in value0]",
value0=[1, 2, 3]
)
assert result == ([2, 4, 6],)
def test_eval_python_1_list_output_is_list(eval_context):
"""Test EvalPython_1_List with list output"""
node = EvalPython_1_List()
# Verify OUTPUT_IS_LIST is set
assert EvalPython_1_List.OUTPUT_IS_LIST == (True,)
# Test that returns a list
result = node.exec_py(
pycode="return list(range(value0))",
value0=5
)
assert result == ([0, 1, 2, 3, 4],)
def test_eval_python_1_list_multiple_items(eval_context):
"""Test EvalPython_1_List returning multiple items in list"""
node = EvalPython_1_List()
result = node.exec_py(
pycode="return ['a', 'b', 'c']",
value0=0
)
assert result == (['a', 'b', 'c'],)
def test_eval_python_list_list_both(eval_context):
"""Test EvalPython_List_List with both list input and output"""
node = EvalPython_List_List()
# Verify both are set
assert EvalPython_List_List.INPUT_IS_LIST is True
assert EvalPython_List_List.OUTPUT_IS_LIST == (True,)
# Test processing list input and returning list output
result = node.exec_py(
pycode="return [x ** 2 for x in value0]",
value0=[1, 2, 3, 4]
)
assert result == ([1, 4, 9, 16],)
def test_eval_python_list_list_filter(eval_context):
"""Test EvalPython_List_List filtering a list"""
node = EvalPython_List_List()
result = node.exec_py(
pycode="return [x for x in value0 if x > 5]",
value0=[1, 3, 5, 7, 9, 11]
)
assert result == ([7, 9, 11],)
def test_eval_python_list_list_transform(eval_context):
"""Test EvalPython_List_List transforming list elements"""
node = EvalPython_List_List()
result = node.exec_py(
pycode="return [str(x).upper() for x in value0]",
value0=['hello', 'world', 'python']
)
assert result == (['HELLO', 'WORLD', 'PYTHON'],)
def test_eval_python_factory_with_list_flags(eval_context):
"""Test factory function with custom list flags"""
# Create node with input as list but output scalar
ListInputNode = eval_python(inputs=1, outputs=1, input_is_list=True, output_is_list=None)
assert ListInputNode.INPUT_IS_LIST is True
node = ListInputNode()
result = node.exec_py(
pycode="return len(value0)",
value0=[1, 2, 3, 4, 5]
)
assert result == (5,)
def test_eval_python_factory_scalar_output_list(eval_context):
"""Test factory function with scalar input and list output"""
ScalarToListNode = eval_python(inputs=1, outputs=1, input_is_list=None, output_is_list=(True,))
assert ScalarToListNode.OUTPUT_IS_LIST == (True,)
node = ScalarToListNode()
result = node.exec_py(
pycode="return [value0] * 3",
value0='x'
)
assert result == (['x', 'x', 'x'],)
def test_eval_python_list_empty_list(eval_context):
"""Test list nodes with empty lists"""
node = EvalPython_List_List()
result = node.exec_py(
pycode="return []",
value0=[]
)
assert result == ([],)