Fix distributed tracing

This commit is contained in:
doctorpangloss 2025-11-18 11:04:09 -08:00
parent bdb3c8f039
commit 8108315b93
14 changed files with 291 additions and 27 deletions

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from ..cmd.main_pre import tracer
import asyncio
import concurrent.futures
import copy
@ -15,8 +17,6 @@ from typing import Optional, Generator
from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from ..cmd.main_pre import tracer
from .async_progress_iterable import _ProgressHandler, QueuePromptWithProgress
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from .main_pre import tracer
import asyncio
import copy
import heapq
@ -16,8 +18,6 @@ from enum import Enum
from os import PathLike
from typing import List, Optional, Tuple, Literal
# order matters
from .main_pre import tracer
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status

View File

@ -1,3 +1,5 @@
from .main_pre import tracer
import asyncio
import contextvars
import gc
@ -9,9 +11,6 @@ import threading
import time
from pathlib import Path
from typing import Optional
# main_pre must be the earliest import
from .main_pre import tracer
from ..cli_args_types import Configuration
from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp
from ..execution_context import current_execution_context

View File

@ -118,8 +118,9 @@ def _fix_pytorch_240():
def _create_tracer():
from opentelemetry import trace
from opentelemetry import trace, metrics
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.semconv.attributes import service_attributes
@ -127,6 +128,8 @@ def _create_tracer():
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
@ -142,6 +145,10 @@ def _create_tracer():
sampler = ProgressSpanSampler()
provider = TracerProvider(resource=resource, sampler=sampler)
# Set the global tracer provider FIRST, before instrumenting
# This ensures instrumentors can access the provider
trace.set_tracer_provider(provider)
has_endpoint = args.otel_exporter_otlp_endpoint is not None
if has_endpoint:
@ -152,9 +159,21 @@ def _create_tracer():
processor = BatchSpanProcessor(exporter)
provider.add_span_processor(processor)
# enable instrumentation
# Set up metrics export to track dropped spans
# Only enable if OTEL_EXPORTER_OTLP_METRICS_ENDPOINT is set, since not all OTLP endpoints support metrics
metrics_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT")
if metrics_endpoint:
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(endpoint=metrics_endpoint),
export_interval_millis=10000 # Export every 10 seconds
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# enable instrumentation BEFORE any aio_pika imports
patch_spanbuilder_set_channel()
# Instrument aio_pika first since it's most likely to be imported early
AioPikaInstrumentor().instrument()
AioHttpServerInstrumentor().instrument()
AioHttpClientInstrumentor().instrument()

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from ..cmd.main_pre import tracer
import asyncio
import time
import uuid
@ -13,7 +15,6 @@ from aio_pika import connect_robust
from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import JsonRPC
from ..cmd.main_pre import tracer
from .distributed_progress import ProgressHandlers
from .distributed_types import RpcRequest, RpcReply
from .history import History

View File

@ -1,3 +1,5 @@
from ..cmd.main_pre import tracer
import asyncio
import logging
from asyncio import AbstractEventLoop
@ -9,8 +11,6 @@ from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC
from aiohttp import web
from aiormq import AMQPConnectionError
from ..cmd.main_pre import tracer
from .executors import ContextVarExecutor
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply

View File

@ -1,6 +1,6 @@
import asyncio
from ..cmd.main_pre import args
import asyncio
from ..component_model.file_counter import cleanup_temp
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args

View File

@ -1,11 +1,11 @@
from ..cmd.main_pre import args
import asyncio
import json
import logging
from typing import Optional, Literal
import typer
from ..cmd.main_pre import args
from ..cli_args_types import Configuration
from ..component_model.asyncio_files import stream_json_objects
from ..client.embedded_comfy_client import Comfy

View File

@ -17,6 +17,8 @@
"""
from __future__ import annotations
from .cmd.main_pre import tracer
import gc
import logging
import platform
@ -32,8 +34,6 @@ from typing import List, Sequence, Final, Optional
import psutil
import torch
from opentelemetry.trace import get_current_span
from .cmd.main_pre import tracer
from . import interruption
from .cli_args import args, PerformanceFeature
from .component_model.deprecation import _deprecate_method

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from ..cmd.main_pre import tracer
import importlib
import logging
import os
@ -12,8 +14,6 @@ from importlib.metadata import entry_points
from threading import RLock
from opentelemetry.trace import Span, Status, StatusCode
from ..cmd.main_pre import tracer
from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports

View File

@ -31,6 +31,19 @@ def patch_spanbuilder_set_channel() -> None:
class ProgressSpanSampler(Sampler):
"""
Sampler which omits aio_pika messages destined/related to progress.
To debug which spans are being dropped, set environment variable:
OTEL_DEBUG_SAMPLER=1
"""
def __init__(self):
import logging
import os
self.logger = logging.getLogger(__name__)
self.debug = os.environ.get("OTEL_DEBUG_SAMPLER", "0") == "1"
def get_description(self) -> str:
return "Sampler which omits aio_pika messages destined/related to progress"
@ -44,9 +57,23 @@ class ProgressSpanSampler(Sampler):
links: Optional[Sequence["Link"]] = None,
trace_state: Optional["TraceState"] = None,
) -> "SamplingResult":
should_drop = False
drop_reason = None
if attributes is not None and "messaging.destination" in attributes and attributes["messaging.destination"].endswith("progress"):
return SamplingResult(Decision.DROP)
should_drop = True
drop_reason = f"messaging.destination={attributes['messaging.destination']}"
# the ephemeral reply channels are not required for correct span correlation
if name.startswith(",amq_") or name.startswith("amq"):
elif name.startswith(",amq_") or name.startswith("amq"):
should_drop = True
drop_reason = f"span name starts with amq: {name}"
if should_drop:
if self.debug:
self.logger.debug(f"[SAMPLER] DROPPED span '{name}' (kind={kind}): {drop_reason}")
return SamplingResult(Decision.DROP)
if self.debug:
self.logger.debug(f"[SAMPLER] RECORDING span '{name}' (kind={kind}, attributes={attributes})")
return SamplingResult(Decision.RECORD_AND_SAMPLE)

View File

@ -7,7 +7,7 @@ import uuid
import pytest
from testcontainers.rabbitmq import RabbitMqContainer
from opentelemetry import trace, propagate, context
from opentelemetry import trace, propagate
from opentelemetry.trace import SpanKind
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

View File

@ -8,6 +8,7 @@ full distributed trace.
import asyncio
import logging
import os
import subprocess
import tempfile
import time
import uuid
@ -23,6 +24,7 @@ from opentelemetry.semconv.attributes import service_attributes
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.nginx import NginxContainer
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
@ -528,7 +530,6 @@ async def test_multiple_requests_different_traces(frontend_backend_worker_with_r
@pytest.mark.asyncio
@pytest.mark.skip(reason="rabbitmq has to be configured for observability?")
async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container):
"""
Test that traces include RabbitMQ publish/consume operations.
@ -566,4 +567,202 @@ async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_r
if rmq_op in op_name:
found_rabbitmq_ops.append(op_name)
assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces"
assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces"
@pytest.mark.asyncio
async def test_aiohttp_and_aio_pika_spans_with_docker_frontend(jaeger_container):
"""
Test that both aiohttp and aio_pika instrumentation work in the Docker image.
This test helps diagnose if there's a dependency issue in the Docker image preventing
instrumentation from working correctly by:
1. Starting the ComfyUI frontend in a Docker container
2. Starting a local worker process
3. Submitting a workflow
4. Querying Jaeger to verify both aiohttp and aio_pika spans are present
Set COMFYUI_IMAGE env var to override default image, e.g.:
COMFYUI_IMAGE=ghcr.io/hiddenswitch/comfyui:latest
"""
docker_image = os.environ.get("COMFYUI_IMAGE", "ghcr.io/hiddenswitch/comfyui:latest")
jaeger_url = jaeger_container.get_query_url()
otlp_endpoint = jaeger_container.get_otlp_endpoint()
otlp_port = jaeger_container.get_exposed_port(4318)
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
# Get Docker bridge gateway for container-to-host communication
try:
result = subprocess.run(
["docker", "network", "inspect", "bridge", "-f", "{{(index .IPAM.Config 0).Gateway}}"],
capture_output=True,
text=True,
check=True,
timeout=5
)
docker_host = result.stdout.strip()
if not docker_host:
docker_host = "host.docker.internal"
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
docker_host = "host.docker.internal"
connection_uri_container = f"amqp://guest:guest@{docker_host}:{params.port}"
connection_uri_local = f"amqp://guest:guest@127.0.0.1:{params.port}"
# Start frontend in Docker container
frontend_container = DockerContainer(docker_image)
frontend_container.with_exposed_ports(8188)
otlp_endpoint_container = f"http://{docker_host}:{otlp_port}"
env_vars = {
"OTEL_SERVICE_NAME": "comfyui-docker-frontend",
"OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint_container,
}
for key, value in env_vars.items():
frontend_container.with_env(key, value)
frontend_container.with_command(
f"python -m comfy.cmd.main --listen 0.0.0.0 --port 8188 "
f"--cpu --distributed-queue-frontend "
f"--distributed-queue-connection-uri={connection_uri_container}"
)
frontend_container.start()
try:
frontend_host = frontend_container.get_container_host_ip()
frontend_port = frontend_container.get_exposed_port(8188)
frontend_url = f"http://{frontend_host}:{frontend_port}"
# Wait for frontend to be ready
connected = False
for _ in range(15):
try:
response = requests.get(frontend_url, timeout=1)
if response.status_code == 200:
connected = True
break
except Exception:
pass
time.sleep(1)
assert connected, f"Could not connect to Docker frontend at {frontend_url}"
# Start local worker
worker_env = os.environ.copy()
worker_env["OTEL_SERVICE_NAME"] = "comfyui-worker"
worker_env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint
worker_process = subprocess.Popen(
[
"comfyui-worker",
"--port=19099",
f"--distributed-queue-connection-uri={connection_uri_local}",
"--executor-factory=ThreadPoolExecutor"
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=worker_env,
text=True,
bufsize=1
)
try:
time.sleep(5)
from comfy.client.aio_client import AsyncRemoteComfyClient
test_id = str(uuid.uuid4())[:8]
prompt = sdxl_workflow_with_refiner(f"docker_test_{test_id}", inference_steps=1, refiner_steps=1)
async with AsyncRemoteComfyClient(server_address=frontend_url) as client:
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None, "Failed to get task ID"
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=2.0)
if status_code != 200:
# Capture worker logs
worker_output = ""
if worker_process.stdout:
worker_output = worker_process.stdout.read()
# Get frontend container logs
frontend_logs = frontend_container.get_logs()
logger.error("=" * 80)
logger.error("TASK FAILED - Diagnostic Information:")
logger.error("=" * 80)
logger.error(f"Task ID: {task_id}")
logger.error(f"Status Code: {status_code}")
logger.error(f"Result: {result}")
logger.error("\n--- Frontend Container Logs (last 100 lines) ---")
frontend_log_lines = frontend_logs.decode('utf-8').split('\n')
for line in frontend_log_lines[-100:]:
logger.error(line)
logger.error("\n--- Worker Process Output ---")
for line in worker_output.split('\n')[-100:]:
logger.error(line)
logger.error("=" * 80)
assert status_code == 200, f"Task failed with status {status_code}. Check logs above for details."
await asyncio.sleep(5)
# Query Jaeger for traces from both services
frontend_traces = query_jaeger_traces(jaeger_url, "comfyui-docker-frontend", lookback="5m").get("data", [])
worker_traces = query_jaeger_traces(jaeger_url, "comfyui-worker", lookback="5m").get("data", [])
assert frontend_traces, (
f"No traces found in Jaeger for service 'comfyui-docker-frontend'. "
f"Check that OTEL export is working from Docker container. Jaeger UI: {jaeger_url}"
)
assert worker_traces, (
f"No traces found in Jaeger for service 'comfyui-worker'. "
f"Check that OTEL export is working from worker. Jaeger UI: {jaeger_url}"
)
# Analyze span types from both services
aiohttp_spans = []
aio_pika_frontend_spans = []
aio_pika_worker_spans = []
for trace_item in frontend_traces:
for span in trace_item.get("spans", []):
operation_name = span.get("operationName", "")
if any(http_op in operation_name.upper() for http_op in ["GET", "POST", "PUT", "DELETE", "PATCH"]):
aiohttp_spans.append(operation_name)
elif "publish" in operation_name.lower() or "send" in operation_name.lower():
aio_pika_frontend_spans.append(operation_name)
for trace_item in worker_traces:
for span in trace_item.get("spans", []):
operation_name = span.get("operationName", "")
if "consume" in operation_name.lower() or "receive" in operation_name.lower() or "publish" in operation_name.lower():
aio_pika_worker_spans.append(operation_name)
assert aiohttp_spans, (
f"No aiohttp server spans found in traces from Docker frontend. "
f"This indicates aiohttp server instrumentation is not working in the Docker image. "
f"Image: {docker_image}. Jaeger UI: {jaeger_url}"
)
total_aio_pika_spans = len(aio_pika_frontend_spans) + len(aio_pika_worker_spans)
assert total_aio_pika_spans > 0, (
f"No aio_pika spans found in traces. "
f"Frontend aio_pika spans: {len(aio_pika_frontend_spans)}, Worker aio_pika spans: {len(aio_pika_worker_spans)}. "
f"Expected messaging spans for distributed queue operations. "
f"This indicates aio_pika instrumentation is not working. Jaeger UI: {jaeger_url}"
)
finally:
worker_process.terminate()
worker_process.wait(timeout=10)
finally:
frontend_container.stop()

View File

@ -74,22 +74,41 @@ class MainPreImportOrderChecker(BaseChecker):
def _is_other_relevant_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
"""
Checks if an import is a relative import or an import from
the 'comfy' package family, and is not a 'main_pre' import.
Checks if an import should come after main_pre.
Returns True for:
- Any relative import
- Any import from the 'comfy' package family
- Any third-party library import (aio_pika, aiohttp, etc.)
Returns False for:
- main_pre imports
- __future__ imports (must always be first)
"""
if self._is_main_pre_import(stmt):
return False
if isinstance(stmt, nodes.ImportFrom):
# Allow __future__ imports (they must be first by Python rules)
if stmt.modname == '__future__':
return False
if stmt.level and stmt.level > 0: # Any relative import
return True
if stmt.modname and stmt.modname.startswith('comfy'):
return True
# Flag any third-party or stdlib imports that need instrumentation
# This includes aio_pika, aiohttp, requests, etc.
if stmt.modname:
return True
if isinstance(stmt, nodes.Import):
for name, _ in stmt.names:
if name.startswith('comfy'):
return True
# Flag any other absolute imports
return True
return False