diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 389fd2b6b..3e90f8e61 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -1,5 +1,7 @@ from __future__ import annotations +from ..cmd.main_pre import tracer + import asyncio import concurrent.futures import copy @@ -15,8 +17,6 @@ from typing import Optional, Generator from opentelemetry import context, propagate from opentelemetry.context import Context, attach, detach from opentelemetry.trace import Status, StatusCode - -from ..cmd.main_pre import tracer from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 5ad5865c2..8d3ba84ae 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -1,5 +1,7 @@ from __future__ import annotations +from .main_pre import tracer + import asyncio import copy import heapq @@ -16,8 +18,6 @@ from enum import Enum from os import PathLike from typing import List, Optional, Tuple, Literal -# order matters -from .main_pre import tracer import torch from opentelemetry.trace import get_current_span, StatusCode, Status diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index a921a8eed..cade11350 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,3 +1,5 @@ +from .main_pre import tracer + import asyncio import contextvars import gc @@ -9,9 +11,6 @@ import threading import time from pathlib import Path from typing import Optional - -# main_pre must be the earliest import -from .main_pre import tracer from ..cli_args_types import Configuration from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp from ..execution_context import current_execution_context diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 7c06e4cee..af1a066d7 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -118,8 +118,9 @@ def _fix_pytorch_240(): def _create_tracer(): - from opentelemetry import trace + from opentelemetry import trace, metrics from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.semconv.attributes import service_attributes @@ -127,6 +128,8 @@ def _create_tracer(): from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor @@ -142,6 +145,10 @@ def _create_tracer(): sampler = ProgressSpanSampler() provider = TracerProvider(resource=resource, sampler=sampler) + # Set the global tracer provider FIRST, before instrumenting + # This ensures instrumentors can access the provider + trace.set_tracer_provider(provider) + has_endpoint = args.otel_exporter_otlp_endpoint is not None if has_endpoint: @@ -152,9 +159,21 @@ def _create_tracer(): processor = BatchSpanProcessor(exporter) provider.add_span_processor(processor) - # enable instrumentation + # Set up metrics export to track dropped spans + # Only enable if OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is set, since not all OTLP endpoints support metrics + metrics_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT") + if metrics_endpoint: + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter(endpoint=metrics_endpoint), + export_interval_millis=10000 # Export every 10 seconds + ) + meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(meter_provider) + + # enable instrumentation BEFORE any aio_pika imports patch_spanbuilder_set_channel() + # Instrument aio_pika first since it's most likely to be imported early AioPikaInstrumentor().instrument() AioHttpServerInstrumentor().instrument() AioHttpClientInstrumentor().instrument() diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 894921057..e3af6ef5c 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -1,5 +1,7 @@ from __future__ import annotations +from ..cmd.main_pre import tracer + import asyncio import time import uuid @@ -13,7 +15,6 @@ from aio_pika import connect_robust from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.patterns import JsonRPC -from ..cmd.main_pre import tracer from .distributed_progress import ProgressHandlers from .distributed_types import RpcRequest, RpcReply from .history import History diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index d831ba734..b6bf3dad1 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -1,3 +1,5 @@ +from ..cmd.main_pre import tracer + import asyncio import logging from asyncio import AbstractEventLoop @@ -9,8 +11,6 @@ from aio_pika import connect_robust from aio_pika.patterns import JsonRPC from aiohttp import web from aiormq import AMQPConnectionError - -from ..cmd.main_pre import tracer from .executors import ContextVarExecutor from .distributed_progress import DistributedExecutorToClientProgress from .distributed_types import RpcRequest, RpcReply diff --git a/comfy/entrypoints/worker.py b/comfy/entrypoints/worker.py index 267c558a5..2139e4c81 100644 --- a/comfy/entrypoints/worker.py +++ b/comfy/entrypoints/worker.py @@ -1,6 +1,6 @@ -import asyncio - from ..cmd.main_pre import args + +import asyncio from ..component_model.file_counter import cleanup_temp from ..component_model.entrypoints_common import configure_application_paths, executor_from_args diff --git a/comfy/entrypoints/workflow.py b/comfy/entrypoints/workflow.py index b8bdee4e4..74d5b9726 100644 --- a/comfy/entrypoints/workflow.py +++ b/comfy/entrypoints/workflow.py @@ -1,11 +1,11 @@ +from ..cmd.main_pre import args + import asyncio import json import logging from typing import Optional, Literal import typer - -from ..cmd.main_pre import args from ..cli_args_types import Configuration from ..component_model.asyncio_files import stream_json_objects from ..client.embedded_comfy_client import Comfy diff --git a/comfy/model_management.py b/comfy/model_management.py index 413b16b09..bf3aea14e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -17,6 +17,8 @@ """ from __future__ import annotations +from .cmd.main_pre import tracer + import gc import logging import platform @@ -32,8 +34,6 @@ from typing import List, Sequence, Final, Optional import psutil import torch from opentelemetry.trace import get_current_span - -from .cmd.main_pre import tracer from . import interruption from .cli_args import args, PerformanceFeature from .component_model.deprecation import _deprecate_method diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 500ef7286..28f11ee69 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -1,5 +1,7 @@ from __future__ import annotations +from ..cmd.main_pre import tracer + import importlib import logging import os @@ -12,8 +14,6 @@ from importlib.metadata import entry_points from threading import RLock from opentelemetry.trace import Span, Status, StatusCode - -from ..cmd.main_pre import tracer from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.version_list import supported_versions from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports diff --git a/comfy/tracing_compatibility.py b/comfy/tracing_compatibility.py index a77f1e79b..24139fff3 100644 --- a/comfy/tracing_compatibility.py +++ b/comfy/tracing_compatibility.py @@ -31,6 +31,19 @@ def patch_spanbuilder_set_channel() -> None: class ProgressSpanSampler(Sampler): + """ + Sampler which omits aio_pika messages destined/related to progress. + + To debug which spans are being dropped, set environment variable: + OTEL_DEBUG_SAMPLER=1 + """ + + def __init__(self): + import logging + import os + self.logger = logging.getLogger(__name__) + self.debug = os.environ.get("OTEL_DEBUG_SAMPLER", "0") == "1" + def get_description(self) -> str: return "Sampler which omits aio_pika messages destined/related to progress" @@ -44,9 +57,23 @@ class ProgressSpanSampler(Sampler): links: Optional[Sequence["Link"]] = None, trace_state: Optional["TraceState"] = None, ) -> "SamplingResult": + should_drop = False + drop_reason = None + if attributes is not None and "messaging.destination" in attributes and attributes["messaging.destination"].endswith("progress"): - return SamplingResult(Decision.DROP) + should_drop = True + drop_reason = f"messaging.destination={attributes['messaging.destination']}" # the ephemeral reply channels are not required for correct span correlation - if name.startswith(",amq_") or name.startswith("amq"): + elif name.startswith(",amq_") or name.startswith("amq"): + should_drop = True + drop_reason = f"span name starts with amq: {name}" + + if should_drop: + if self.debug: + self.logger.debug(f"[SAMPLER] DROPPED span '{name}' (kind={kind}): {drop_reason}") return SamplingResult(Decision.DROP) + + if self.debug: + self.logger.debug(f"[SAMPLER] RECORDING span '{name}' (kind={kind}, attributes={attributes})") + return SamplingResult(Decision.RECORD_AND_SAMPLE) diff --git a/tests/distributed/test_tracing.py b/tests/distributed/test_tracing.py index 8c156a9bf..55c55d295 100644 --- a/tests/distributed/test_tracing.py +++ b/tests/distributed/test_tracing.py @@ -7,7 +7,7 @@ import uuid import pytest from testcontainers.rabbitmq import RabbitMqContainer -from opentelemetry import trace, propagate, context +from opentelemetry import trace, propagate from opentelemetry.trace import SpanKind from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor diff --git a/tests/distributed/test_tracing_integration.py b/tests/distributed/test_tracing_integration.py index 36a9fabd6..463a85935 100644 --- a/tests/distributed/test_tracing_integration.py +++ b/tests/distributed/test_tracing_integration.py @@ -8,6 +8,7 @@ full distributed trace. import asyncio import logging import os +import subprocess import tempfile import time import uuid @@ -23,6 +24,7 @@ from opentelemetry.semconv.attributes import service_attributes from testcontainers.core.container import DockerContainer from testcontainers.core.waiting_utils import wait_for_logs from testcontainers.nginx import NginxContainer +from testcontainers.rabbitmq import RabbitMqContainer from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner @@ -528,7 +530,6 @@ async def test_multiple_requests_different_traces(frontend_backend_worker_with_r @pytest.mark.asyncio -@pytest.mark.skip(reason="rabbitmq has to be configured for observability?") async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container): """ Test that traces include RabbitMQ publish/consume operations. @@ -566,4 +567,202 @@ async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_r if rmq_op in op_name: found_rabbitmq_ops.append(op_name) - assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces" \ No newline at end of file + assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces" + + +@pytest.mark.asyncio +async def test_aiohttp_and_aio_pika_spans_with_docker_frontend(jaeger_container): + """ + Test that both aiohttp and aio_pika instrumentation work in the Docker image. + + This test helps diagnose if there's a dependency issue in the Docker image preventing + instrumentation from working correctly by: + 1. Starting the ComfyUI frontend in a Docker container + 2. Starting a local worker process + 3. Submitting a workflow + 4. Querying Jaeger to verify both aiohttp and aio_pika spans are present + + Set COMFYUI_IMAGE env var to override default image, e.g.: + COMFYUI_IMAGE=ghcr.io/hiddenswitch/comfyui:latest + """ + docker_image = os.environ.get("COMFYUI_IMAGE", "ghcr.io/hiddenswitch/comfyui:latest") + + jaeger_url = jaeger_container.get_query_url() + otlp_endpoint = jaeger_container.get_otlp_endpoint() + otlp_port = jaeger_container.get_exposed_port(4318) + + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + + # Get Docker bridge gateway for container-to-host communication + try: + result = subprocess.run( + ["docker", "network", "inspect", "bridge", "-f", "{{(index .IPAM.Config 0).Gateway}}"], + capture_output=True, + text=True, + check=True, + timeout=5 + ) + docker_host = result.stdout.strip() + if not docker_host: + docker_host = "host.docker.internal" + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + docker_host = "host.docker.internal" + + connection_uri_container = f"amqp://guest:guest@{docker_host}:{params.port}" + connection_uri_local = f"amqp://guest:guest@127.0.0.1:{params.port}" + + # Start frontend in Docker container + frontend_container = DockerContainer(docker_image) + frontend_container.with_exposed_ports(8188) + + otlp_endpoint_container = f"http://{docker_host}:{otlp_port}" + env_vars = { + "OTEL_SERVICE_NAME": "comfyui-docker-frontend", + "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint_container, + } + + for key, value in env_vars.items(): + frontend_container.with_env(key, value) + + frontend_container.with_command( + f"python -m comfy.cmd.main --listen 0.0.0.0 --port 8188 " + f"--cpu --distributed-queue-frontend " + f"--distributed-queue-connection-uri={connection_uri_container}" + ) + + frontend_container.start() + + try: + frontend_host = frontend_container.get_container_host_ip() + frontend_port = frontend_container.get_exposed_port(8188) + frontend_url = f"http://{frontend_host}:{frontend_port}" + + # Wait for frontend to be ready + connected = False + for _ in range(15): + try: + response = requests.get(frontend_url, timeout=1) + if response.status_code == 200: + connected = True + break + except Exception: + pass + time.sleep(1) + + assert connected, f"Could not connect to Docker frontend at {frontend_url}" + + # Start local worker + worker_env = os.environ.copy() + worker_env["OTEL_SERVICE_NAME"] = "comfyui-worker" + worker_env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint + + worker_process = subprocess.Popen( + [ + "comfyui-worker", + "--port=19099", + f"--distributed-queue-connection-uri={connection_uri_local}", + "--executor-factory=ThreadPoolExecutor" + ], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=worker_env, + text=True, + bufsize=1 + ) + + try: + time.sleep(5) + + from comfy.client.aio_client import AsyncRemoteComfyClient + + test_id = str(uuid.uuid4())[:8] + prompt = sdxl_workflow_with_refiner(f"docker_test_{test_id}", inference_steps=1, refiner_steps=1) + + async with AsyncRemoteComfyClient(server_address=frontend_url) as client: + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None, "Failed to get task ID" + + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=2.0) + + if status_code != 200: + # Capture worker logs + worker_output = "" + if worker_process.stdout: + worker_output = worker_process.stdout.read() + + # Get frontend container logs + frontend_logs = frontend_container.get_logs() + + logger.error("=" * 80) + logger.error("TASK FAILED - Diagnostic Information:") + logger.error("=" * 80) + logger.error(f"Task ID: {task_id}") + logger.error(f"Status Code: {status_code}") + logger.error(f"Result: {result}") + logger.error("\n--- Frontend Container Logs (last 100 lines) ---") + frontend_log_lines = frontend_logs.decode('utf-8').split('\n') + for line in frontend_log_lines[-100:]: + logger.error(line) + logger.error("\n--- Worker Process Output ---") + for line in worker_output.split('\n')[-100:]: + logger.error(line) + logger.error("=" * 80) + + assert status_code == 200, f"Task failed with status {status_code}. Check logs above for details." + + await asyncio.sleep(5) + + # Query Jaeger for traces from both services + frontend_traces = query_jaeger_traces(jaeger_url, "comfyui-docker-frontend", lookback="5m").get("data", []) + worker_traces = query_jaeger_traces(jaeger_url, "comfyui-worker", lookback="5m").get("data", []) + + assert frontend_traces, ( + f"No traces found in Jaeger for service 'comfyui-docker-frontend'. " + f"Check that OTEL export is working from Docker container. Jaeger UI: {jaeger_url}" + ) + + assert worker_traces, ( + f"No traces found in Jaeger for service 'comfyui-worker'. " + f"Check that OTEL export is working from worker. Jaeger UI: {jaeger_url}" + ) + + # Analyze span types from both services + aiohttp_spans = [] + aio_pika_frontend_spans = [] + aio_pika_worker_spans = [] + + for trace_item in frontend_traces: + for span in trace_item.get("spans", []): + operation_name = span.get("operationName", "") + if any(http_op in operation_name.upper() for http_op in ["GET", "POST", "PUT", "DELETE", "PATCH"]): + aiohttp_spans.append(operation_name) + elif "publish" in operation_name.lower() or "send" in operation_name.lower(): + aio_pika_frontend_spans.append(operation_name) + + for trace_item in worker_traces: + for span in trace_item.get("spans", []): + operation_name = span.get("operationName", "") + if "consume" in operation_name.lower() or "receive" in operation_name.lower() or "publish" in operation_name.lower(): + aio_pika_worker_spans.append(operation_name) + + assert aiohttp_spans, ( + f"No aiohttp server spans found in traces from Docker frontend. " + f"This indicates aiohttp server instrumentation is not working in the Docker image. " + f"Image: {docker_image}. Jaeger UI: {jaeger_url}" + ) + + total_aio_pika_spans = len(aio_pika_frontend_spans) + len(aio_pika_worker_spans) + assert total_aio_pika_spans > 0, ( + f"No aio_pika spans found in traces. " + f"Frontend aio_pika spans: {len(aio_pika_frontend_spans)}, Worker aio_pika spans: {len(aio_pika_worker_spans)}. " + f"Expected messaging spans for distributed queue operations. " + f"This indicates aio_pika instrumentation is not working. Jaeger UI: {jaeger_url}" + ) + + finally: + worker_process.terminate() + worker_process.wait(timeout=10) + + finally: + frontend_container.stop() \ No newline at end of file diff --git a/tests/main_pre_import_checker.py b/tests/main_pre_import_checker.py index 645c08588..4b0d17179 100644 --- a/tests/main_pre_import_checker.py +++ b/tests/main_pre_import_checker.py @@ -74,22 +74,41 @@ class MainPreImportOrderChecker(BaseChecker): def _is_other_relevant_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool: """ - Checks if an import is a relative import or an import from - the 'comfy' package family, and is not a 'main_pre' import. + Checks if an import should come after main_pre. + + Returns True for: + - Any relative import + - Any import from the 'comfy' package family + - Any third-party library import (aio_pika, aiohttp, etc.) + + Returns False for: + - main_pre imports + - __future__ imports (must always be first) """ if self._is_main_pre_import(stmt): return False if isinstance(stmt, nodes.ImportFrom): + # Allow __future__ imports (they must be first by Python rules) + if stmt.modname == '__future__': + return False + if stmt.level and stmt.level > 0: # Any relative import return True if stmt.modname and stmt.modname.startswith('comfy'): return True + # Flag any third-party or stdlib imports that need instrumentation + # This includes aio_pika, aiohttp, requests, etc. + if stmt.modname: + return True + if isinstance(stmt, nodes.Import): for name, _ in stmt.names: if name.startswith('comfy'): return True + # Flag any other absolute imports + return True return False