""" Integration tests for distributed tracing across RabbitMQ and services. These tests validate that trace context propagates correctly from frontend to backend workers through RabbitMQ, and that Jaeger can reconstruct the full distributed trace. """ import asyncio import logging import os import tempfile import time import uuid import pytest import requests from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.attributes import service_attributes from testcontainers.core.container import DockerContainer from testcontainers.core.waiting_utils import wait_for_logs from testcontainers.nginx import NginxContainer from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class JaegerContainer(DockerContainer): """Testcontainer for Jaeger all-in-one with OTLP support.""" def __init__(self, image: str = "jaegertracing/all-in-one:latest"): super().__init__(image) self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP self.with_env("COLLECTOR_OTLP_ENABLED", "true") def get_query_url(self) -> str: """Get Jaeger Query API URL.""" host = self.get_container_host_ip() port = self.get_exposed_port(16686) return f"http://{host}:{port}" def get_otlp_endpoint(self) -> str: """Get OTLP HTTP endpoint for sending traces.""" host = self.get_container_host_ip() port = self.get_exposed_port(4318) return f"http://{host}:{port}" def start(self): super().start() wait_for_logs(self, ".*Starting GRPC server.*", timeout=30) return self @pytest.fixture(scope="function") def nginx_proxy(frontend_backend_worker_with_rabbitmq): """ Provide an nginx proxy in front of the ComfyUI frontend. This tests if nginx is blocking W3C trace context propagation. """ import socket import subprocess # Extract host and port from frontend address frontend_url = frontend_backend_worker_with_rabbitmq # frontend_url is like "http://127.0.0.1:19001" import re match = re.match(r'http://([^:]+):(\d+)', frontend_url) if not match: raise ValueError(f"Could not parse frontend URL: {frontend_url}") frontend_host = match.group(1) frontend_port = match.group(2) nginx_port = 8085 # Get the Docker bridge gateway IP (this is how containers reach the host on Linux) # Try to get the default Docker bridge gateway try: result = subprocess.run( ["docker", "network", "inspect", "bridge", "-f", "{{range .IPAM.Config}}{{.Gateway}}{{end}}"], capture_output=True, text=True, check=True ) docker_gateway = result.stdout.strip() logger.info(f"Using Docker gateway IP: {docker_gateway}") except Exception as e: # Fallback: try common gateway IPs docker_gateway = "172.17.0.1" # Default Docker bridge gateway on Linux logger.warning(f"Could not detect Docker gateway, using default: {docker_gateway}") # Create nginx config that proxies to the frontend and passes trace headers nginx_conf = f""" events {{ worker_connections 1024; }} http {{ upstream backend {{ server {docker_gateway}:{frontend_port}; }} server {{ listen {nginx_port}; location / {{ proxy_pass http://backend; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; }} }} }} """ # Write config to a temporary file with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as f: f.write(nginx_conf) nginx_conf_path = f.name try: # Start nginx container with the config nginx = NginxContainer(port=nginx_port) nginx.with_volume_mapping(nginx_conf_path, "/etc/nginx/nginx.conf") nginx.start() # Get the nginx URL host = nginx.get_container_host_ip() port = nginx.get_exposed_port(nginx_port) nginx_url = f"http://{host}:{port}" logger.info(f"Nginx proxy started at {nginx_url} -> {frontend_url}") # Wait for nginx to be ready for _ in range(30): try: response = requests.get(nginx_url, timeout=1) if response.status_code: break except Exception: pass time.sleep(0.5) yield nginx_url finally: nginx.stop() os.unlink(nginx_conf_path) @pytest.fixture(scope="module") def jaeger_container(): """ Provide a Jaeger container for collecting traces. This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the Jaeger container, and cleans it up when the container stops. """ container = JaegerContainer() container.start() # Wait for Jaeger to be fully ready query_url = container.get_query_url() otlp_endpoint = container.get_otlp_endpoint() for _ in range(30): try: response = requests.get(f"{query_url}/api/services") if response.status_code == 200: logger.info(f"Jaeger ready at {query_url}") logger.info(f"OTLP endpoint: {otlp_endpoint}") break except Exception: pass time.sleep(1) # Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}") try: yield container finally: # Restore original OTEL_EXPORTER_OTLP_ENDPOINT if old_endpoint is not None: os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}") else: os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None) logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT") container.stop() def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None, lookback: str = "1h", limit: int = 100) -> dict: """ Query Jaeger for traces. Args: jaeger_url: Base URL of Jaeger query service service: Service name to query operation: Optional operation name filter lookback: Lookback period (e.g., "1h", "30m") limit: Maximum number of traces to return Returns: JSON response from Jaeger API """ params = { "service": service, "lookback": lookback, "limit": limit } if operation: params["operation"] = operation response = requests.get(f"{jaeger_url}/api/traces", params=params) response.raise_for_status() return response.json() def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict: """Find a specific trace by operation name.""" for trace in traces_response.get("data", []): for span in trace.get("spans", []): if span.get("operationName") == operation_name: return trace return None def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool: """ Verify that a trace spans multiple services and maintains parent-child relationships. Args: trace: Jaeger trace object expected_services: List of service names expected in the trace Returns: True if trace shows proper distributed tracing across services """ if not trace: return False spans = trace.get("spans", []) if not spans: return False # Check that all expected services are present trace_services = set() for span in spans: process_id = span.get("processID") if process_id: process = trace.get("processes", {}).get(process_id, {}) service_name = process.get("serviceName") if service_name: trace_services.add(service_name) logger.info(f"Trace contains services: {trace_services}") logger.info(f"Expected services: {set(expected_services)}") # Verify all expected services are present for service in expected_services: if service not in trace_services: logger.warning(f"Expected service '{service}' not found in trace") return False # Verify all spans share the same trace ID trace_ids = set(span.get("traceID") for span in spans) if len(trace_ids) != 1: logger.warning(f"Multiple trace IDs found: {trace_ids}") return False # Verify parent-child relationships exist span_ids = {span.get("spanID") for span in spans} has_parent_refs = False for span in spans: references = span.get("references", []) for ref in references: if ref.get("refType") == "CHILD_OF": parent_span_id = ref.get("spanID") if parent_span_id in span_ids: has_parent_refs = True logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}") if not has_parent_refs: logger.warning("No parent-child relationships found in trace") return False return True # order matters, execute jaeger_container first @pytest.mark.asyncio async def test_tracing_integration(jaeger_container, nginx_proxy): """ Integration test for distributed tracing across services with nginx proxy. This test: 1. Starts ComfyUI frontend and worker with RabbitMQ 2. Starts nginx proxy in front of the frontend to test trace context propagation through nginx 3. Configures OTLP export to Jaeger testcontainer 4. Submits a workflow through the nginx proxy 5. Queries Jaeger to verify trace propagation 6. Validates that the trace spans multiple services with proper relationships This specifically tests if nginx is blocking W3C trace context (traceparent/tracestate headers). """ server_address = nginx_proxy jaeger_url = jaeger_container.get_query_url() otlp_endpoint = jaeger_container.get_otlp_endpoint() logger.info(f"Frontend server: {server_address}") logger.info(f"Jaeger UI: {jaeger_url}") logger.info(f"OTLP endpoint: {otlp_endpoint}") # Set up tracing for the async HTTP client resource = Resource.create({ service_attributes.SERVICE_NAME: "comfyui-client", }) provider = TracerProvider(resource=resource) exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces") processor = BatchSpanProcessor(exporter) provider.add_span_processor(processor) from opentelemetry import trace trace.set_tracer_provider(provider) # Instrument aiohttp client AioHttpClientInstrumentor().instrument() # we have to call this very late, so that the instrumentation isn't initialized too early from comfy.client.aio_client import AsyncRemoteComfyClient # Note: In a real integration test, you'd need to configure the ComfyUI # services to export traces to this Jaeger instance. For now, this test # documents the expected behavior. # Create a unique prompt to identify our trace test_id = str(uuid.uuid4())[:8] prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1) # Get the tracer for the client client_tracer = trace.get_tracer("test_tracing_integration") # Submit the workflow - wrap in a span to capture the trace ID with client_tracer.start_as_current_span("submit_workflow") as workflow_span: trace_id = format(workflow_span.get_span_context().trace_id, '032x') logger.info(f"Started trace with trace_id: {trace_id}") async with AsyncRemoteComfyClient(server_address=server_address) as client: logger.info(f"Submitting workflow with test_id: {test_id}") # Queue the prompt with async response task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") assert task_id is not None, "Failed to get task ID" logger.info(f"Queued task: {task_id}") # Poll for completion status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) assert status_code == 200, f"Task failed with status {status_code}" logger.info("Task completed successfully") # Give Jaeger time to receive and process spans await asyncio.sleep(5) # Query Jaeger for traces # Note: The actual service names depend on how your services are configured # Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker" expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...") # First, try to find our specific trace by trace_id from the client service our_trace = None for service in expected_services: try: traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m") if traces_response.get("data"): logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'") for trace in traces_response["data"]: if trace.get("traceID") == trace_id: our_trace = trace logger.info(f"Found our trace in service '{service}'") break if our_trace: break except Exception as e: logger.warning(f"Could not query traces for service '{service}': {e}") # Assert we can find the trace we just created assert our_trace is not None, ( f"Could not find trace with trace_id {trace_id} in Jaeger. " f"This indicates that spans from comfyui-client are not being exported correctly." ) logger.info(f"Successfully found trace with trace_id {trace_id}") # Extract services from the trace trace_services = set() for span in our_trace.get("spans", []): process_id = span.get("processID") if process_id: process = our_trace.get("processes", {}).get(process_id, {}) service_name = process.get("serviceName") if service_name: trace_services.add(service_name) logger.info(f"Services found in trace: {trace_services}") # Assert that comfyui-client service is present (since we instrumented it) assert "comfyui-client" in trace_services, ( f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. " f"This indicates the client instrumentation is not working." ) # Validate trace structure logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans") # Log all spans for debugging for span in our_trace.get("spans", []): process_id = span.get("processID") process = our_trace.get("processes", {}).get(process_id, {}) service_name = process.get("serviceName", "unknown") operation = span.get("operationName", "unknown") logger.info(f" Span: {service_name}.{operation}") # Verify trace continuity - only if both services are present assert "comfyui" in trace_services is_continuous = verify_trace_continuity(our_trace, expected_services) # This assertion documents what SHOULD happen when distributed tracing works assert is_continuous, ( "Trace does not show proper distributed tracing. " "Expected to see spans from multiple services with parent-child relationships. " "This indicates that trace context is not being propagated correctly through RabbitMQ." ) @pytest.mark.asyncio async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq): """ Test that HTTP requests include traceparent headers. This validates that the HTTP layer is properly instrumented for tracing. """ server_address = frontend_backend_worker_with_rabbitmq # Make a simple HTTP request and check for trace headers # Note: We're checking the server's response headers to see if it's trace-aware response = requests.get(f"{server_address}/system_stats") logger.info(f"Response headers: {dict(response.headers)}") # The server should be instrumented and may include trace context in responses # or at minimum, should accept traceparent headers in requests # Test sending a traceparent header test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" response_with_trace = requests.get( f"{server_address}/system_stats", headers={"traceparent": test_traceparent} ) # Should not error when traceparent is provided assert response_with_trace.status_code == 200, "Server should accept traceparent header" logger.info("✓ Server accepts traceparent headers in HTTP requests") @pytest.mark.asyncio async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container): """ Test that multiple independent requests create separate traces. This validates that trace context is properly scoped per request. """ server_address = frontend_backend_worker_with_rabbitmq # Submit multiple workflows task_ids = [] from comfy.client.aio_client import AsyncRemoteComfyClient async with AsyncRemoteComfyClient(server_address=server_address) as client: for i in range(3): prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1) task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") task_ids.append(task_id) logger.info(f"Queued task {i}: {task_id}") # Wait for all to complete for i, task_id in enumerate(task_ids): status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) assert status_code == 200, f"Task {i} failed" logger.info(f"Task {i} completed") # Give Jaeger time to receive spans await asyncio.sleep(5) # Query Jaeger and verify we have multiple distinct traces jaeger_url = jaeger_container.get_query_url() traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10) traces = traces_response.get("data", []) assert len(traces) >= 2 # Get trace IDs trace_ids = [trace.get("traceID") for trace in traces] unique_trace_ids = set(trace_ids) logger.info(f"Found {len(unique_trace_ids)} unique traces") # Verify we have multiple distinct traces assert len(unique_trace_ids) >= 2, ( f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. " "Each request should create its own trace." ) logger.info("✓ Multiple requests created distinct traces") @pytest.mark.asyncio @pytest.mark.skip(reason="rabbitmq has to be configured for observability?") async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container): """ Test that traces include RabbitMQ publish/consume operations. This is critical for distributed tracing - the RabbitMQ operations are what link the frontend and backend spans together. """ server_address = frontend_backend_worker_with_rabbitmq jaeger_url = jaeger_container.get_query_url() # Submit a workflow from comfy.client.aio_client import AsyncRemoteComfyClient async with AsyncRemoteComfyClient(server_address=server_address) as client: prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1) task_id = await client.queue_and_forget_prompt_api(prompt) status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60) assert status_code == 200 await asyncio.sleep(5) traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m") traces = traces_response.get("data", []) # Look for RabbitMQ-related operations in any trace rabbitmq_operations = [ "publish", "consume", "amq_queue_publish", "amq_queue_consume", "amq.basic.publish", "amq.basic.consume", "send", "receive" ] found_rabbitmq_ops = [] for trace in traces: for span in trace.get("spans", []): op_name = span.get("operationName", "").lower() for rmq_op in rabbitmq_operations: if rmq_op in op_name: found_rabbitmq_ops.append(op_name) assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces"