Improve integration test against our actual infra

This commit is contained in:
doctorpangloss 2025-11-18 15:31:32 -08:00
parent 8108315b93
commit bfb1f2fe99

View File

@ -570,6 +570,493 @@ async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_r
assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces"
@pytest.mark.asyncio
@pytest.mark.parametrize("docker_image,otlp_endpoint,jaeger_url", [
pytest.param(
"ghcr.io/hiddenswitch/comfyui:latest",
None, # Will use jaeger_container
None, # Will use jaeger_container
id="test-containers"
),
pytest.param(
"ghcr.io/hiddenswitch/comfyui:latest",
"http://10.152.184.34:4318", # otlp-collector IP
"http://10.152.184.50:16686", # jaeger-production-query IP
id="production-infrastructure"
),
])
async def test_full_docker_stack_trace_propagation(
jaeger_container,
docker_image,
otlp_endpoint,
jaeger_url
):
"""
Full integration test with frontend and backend running as Docker containers.
This test mirrors the production setup to diagnose trace context propagation issues:
1. Starts RabbitMQ container
2. Uses Jaeger container OR production infrastructure (via parametrization)
3. Starts backend worker container(s) with comfyui-worker
4. Starts frontend container with comfyui
5. Submits a workflow
6. Verifies that trace context propagates from frontend -> RabbitMQ -> backend
The test is parameterized to run with:
- test-containers: Uses ephemeral Jaeger container
- production-infrastructure: Uses cluster's Jaeger/OTLP services
"""
use_production = otlp_endpoint is not None
if use_production:
logger.info("=" * 80)
logger.info("Using PRODUCTION infrastructure:")
logger.info(f" Docker image: {docker_image}")
logger.info(f" OTLP endpoint: {otlp_endpoint}")
logger.info(f" Jaeger query URL: {jaeger_url}")
logger.info("=" * 80)
else:
# Use test container
jaeger_url = jaeger_container.get_query_url()
otlp_endpoint = jaeger_container.get_otlp_endpoint()
otlp_port = jaeger_container.get_exposed_port(4318)
logger.info("=" * 80)
logger.info("Using TEST container infrastructure:")
logger.info(f" Docker image: {docker_image}")
logger.info(f" OTLP endpoint: {otlp_endpoint}")
logger.info(f" Jaeger query URL: {jaeger_url}")
logger.info("=" * 80)
# Get Docker bridge gateway for container-to-host communication
if not use_production:
try:
result = subprocess.run(
["docker", "network", "inspect", "bridge", "-f", "{{(index .IPAM.Config 0).Gateway}}"],
capture_output=True,
text=True,
check=True,
timeout=5
)
docker_host = result.stdout.strip()
if not docker_host:
docker_host = "host.docker.internal"
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
docker_host = "host.docker.internal"
logger.info(f"Docker host for container-to-host communication: {docker_host}")
# Set OTLP endpoint accessible from containers
otlp_endpoint_container = f"http://{docker_host}:{otlp_port}"
else:
# Production services are accessible directly by DNS
otlp_endpoint_container = otlp_endpoint
# Get docker_host for RabbitMQ connectivity
try:
result = subprocess.run(
["docker", "network", "inspect", "bridge", "-f", "{{(index .IPAM.Config 0).Gateway}}"],
capture_output=True,
text=True,
check=True,
timeout=5
)
docker_host = result.stdout.strip()
if not docker_host:
docker_host = "host.docker.internal"
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
docker_host = "host.docker.internal"
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
rabbitmq_params = rabbitmq.get_connection_params()
rabbitmq_uri_container = f"amqp://guest:guest@{docker_host}:{rabbitmq_params.port}"
logger.info(f"RabbitMQ URI (from containers): {rabbitmq_uri_container}")
# Start backend workers (similar to production StatefulSet)
backend_containers = []
num_backends = 2
use_gpu = False # Run in CPU mode for tests
for i in range(num_backends):
backend = DockerContainer(docker_image)
backend.with_exposed_ports(9090) # health check port
backend_env = {
"OTEL_SERVICE_NAME": "comfyui", # Use same service name as frontend
"OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint_container,
"OTEL_METRICS_EXPORTER": "none",
"OTEL_LOGS_EXPORTER": "none",
# Configure BatchSpanProcessor to avoid silently dropping spans
"OTEL_BSP_MAX_QUEUE_SIZE": "10000", # Increased from default 2048
"OTEL_BSP_SCHEDULE_DELAY": "500", # Decreased from default 5000ms
"OTEL_BSP_MAX_EXPORT_BATCH_SIZE": "512",
"OTEL_BSP_EXPORT_TIMEOUT": "30000",
"COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI": rabbitmq_uri_container,
"COMFYUI_EXECUTOR_FACTORY": "ThreadPoolExecutor",
"COMFYUI_LOGGING_LEVEL": "INFO",
}
for key, value in backend_env.items():
backend.with_env(key, value)
if use_gpu:
logger.info(f"Configuring backend {i+1} with GPU support")
# Add GPU support for backends
backend.with_kwargs(
device_requests=[
{
"Driver": "nvidia",
"Count": -1, # -1 means all GPUs
"Capabilities": [["gpu"]],
}
]
)
# Use comfyui-worker command like in production
# ENV vars are automatically picked up by cli_args
backend.with_command("comfyui-worker")
else:
logger.info(f"Configuring backend {i+1} with CPU mode (set USE_GPU=true for GPU support)")
# Run in CPU mode without GPU
# ENV vars are automatically picked up by cli_args
backend.with_command("comfyui-worker --cpu")
backend.start()
backend_containers.append(backend)
logger.info(f"Started backend worker {i+1}/{num_backends}")
try:
# Wait for backends to be ready
time.sleep(5)
# Verify backend health
for i, backend in enumerate(backend_containers):
backend_host = backend.get_container_host_ip()
backend_port = backend.get_exposed_port(9090)
health_url = f"http://{backend_host}:{backend_port}/health"
healthy = False
for attempt in range(10):
try:
response = requests.get(health_url, timeout=2)
if response.status_code == 200:
healthy = True
logger.info(f"Backend {i+1} is healthy at {health_url}")
break
except Exception as e:
logger.debug(f"Backend {i+1} health check attempt {attempt+1}: {e}")
time.sleep(2)
if not healthy:
logger.warning(f"Backend {i+1} health check failed, but continuing...")
# Start frontend container (similar to production Deployment)
frontend = DockerContainer(docker_image)
frontend.with_exposed_ports(8188)
frontend_env = {
"OTEL_SERVICE_NAME": "comfyui", # Use same service name as backend
"OTEL_EXPORTER_OTLP_ENDPOINT": otlp_endpoint_container,
"OTEL_METRICS_EXPORTER": "none",
"OTEL_LOGS_EXPORTER": "none",
# Configure BatchSpanProcessor to avoid silently dropping spans
"OTEL_BSP_MAX_QUEUE_SIZE": "10000", # Increased from default 2048
"OTEL_BSP_SCHEDULE_DELAY": "500", # Decreased from default 5000ms
"OTEL_BSP_MAX_EXPORT_BATCH_SIZE": "512",
"OTEL_BSP_EXPORT_TIMEOUT": "30000",
"COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI": rabbitmq_uri_container,
"COMFYUI_DISTRIBUTED_QUEUE_FRONTEND": "1",
"COMFYUI_LOGGING_LEVEL": "INFO",
}
for key, value in frontend_env.items():
frontend.with_env(key, value)
# Use comfyui command like in production
# ENV vars are automatically picked up by cli_args
frontend.with_command("comfyui --listen 0.0.0.0 --port 8188 --cpu")
frontend.start()
logger.info("Started frontend container")
try:
frontend_host = frontend.get_container_host_ip()
frontend_port = frontend.get_exposed_port(8188)
frontend_url = f"http://{frontend_host}:{frontend_port}"
logger.info(f"Frontend URL: {frontend_url}")
# Wait for frontend to be ready
connected = False
for attempt in range(30):
try:
response = requests.get(frontend_url, timeout=2)
if response.status_code == 200:
connected = True
logger.info(f"Frontend is ready at {frontend_url}")
break
except Exception as e:
logger.debug(f"Frontend connection attempt {attempt+1}: {e}")
time.sleep(2)
if not connected:
# Capture frontend logs before failing
logger.error("=" * 80)
logger.error("FRONTEND FAILED TO START - Diagnostic Information:")
logger.error("=" * 80)
logger.error(f"Frontend URL: {frontend_url}")
logger.error("\n--- Frontend Container Logs ---")
try:
frontend_logs = frontend.get_logs()
if isinstance(frontend_logs, tuple):
frontend_logs = frontend_logs[0] + frontend_logs[1]
log_text = frontend_logs.decode('utf-8') if isinstance(frontend_logs, bytes) else str(frontend_logs)
for line in log_text.split('\n')[-200:]:
logger.error(line)
except Exception as e:
logger.error(f"Could not retrieve frontend logs: {e}")
logger.error("=" * 80)
assert connected, f"Could not connect to frontend at {frontend_url}. Check logs above."
# Set up tracing for the test client with properly configured BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry import trace
resource = Resource.create({
service_attributes.SERVICE_NAME: "test-client",
})
provider = TracerProvider(resource=resource)
exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces")
# Configure BatchSpanProcessor to avoid silently dropping spans
# Default: maxQueueSize=2048, scheduleDelayMillis=5000
# Problem: High span volume workflows (>1000 spans) fill the queue and spans are silently dropped
# Solution: Increase queue size and decrease delay
processor = BatchSpanProcessor(
exporter,
max_queue_size=10000, # Increased from default 2048
schedule_delay_millis=500, # Decreased from default 5000ms
max_export_batch_size=512, # Keep default
export_timeout_millis=30000, # Keep default
)
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
# Instrument aiohttp client
AioHttpClientInstrumentor().instrument()
# Import client AFTER instrumentation
from comfy.client.aio_client import AsyncRemoteComfyClient
test_id = str(uuid.uuid4())[:8]
prompt = sdxl_workflow_with_refiner(f"docker_stack_{test_id}", inference_steps=1, refiner_steps=1)
logger.info(f"Submitting workflow with test_id: {test_id}")
# Get the tracer for the test client
client_tracer = trace.get_tracer("test_full_docker_stack_trace_propagation")
# Wrap the request in a span to capture the trace ID
with client_tracer.start_as_current_span("submit_workflow") as workflow_span:
trace_id = format(workflow_span.get_span_context().trace_id, '032x')
logger.info(f"Started trace with trace_id: {trace_id}")
async with AsyncRemoteComfyClient(server_address=frontend_url) as client:
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None, "Failed to get task ID"
logger.info(f"Queued task: {task_id}")
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=120, poll_interval=2.0)
if status_code != 200:
# Capture logs from all containers
logger.error("=" * 80)
logger.error("TASK FAILED - Diagnostic Information:")
logger.error("=" * 80)
logger.error(f"Task ID: {task_id}")
logger.error(f"Status Code: {status_code}")
logger.error(f"Result: {result}")
logger.error("\n--- Frontend Container Logs (last 100 lines) ---")
frontend_logs = frontend.get_logs().decode('utf-8').split('\n')
for line in frontend_logs[-100:]:
logger.error(line)
for i, backend in enumerate(backend_containers):
logger.error(f"\n--- Backend {i+1} Container Logs (last 100 lines) ---")
backend_logs = backend.get_logs().decode('utf-8').split('\n')
for line in backend_logs[-100:]:
logger.error(line)
logger.error("=" * 80)
assert status_code == 200, f"Task failed with status {status_code}. Check logs above."
logger.info("Task completed successfully")
# Give Jaeger time to receive and process spans
logger.info("Waiting for Jaeger to collect spans...")
await asyncio.sleep(10)
# Query Jaeger for OUR specific trace by trace_id
logger.info(f"Querying Jaeger for trace_id: {trace_id}...")
# Query both services to find our trace
test_client_traces = query_jaeger_traces(jaeger_url, "test-client", lookback="10m").get("data", [])
comfyui_traces = query_jaeger_traces(jaeger_url, "comfyui", lookback="10m").get("data", [])
logger.info(f"Found {len(test_client_traces)} traces from 'test-client' service")
logger.info(f"Found {len(comfyui_traces)} traces from 'comfyui' service")
# Find our specific trace
our_trace = None
for trace in test_client_traces + comfyui_traces:
if trace.get("traceID") == trace_id:
our_trace = trace
logger.info(f"Found our trace {trace_id[:16]} with {len(trace.get('spans', []))} spans")
break
assert our_trace is not None, (
f"Could not find trace with trace_id {trace_id} in Jaeger. "
f"This indicates that spans from test-client are not being exported correctly. "
f"Jaeger UI: {jaeger_url}"
)
# Analyze our specific trace for the expected span hierarchy
spans = our_trace.get("spans", [])
logger.info(f"\nAnalyzing trace {trace_id[:16]} with {len(spans)} spans")
# Categorize spans
client_spans = []
server_spans = []
worker_spans = []
rabbitmq_spans = []
# Build a map of span_id -> span for reference lookup
span_map = {span.get("spanID"): span for span in spans}
for span in spans:
op_name = span.get("operationName", "")
span_id = span.get("spanID")
process_id = span.get("processID")
process = our_trace.get("processes", {}).get(process_id, {})
service_name = process.get("serviceName", "unknown")
logger.info(f" Span: {service_name}.{op_name} (id={span_id[:8]})")
# Categorize by operation and service
if service_name == "test-client":
client_spans.append(span)
elif "/api/v1/prompts" in op_name:
server_spans.append(span)
elif any(worker_op in op_name for worker_op in ["Execute", "execute", "queue_prompt", "Load", "Sample"]):
worker_spans.append(span)
elif any(rmq in op_name.lower() for rmq in ["publish", "consume", "send", "receive"]):
rabbitmq_spans.append(span)
logger.info(f"\nSpan summary:")
logger.info(f" Client spans: {len(client_spans)}")
logger.info(f" Server spans (/api/v1/prompts): {len(server_spans)}")
logger.info(f" Worker spans (Execute/Load/Sample): {len(worker_spans)}")
logger.info(f" RabbitMQ spans: {len(rabbitmq_spans)}")
# CRITICAL CHECKS: Verify complete trace propagation
# 1. Check that we have client spans
assert len(client_spans) > 0, (
f"NO CLIENT SPANS FOUND in trace {trace_id[:16]}!\n"
f"Expected to see spans from 'test-client' service but found none.\n"
f"This indicates the test client is not properly instrumented.\n"
f"Jaeger UI: {jaeger_url}/trace/{trace_id}"
)
logger.info(f"✓ Found {len(client_spans)} client span(s)")
# 2. Check that we have the server span /api/v1/prompts
assert len(server_spans) > 0, (
f"NO SERVER SPAN (/api/v1/prompts) FOUND in trace {trace_id[:16]}!\n"
f"Expected to see the HTTP server span but found none.\n"
f"This indicates the frontend is not properly instrumented or not in the same trace.\n"
f"Jaeger UI: {jaeger_url}/trace/{trace_id}"
)
logger.info(f"✓ Found {len(server_spans)} server span(s) for /api/v1/prompts")
# 3. Verify the server span is a CHILD of a client span
server_span = server_spans[0]
server_span_id = server_span.get("spanID")
server_references = server_span.get("references", [])
server_parent_found = False
for ref in server_references:
if ref.get("refType") == "CHILD_OF":
parent_span_id = ref.get("spanID")
if parent_span_id in span_map:
parent_span = span_map[parent_span_id]
parent_process_id = parent_span.get("processID")
parent_process = our_trace.get("processes", {}).get(parent_process_id, {})
parent_service = parent_process.get("serviceName", "unknown")
logger.info(f" Server span parent: {parent_service}.{parent_span.get('operationName', 'unknown')}")
if parent_service == "test-client":
server_parent_found = True
break
assert server_parent_found, (
f"SERVER SPAN IS NOT A CHILD OF CLIENT SPAN!\n"
f"The /api/v1/prompts span exists but is not linked to the test-client request.\n"
f"This indicates trace context (traceparent header) is not being propagated from client to server.\n"
f"Server span references: {server_references}\n"
f"Jaeger UI: {jaeger_url}/trace/{trace_id}"
)
logger.info("✓ Server span is correctly a child of client span")
# 4. Check that we have worker spans (Execute Node, etc.)
assert len(worker_spans) > 0, (
f"NO WORKER SPANS FOUND in trace {trace_id[:16]}!\n"
f"Expected to see worker spans like 'Execute Node', 'Load Checkpoint', etc.\n"
f"Found only:\n"
f" - Client spans: {len(client_spans)}\n"
f" - Server spans: {len(server_spans)}\n"
f" - RabbitMQ spans: {len(rabbitmq_spans)}\n"
f"\n"
f"THIS IS THE PRODUCTION ISSUE!\n"
f"Trace context is NOT propagating from frontend -> RabbitMQ -> worker.\n"
f"\n"
f"Possible causes:\n"
f" 1. aio-pika is not instrumented on frontend or worker\n"
f" 2. W3C trace context headers are not in AMQP message headers\n"
f" 3. OpenTelemetry context propagation is not configured correctly\n"
f" 4. OpenTelemetry Collector or Jaeger is dropping/corrupting spans\n"
f"\n"
f"Jaeger UI: {jaeger_url}/trace/{trace_id}"
)
logger.info(f"✓ Found {len(worker_spans)} worker span(s)")
# 5. Log all worker spans found
logger.info("\nWorker spans found:")
for worker_span in worker_spans:
logger.info(f" - {worker_span.get('operationName', 'unknown')}")
logger.info(f"\n✓✓✓ TRACE PROPAGATION SUCCESSFUL ✓✓✓")
logger.info(f"Trace {trace_id[:16]} contains complete span hierarchy:")
logger.info(f" Client ({len(client_spans)}) -> Server ({len(server_spans)}) -> Worker ({len(worker_spans)})")
logger.info(f"Jaeger UI: {jaeger_url}/trace/{trace_id}")
finally:
logger.info("Stopping frontend container...")
frontend.stop()
finally:
logger.info("Stopping backend containers...")
for i, backend in enumerate(backend_containers):
logger.info(f"Stopping backend {i+1}/{num_backends}...")
backend.stop()
@pytest.mark.asyncio
async def test_aiohttp_and_aio_pika_spans_with_docker_frontend(jaeger_container):
"""