diff --git a/tests/distributed/test_tracing_integration.py b/tests/distributed/test_tracing_integration.py index 463a85935..0294b1165 100644 --- a/tests/distributed/test_tracing_integration.py +++ b/tests/distributed/test_tracing_integration.py @@ -570,6 +570,493 @@ async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_r assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces" +@pytest.mark.asyncio +@pytest.mark.parametrize("docker_image,otlp_endpoint,jaeger_url", [ + pytest.param( + "ghcr.io/hiddenswitch/comfyui:latest", + None, # Will use jaeger_container + None, # Will use jaeger_container + id="test-containers" + ), + pytest.param( + "ghcr.io/hiddenswitch/comfyui:latest", + "http://10.152.184.34:4318", # otlp-collector IP + "http://10.152.184.50:16686", # jaeger-production-query IP + id="production-infrastructure" + ), +]) +async def test_full_docker_stack_trace_propagation( + jaeger_container, + docker_image, + otlp_endpoint, + jaeger_url +): + """ + Full integration test with frontend and backend running as Docker containers. + + This test mirrors the production setup to diagnose trace context propagation issues: + 1. Starts RabbitMQ container + 2. Uses Jaeger container OR production infrastructure (via parametrization) + 3. Starts backend worker container(s) with comfyui-worker + 4. Starts frontend container with comfyui + 5. Submits a workflow + 6. Verifies that trace context propagates from frontend -> RabbitMQ -> backend + + The test is parameterized to run with: + - test-containers: Uses ephemeral Jaeger container + - production-infrastructure: Uses cluster's Jaeger/OTLP services + """ + use_production = otlp_endpoint is not None + + if use_production: + logger.info("=" * 80) + logger.info("Using PRODUCTION infrastructure:") + logger.info(f" Docker image: {docker_image}") + logger.info(f" OTLP endpoint: {otlp_endpoint}") + logger.info(f" Jaeger query URL: {jaeger_url}") + logger.info("=" * 80) + else: + # Use test container + jaeger_url = jaeger_container.get_query_url() + otlp_endpoint = jaeger_container.get_otlp_endpoint() + otlp_port = jaeger_container.get_exposed_port(4318) + + logger.info("=" * 80) + logger.info("Using TEST container infrastructure:") + logger.info(f" Docker image: {docker_image}") + logger.info(f" OTLP endpoint: {otlp_endpoint}") + logger.info(f" Jaeger query URL: {jaeger_url}") + logger.info("=" * 80) + + # Get Docker bridge gateway for container-to-host communication + if not use_production: + try: + result = subprocess.run( + ["docker", "network", "inspect", "bridge", "-f", "{{(index .IPAM.Config 0).Gateway}}"], + capture_output=True, + text=True, + check=True, + timeout=5 + ) + docker_host = result.stdout.strip() + if not docker_host: + docker_host = "host.docker.internal" + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + docker_host = "host.docker.internal" + + logger.info(f"Docker host for container-to-host communication: {docker_host}") + + # Set OTLP endpoint accessible from containers + otlp_endpoint_container = f"http://{docker_host}:{otlp_port}" + else: + # Production services are accessible directly by DNS + otlp_endpoint_container = otlp_endpoint + + # Get docker_host for RabbitMQ connectivity + try: + result = subprocess.run( + ["docker", "network", "inspect", "bridge", "-f", "{{(index .IPAM.Config 0).Gateway}}"], + capture_output=True, + text=True, + check=True, + timeout=5 + ) + docker_host = result.stdout.strip() + if not docker_host: + docker_host = "host.docker.internal" + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + docker_host = "host.docker.internal" + + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + rabbitmq_params = rabbitmq.get_connection_params() + rabbitmq_uri_container = f"amqp://guest:guest@{docker_host}:{rabbitmq_params.port}" + logger.info(f"RabbitMQ URI (from containers): {rabbitmq_uri_container}") + + # Start backend workers (similar to production StatefulSet) + backend_containers = [] + num_backends = 2 + use_gpu = False # Run in CPU mode for tests + + for i in range(num_backends): + backend = DockerContainer(docker_image) + backend.with_exposed_ports(9090) # health check port + + backend_env = { + "OTEL_SERVICE_NAME": "comfyui", # Use same service name as frontend + "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint_container, + "OTEL_METRICS_EXPORTER": "none", + "OTEL_LOGS_EXPORTER": "none", + # Configure BatchSpanProcessor to avoid silently dropping spans + "OTEL_BSP_MAX_QUEUE_SIZE": "10000", # Increased from default 2048 + "OTEL_BSP_SCHEDULE_DELAY": "500", # Decreased from default 5000ms + "OTEL_BSP_MAX_EXPORT_BATCH_SIZE": "512", + "OTEL_BSP_EXPORT_TIMEOUT": "30000", + "COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI": rabbitmq_uri_container, + "COMFYUI_EXECUTOR_FACTORY": "ThreadPoolExecutor", + "COMFYUI_LOGGING_LEVEL": "INFO", + } + + for key, value in backend_env.items(): + backend.with_env(key, value) + + if use_gpu: + logger.info(f"Configuring backend {i+1} with GPU support") + # Add GPU support for backends + backend.with_kwargs( + device_requests=[ + { + "Driver": "nvidia", + "Count": -1, # -1 means all GPUs + "Capabilities": [["gpu"]], + } + ] + ) + # Use comfyui-worker command like in production + # ENV vars are automatically picked up by cli_args + backend.with_command("comfyui-worker") + else: + logger.info(f"Configuring backend {i+1} with CPU mode (set USE_GPU=true for GPU support)") + # Run in CPU mode without GPU + # ENV vars are automatically picked up by cli_args + backend.with_command("comfyui-worker --cpu") + + backend.start() + backend_containers.append(backend) + logger.info(f"Started backend worker {i+1}/{num_backends}") + + try: + # Wait for backends to be ready + time.sleep(5) + + # Verify backend health + for i, backend in enumerate(backend_containers): + backend_host = backend.get_container_host_ip() + backend_port = backend.get_exposed_port(9090) + health_url = f"http://{backend_host}:{backend_port}/health" + + healthy = False + for attempt in range(10): + try: + response = requests.get(health_url, timeout=2) + if response.status_code == 200: + healthy = True + logger.info(f"Backend {i+1} is healthy at {health_url}") + break + except Exception as e: + logger.debug(f"Backend {i+1} health check attempt {attempt+1}: {e}") + time.sleep(2) + + if not healthy: + logger.warning(f"Backend {i+1} health check failed, but continuing...") + + # Start frontend container (similar to production Deployment) + frontend = DockerContainer(docker_image) + frontend.with_exposed_ports(8188) + + frontend_env = { + "OTEL_SERVICE_NAME": "comfyui", # Use same service name as backend + "OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint_container, + "OTEL_METRICS_EXPORTER": "none", + "OTEL_LOGS_EXPORTER": "none", + # Configure BatchSpanProcessor to avoid silently dropping spans + "OTEL_BSP_MAX_QUEUE_SIZE": "10000", # Increased from default 2048 + "OTEL_BSP_SCHEDULE_DELAY": "500", # Decreased from default 5000ms + "OTEL_BSP_MAX_EXPORT_BATCH_SIZE": "512", + "OTEL_BSP_EXPORT_TIMEOUT": "30000", + "COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI": rabbitmq_uri_container, + "COMFYUI_DISTRIBUTED_QUEUE_FRONTEND": "1", + "COMFYUI_LOGGING_LEVEL": "INFO", + } + + for key, value in frontend_env.items(): + frontend.with_env(key, value) + + # Use comfyui command like in production + # ENV vars are automatically picked up by cli_args + frontend.with_command("comfyui --listen 0.0.0.0 --port 8188 --cpu") + + frontend.start() + logger.info("Started frontend container") + + try: + frontend_host = frontend.get_container_host_ip() + frontend_port = frontend.get_exposed_port(8188) + frontend_url = f"http://{frontend_host}:{frontend_port}" + + logger.info(f"Frontend URL: {frontend_url}") + + # Wait for frontend to be ready + connected = False + for attempt in range(30): + try: + response = requests.get(frontend_url, timeout=2) + if response.status_code == 200: + connected = True + logger.info(f"Frontend is ready at {frontend_url}") + break + except Exception as e: + logger.debug(f"Frontend connection attempt {attempt+1}: {e}") + time.sleep(2) + + if not connected: + # Capture frontend logs before failing + logger.error("=" * 80) + logger.error("FRONTEND FAILED TO START - Diagnostic Information:") + logger.error("=" * 80) + logger.error(f"Frontend URL: {frontend_url}") + logger.error("\n--- Frontend Container Logs ---") + try: + frontend_logs = frontend.get_logs() + if isinstance(frontend_logs, tuple): + frontend_logs = frontend_logs[0] + frontend_logs[1] + log_text = frontend_logs.decode('utf-8') if isinstance(frontend_logs, bytes) else str(frontend_logs) + for line in log_text.split('\n')[-200:]: + logger.error(line) + except Exception as e: + logger.error(f"Could not retrieve frontend logs: {e}") + logger.error("=" * 80) + + assert connected, f"Could not connect to frontend at {frontend_url}. Check logs above." + + # Set up tracing for the test client with properly configured BatchSpanProcessor + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.semconv.attributes import service_attributes + from opentelemetry import trace + + resource = Resource.create({ + service_attributes.SERVICE_NAME: "test-client", + }) + provider = TracerProvider(resource=resource) + exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces") + + # Configure BatchSpanProcessor to avoid silently dropping spans + # Default: maxQueueSize=2048, scheduleDelayMillis=5000 + # Problem: High span volume workflows (>1000 spans) fill the queue and spans are silently dropped + # Solution: Increase queue size and decrease delay + processor = BatchSpanProcessor( + exporter, + max_queue_size=10000, # Increased from default 2048 + schedule_delay_millis=500, # Decreased from default 5000ms + max_export_batch_size=512, # Keep default + export_timeout_millis=30000, # Keep default + ) + provider.add_span_processor(processor) + trace.set_tracer_provider(provider) + + # Instrument aiohttp client + AioHttpClientInstrumentor().instrument() + + # Import client AFTER instrumentation + from comfy.client.aio_client import AsyncRemoteComfyClient + + test_id = str(uuid.uuid4())[:8] + prompt = sdxl_workflow_with_refiner(f"docker_stack_{test_id}", inference_steps=1, refiner_steps=1) + + logger.info(f"Submitting workflow with test_id: {test_id}") + + # Get the tracer for the test client + client_tracer = trace.get_tracer("test_full_docker_stack_trace_propagation") + + # Wrap the request in a span to capture the trace ID + with client_tracer.start_as_current_span("submit_workflow") as workflow_span: + trace_id = format(workflow_span.get_span_context().trace_id, '032x') + logger.info(f"Started trace with trace_id: {trace_id}") + + async with AsyncRemoteComfyClient(server_address=frontend_url) as client: + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None, "Failed to get task ID" + + logger.info(f"Queued task: {task_id}") + + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=120, poll_interval=2.0) + + if status_code != 200: + # Capture logs from all containers + logger.error("=" * 80) + logger.error("TASK FAILED - Diagnostic Information:") + logger.error("=" * 80) + logger.error(f"Task ID: {task_id}") + logger.error(f"Status Code: {status_code}") + logger.error(f"Result: {result}") + + logger.error("\n--- Frontend Container Logs (last 100 lines) ---") + frontend_logs = frontend.get_logs().decode('utf-8').split('\n') + for line in frontend_logs[-100:]: + logger.error(line) + + for i, backend in enumerate(backend_containers): + logger.error(f"\n--- Backend {i+1} Container Logs (last 100 lines) ---") + backend_logs = backend.get_logs().decode('utf-8').split('\n') + for line in backend_logs[-100:]: + logger.error(line) + + logger.error("=" * 80) + + assert status_code == 200, f"Task failed with status {status_code}. Check logs above." + logger.info("Task completed successfully") + + # Give Jaeger time to receive and process spans + logger.info("Waiting for Jaeger to collect spans...") + await asyncio.sleep(10) + + # Query Jaeger for OUR specific trace by trace_id + logger.info(f"Querying Jaeger for trace_id: {trace_id}...") + + # Query both services to find our trace + test_client_traces = query_jaeger_traces(jaeger_url, "test-client", lookback="10m").get("data", []) + comfyui_traces = query_jaeger_traces(jaeger_url, "comfyui", lookback="10m").get("data", []) + + logger.info(f"Found {len(test_client_traces)} traces from 'test-client' service") + logger.info(f"Found {len(comfyui_traces)} traces from 'comfyui' service") + + # Find our specific trace + our_trace = None + for trace in test_client_traces + comfyui_traces: + if trace.get("traceID") == trace_id: + our_trace = trace + logger.info(f"Found our trace {trace_id[:16]} with {len(trace.get('spans', []))} spans") + break + + assert our_trace is not None, ( + f"Could not find trace with trace_id {trace_id} in Jaeger. " + f"This indicates that spans from test-client are not being exported correctly. " + f"Jaeger UI: {jaeger_url}" + ) + + # Analyze our specific trace for the expected span hierarchy + spans = our_trace.get("spans", []) + logger.info(f"\nAnalyzing trace {trace_id[:16]} with {len(spans)} spans") + + # Categorize spans + client_spans = [] + server_spans = [] + worker_spans = [] + rabbitmq_spans = [] + + # Build a map of span_id -> span for reference lookup + span_map = {span.get("spanID"): span for span in spans} + + for span in spans: + op_name = span.get("operationName", "") + span_id = span.get("spanID") + process_id = span.get("processID") + process = our_trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName", "unknown") + + logger.info(f" Span: {service_name}.{op_name} (id={span_id[:8]})") + + # Categorize by operation and service + if service_name == "test-client": + client_spans.append(span) + elif "/api/v1/prompts" in op_name: + server_spans.append(span) + elif any(worker_op in op_name for worker_op in ["Execute", "execute", "queue_prompt", "Load", "Sample"]): + worker_spans.append(span) + elif any(rmq in op_name.lower() for rmq in ["publish", "consume", "send", "receive"]): + rabbitmq_spans.append(span) + + logger.info(f"\nSpan summary:") + logger.info(f" Client spans: {len(client_spans)}") + logger.info(f" Server spans (/api/v1/prompts): {len(server_spans)}") + logger.info(f" Worker spans (Execute/Load/Sample): {len(worker_spans)}") + logger.info(f" RabbitMQ spans: {len(rabbitmq_spans)}") + + # CRITICAL CHECKS: Verify complete trace propagation + + # 1. Check that we have client spans + assert len(client_spans) > 0, ( + f"NO CLIENT SPANS FOUND in trace {trace_id[:16]}!\n" + f"Expected to see spans from 'test-client' service but found none.\n" + f"This indicates the test client is not properly instrumented.\n" + f"Jaeger UI: {jaeger_url}/trace/{trace_id}" + ) + logger.info(f"✓ Found {len(client_spans)} client span(s)") + + # 2. Check that we have the server span /api/v1/prompts + assert len(server_spans) > 0, ( + f"NO SERVER SPAN (/api/v1/prompts) FOUND in trace {trace_id[:16]}!\n" + f"Expected to see the HTTP server span but found none.\n" + f"This indicates the frontend is not properly instrumented or not in the same trace.\n" + f"Jaeger UI: {jaeger_url}/trace/{trace_id}" + ) + logger.info(f"✓ Found {len(server_spans)} server span(s) for /api/v1/prompts") + + # 3. Verify the server span is a CHILD of a client span + server_span = server_spans[0] + server_span_id = server_span.get("spanID") + server_references = server_span.get("references", []) + + server_parent_found = False + for ref in server_references: + if ref.get("refType") == "CHILD_OF": + parent_span_id = ref.get("spanID") + if parent_span_id in span_map: + parent_span = span_map[parent_span_id] + parent_process_id = parent_span.get("processID") + parent_process = our_trace.get("processes", {}).get(parent_process_id, {}) + parent_service = parent_process.get("serviceName", "unknown") + logger.info(f" Server span parent: {parent_service}.{parent_span.get('operationName', 'unknown')}") + + if parent_service == "test-client": + server_parent_found = True + break + + assert server_parent_found, ( + f"SERVER SPAN IS NOT A CHILD OF CLIENT SPAN!\n" + f"The /api/v1/prompts span exists but is not linked to the test-client request.\n" + f"This indicates trace context (traceparent header) is not being propagated from client to server.\n" + f"Server span references: {server_references}\n" + f"Jaeger UI: {jaeger_url}/trace/{trace_id}" + ) + logger.info("✓ Server span is correctly a child of client span") + + # 4. Check that we have worker spans (Execute Node, etc.) + assert len(worker_spans) > 0, ( + f"NO WORKER SPANS FOUND in trace {trace_id[:16]}!\n" + f"Expected to see worker spans like 'Execute Node', 'Load Checkpoint', etc.\n" + f"Found only:\n" + f" - Client spans: {len(client_spans)}\n" + f" - Server spans: {len(server_spans)}\n" + f" - RabbitMQ spans: {len(rabbitmq_spans)}\n" + f"\n" + f"THIS IS THE PRODUCTION ISSUE!\n" + f"Trace context is NOT propagating from frontend -> RabbitMQ -> worker.\n" + f"\n" + f"Possible causes:\n" + f" 1. aio-pika is not instrumented on frontend or worker\n" + f" 2. W3C trace context headers are not in AMQP message headers\n" + f" 3. OpenTelemetry context propagation is not configured correctly\n" + f" 4. OpenTelemetry Collector or Jaeger is dropping/corrupting spans\n" + f"\n" + f"Jaeger UI: {jaeger_url}/trace/{trace_id}" + ) + logger.info(f"✓ Found {len(worker_spans)} worker span(s)") + + # 5. Log all worker spans found + logger.info("\nWorker spans found:") + for worker_span in worker_spans: + logger.info(f" - {worker_span.get('operationName', 'unknown')}") + + logger.info(f"\n✓✓✓ TRACE PROPAGATION SUCCESSFUL ✓✓✓") + logger.info(f"Trace {trace_id[:16]} contains complete span hierarchy:") + logger.info(f" Client ({len(client_spans)}) -> Server ({len(server_spans)}) -> Worker ({len(worker_spans)})") + logger.info(f"Jaeger UI: {jaeger_url}/trace/{trace_id}") + + finally: + logger.info("Stopping frontend container...") + frontend.stop() + + finally: + logger.info("Stopping backend containers...") + for i, backend in enumerate(backend_containers): + logger.info(f"Stopping backend {i+1}/{num_backends}...") + backend.stop() + + @pytest.mark.asyncio async def test_aiohttp_and_aio_pika_spans_with_docker_frontend(jaeger_container): """