mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
569 lines
21 KiB
Python
569 lines
21 KiB
Python
"""
|
|
Integration tests for distributed tracing across RabbitMQ and services.
|
|
|
|
These tests validate that trace context propagates correctly from frontend
|
|
to backend workers through RabbitMQ, and that Jaeger can reconstruct the
|
|
full distributed trace.
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
|
|
import pytest
|
|
import requests
|
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
|
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
|
|
from opentelemetry.sdk.resources import Resource
|
|
from opentelemetry.sdk.trace import TracerProvider
|
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
from opentelemetry.semconv.attributes import service_attributes
|
|
from testcontainers.core.container import DockerContainer
|
|
from testcontainers.core.waiting_utils import wait_for_logs
|
|
from testcontainers.nginx import NginxContainer
|
|
|
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class JaegerContainer(DockerContainer):
|
|
"""Testcontainer for Jaeger all-in-one with OTLP support."""
|
|
|
|
def __init__(self, image: str = "jaegertracing/all-in-one:latest"):
|
|
super().__init__(image)
|
|
self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP
|
|
self.with_env("COLLECTOR_OTLP_ENABLED", "true")
|
|
|
|
def get_query_url(self) -> str:
|
|
"""Get Jaeger Query API URL."""
|
|
host = self.get_container_host_ip()
|
|
port = self.get_exposed_port(16686)
|
|
return f"http://{host}:{port}"
|
|
|
|
def get_otlp_endpoint(self) -> str:
|
|
"""Get OTLP HTTP endpoint for sending traces."""
|
|
host = self.get_container_host_ip()
|
|
port = self.get_exposed_port(4318)
|
|
return f"http://{host}:{port}"
|
|
|
|
def start(self):
|
|
super().start()
|
|
wait_for_logs(self, ".*Starting GRPC server.*", timeout=30)
|
|
return self
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def nginx_proxy(frontend_backend_worker_with_rabbitmq):
|
|
"""
|
|
Provide an nginx proxy in front of the ComfyUI frontend.
|
|
This tests if nginx is blocking W3C trace context propagation.
|
|
"""
|
|
import socket
|
|
import subprocess
|
|
|
|
# Extract host and port from frontend address
|
|
frontend_url = frontend_backend_worker_with_rabbitmq
|
|
# frontend_url is like "http://127.0.0.1:19001"
|
|
import re
|
|
match = re.match(r'http://([^:]+):(\d+)', frontend_url)
|
|
if not match:
|
|
raise ValueError(f"Could not parse frontend URL: {frontend_url}")
|
|
|
|
frontend_host = match.group(1)
|
|
frontend_port = match.group(2)
|
|
nginx_port = 8085
|
|
|
|
# Get the Docker bridge gateway IP (this is how containers reach the host on Linux)
|
|
# Try to get the default Docker bridge gateway
|
|
try:
|
|
result = subprocess.run(
|
|
["docker", "network", "inspect", "bridge", "-f", "{{range .IPAM.Config}}{{.Gateway}}{{end}}"],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True
|
|
)
|
|
docker_gateway = result.stdout.strip()
|
|
logger.info(f"Using Docker gateway IP: {docker_gateway}")
|
|
except Exception as e:
|
|
# Fallback: try common gateway IPs
|
|
docker_gateway = "172.17.0.1" # Default Docker bridge gateway on Linux
|
|
logger.warning(f"Could not detect Docker gateway, using default: {docker_gateway}")
|
|
|
|
# Create nginx config that proxies to the frontend and passes trace headers
|
|
nginx_conf = f"""
|
|
events {{
|
|
worker_connections 1024;
|
|
}}
|
|
|
|
http {{
|
|
upstream backend {{
|
|
server {docker_gateway}:{frontend_port};
|
|
}}
|
|
|
|
server {{
|
|
listen {nginx_port};
|
|
|
|
location / {{
|
|
proxy_pass http://backend;
|
|
proxy_set_header Host $host;
|
|
proxy_set_header X-Real-IP $remote_addr;
|
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
|
proxy_set_header X-Forwarded-Proto $scheme;
|
|
}}
|
|
}}
|
|
}}
|
|
"""
|
|
|
|
# Write config to a temporary file
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.conf', delete=False) as f:
|
|
f.write(nginx_conf)
|
|
nginx_conf_path = f.name
|
|
|
|
try:
|
|
# Start nginx container with the config
|
|
nginx = NginxContainer(port=nginx_port)
|
|
nginx.with_volume_mapping(nginx_conf_path, "/etc/nginx/nginx.conf")
|
|
nginx.start()
|
|
|
|
# Get the nginx URL
|
|
host = nginx.get_container_host_ip()
|
|
port = nginx.get_exposed_port(nginx_port)
|
|
nginx_url = f"http://{host}:{port}"
|
|
|
|
logger.info(f"Nginx proxy started at {nginx_url} -> {frontend_url}")
|
|
|
|
# Wait for nginx to be ready
|
|
for _ in range(30):
|
|
try:
|
|
response = requests.get(nginx_url, timeout=1)
|
|
if response.status_code:
|
|
break
|
|
except Exception:
|
|
pass
|
|
time.sleep(0.5)
|
|
|
|
yield nginx_url
|
|
finally:
|
|
nginx.stop()
|
|
os.unlink(nginx_conf_path)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def jaeger_container():
|
|
"""
|
|
Provide a Jaeger container for collecting traces.
|
|
|
|
This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the
|
|
Jaeger container, and cleans it up when the container stops.
|
|
"""
|
|
container = JaegerContainer()
|
|
container.start()
|
|
|
|
# Wait for Jaeger to be fully ready
|
|
query_url = container.get_query_url()
|
|
otlp_endpoint = container.get_otlp_endpoint()
|
|
|
|
for _ in range(30):
|
|
try:
|
|
response = requests.get(f"{query_url}/api/services")
|
|
if response.status_code == 200:
|
|
logger.info(f"Jaeger ready at {query_url}")
|
|
logger.info(f"OTLP endpoint: {otlp_endpoint}")
|
|
break
|
|
except Exception:
|
|
pass
|
|
time.sleep(1)
|
|
|
|
# Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test
|
|
old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
|
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint
|
|
logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}")
|
|
|
|
try:
|
|
yield container
|
|
finally:
|
|
# Restore original OTEL_EXPORTER_OTLP_ENDPOINT
|
|
if old_endpoint is not None:
|
|
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint
|
|
logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}")
|
|
else:
|
|
os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None)
|
|
logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT")
|
|
|
|
container.stop()
|
|
|
|
|
|
def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None,
|
|
lookback: str = "1h", limit: int = 100) -> dict:
|
|
"""
|
|
Query Jaeger for traces.
|
|
|
|
Args:
|
|
jaeger_url: Base URL of Jaeger query service
|
|
service: Service name to query
|
|
operation: Optional operation name filter
|
|
lookback: Lookback period (e.g., "1h", "30m")
|
|
limit: Maximum number of traces to return
|
|
|
|
Returns:
|
|
JSON response from Jaeger API
|
|
"""
|
|
params = {
|
|
"service": service,
|
|
"lookback": lookback,
|
|
"limit": limit
|
|
}
|
|
if operation:
|
|
params["operation"] = operation
|
|
|
|
response = requests.get(f"{jaeger_url}/api/traces", params=params)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
|
|
def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict:
|
|
"""Find a specific trace by operation name."""
|
|
for trace in traces_response.get("data", []):
|
|
for span in trace.get("spans", []):
|
|
if span.get("operationName") == operation_name:
|
|
return trace
|
|
return None
|
|
|
|
|
|
def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool:
|
|
"""
|
|
Verify that a trace spans multiple services and maintains parent-child relationships.
|
|
|
|
Args:
|
|
trace: Jaeger trace object
|
|
expected_services: List of service names expected in the trace
|
|
|
|
Returns:
|
|
True if trace shows proper distributed tracing across services
|
|
"""
|
|
if not trace:
|
|
return False
|
|
|
|
spans = trace.get("spans", [])
|
|
if not spans:
|
|
return False
|
|
|
|
# Check that all expected services are present
|
|
trace_services = set()
|
|
for span in spans:
|
|
process_id = span.get("processID")
|
|
if process_id:
|
|
process = trace.get("processes", {}).get(process_id, {})
|
|
service_name = process.get("serviceName")
|
|
if service_name:
|
|
trace_services.add(service_name)
|
|
|
|
logger.info(f"Trace contains services: {trace_services}")
|
|
logger.info(f"Expected services: {set(expected_services)}")
|
|
|
|
# Verify all expected services are present
|
|
for service in expected_services:
|
|
if service not in trace_services:
|
|
logger.warning(f"Expected service '{service}' not found in trace")
|
|
return False
|
|
|
|
# Verify all spans share the same trace ID
|
|
trace_ids = set(span.get("traceID") for span in spans)
|
|
if len(trace_ids) != 1:
|
|
logger.warning(f"Multiple trace IDs found: {trace_ids}")
|
|
return False
|
|
|
|
# Verify parent-child relationships exist
|
|
span_ids = {span.get("spanID") for span in spans}
|
|
has_parent_refs = False
|
|
|
|
for span in spans:
|
|
references = span.get("references", [])
|
|
for ref in references:
|
|
if ref.get("refType") == "CHILD_OF":
|
|
parent_span_id = ref.get("spanID")
|
|
if parent_span_id in span_ids:
|
|
has_parent_refs = True
|
|
logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}")
|
|
|
|
if not has_parent_refs:
|
|
logger.warning("No parent-child relationships found in trace")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
# order matters, execute jaeger_container first
|
|
@pytest.mark.asyncio
|
|
async def test_tracing_integration(jaeger_container, nginx_proxy):
|
|
"""
|
|
Integration test for distributed tracing across services with nginx proxy.
|
|
|
|
This test:
|
|
1. Starts ComfyUI frontend and worker with RabbitMQ
|
|
2. Starts nginx proxy in front of the frontend to test trace context propagation through nginx
|
|
3. Configures OTLP export to Jaeger testcontainer
|
|
4. Submits a workflow through the nginx proxy
|
|
5. Queries Jaeger to verify trace propagation
|
|
6. Validates that the trace spans multiple services with proper relationships
|
|
|
|
This specifically tests if nginx is blocking W3C trace context (traceparent/tracestate headers).
|
|
"""
|
|
server_address = nginx_proxy
|
|
jaeger_url = jaeger_container.get_query_url()
|
|
otlp_endpoint = jaeger_container.get_otlp_endpoint()
|
|
|
|
logger.info(f"Frontend server: {server_address}")
|
|
logger.info(f"Jaeger UI: {jaeger_url}")
|
|
logger.info(f"OTLP endpoint: {otlp_endpoint}")
|
|
|
|
# Set up tracing for the async HTTP client
|
|
resource = Resource.create({
|
|
service_attributes.SERVICE_NAME: "comfyui-client",
|
|
})
|
|
provider = TracerProvider(resource=resource)
|
|
exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces")
|
|
processor = BatchSpanProcessor(exporter)
|
|
provider.add_span_processor(processor)
|
|
from opentelemetry import trace
|
|
|
|
trace.set_tracer_provider(provider)
|
|
|
|
# Instrument aiohttp client
|
|
AioHttpClientInstrumentor().instrument()
|
|
|
|
# we have to call this very late, so that the instrumentation isn't initialized too early
|
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
|
|
|
# Note: In a real integration test, you'd need to configure the ComfyUI
|
|
# services to export traces to this Jaeger instance. For now, this test
|
|
# documents the expected behavior.
|
|
|
|
# Create a unique prompt to identify our trace
|
|
test_id = str(uuid.uuid4())[:8]
|
|
prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1)
|
|
|
|
# Get the tracer for the client
|
|
client_tracer = trace.get_tracer("test_tracing_integration")
|
|
|
|
# Submit the workflow - wrap in a span to capture the trace ID
|
|
with client_tracer.start_as_current_span("submit_workflow") as workflow_span:
|
|
trace_id = format(workflow_span.get_span_context().trace_id, '032x')
|
|
logger.info(f"Started trace with trace_id: {trace_id}")
|
|
|
|
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
|
logger.info(f"Submitting workflow with test_id: {test_id}")
|
|
|
|
# Queue the prompt with async response
|
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
|
assert task_id is not None, "Failed to get task ID"
|
|
|
|
logger.info(f"Queued task: {task_id}")
|
|
|
|
# Poll for completion
|
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
|
assert status_code == 200, f"Task failed with status {status_code}"
|
|
logger.info("Task completed successfully")
|
|
|
|
# Give Jaeger time to receive and process spans
|
|
await asyncio.sleep(5)
|
|
|
|
# Query Jaeger for traces
|
|
# Note: The actual service names depend on how your services are configured
|
|
# Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker"
|
|
|
|
expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names
|
|
|
|
logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...")
|
|
|
|
# First, try to find our specific trace by trace_id from the client service
|
|
our_trace = None
|
|
for service in expected_services:
|
|
try:
|
|
traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m")
|
|
if traces_response.get("data"):
|
|
logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'")
|
|
for trace in traces_response["data"]:
|
|
if trace.get("traceID") == trace_id:
|
|
our_trace = trace
|
|
logger.info(f"Found our trace in service '{service}'")
|
|
break
|
|
if our_trace:
|
|
break
|
|
except Exception as e:
|
|
logger.warning(f"Could not query traces for service '{service}': {e}")
|
|
|
|
# Assert we can find the trace we just created
|
|
assert our_trace is not None, (
|
|
f"Could not find trace with trace_id {trace_id} in Jaeger. "
|
|
f"This indicates that spans from comfyui-client are not being exported correctly."
|
|
)
|
|
|
|
logger.info(f"Successfully found trace with trace_id {trace_id}")
|
|
|
|
# Extract services from the trace
|
|
trace_services = set()
|
|
for span in our_trace.get("spans", []):
|
|
process_id = span.get("processID")
|
|
if process_id:
|
|
process = our_trace.get("processes", {}).get(process_id, {})
|
|
service_name = process.get("serviceName")
|
|
if service_name:
|
|
trace_services.add(service_name)
|
|
|
|
logger.info(f"Services found in trace: {trace_services}")
|
|
|
|
# Assert that comfyui-client service is present (since we instrumented it)
|
|
assert "comfyui-client" in trace_services, (
|
|
f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. "
|
|
f"This indicates the client instrumentation is not working."
|
|
)
|
|
|
|
# Validate trace structure
|
|
logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans")
|
|
|
|
# Log all spans for debugging
|
|
for span in our_trace.get("spans", []):
|
|
process_id = span.get("processID")
|
|
process = our_trace.get("processes", {}).get(process_id, {})
|
|
service_name = process.get("serviceName", "unknown")
|
|
operation = span.get("operationName", "unknown")
|
|
logger.info(f" Span: {service_name}.{operation}")
|
|
|
|
# Verify trace continuity - only if both services are present
|
|
assert "comfyui" in trace_services
|
|
is_continuous = verify_trace_continuity(our_trace, expected_services)
|
|
|
|
# This assertion documents what SHOULD happen when distributed tracing works
|
|
assert is_continuous, (
|
|
"Trace does not show proper distributed tracing. "
|
|
"Expected to see spans from multiple services with parent-child relationships. "
|
|
"This indicates that trace context is not being propagated correctly through RabbitMQ."
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq):
|
|
"""
|
|
Test that HTTP requests include traceparent headers.
|
|
|
|
This validates that the HTTP layer is properly instrumented for tracing.
|
|
"""
|
|
server_address = frontend_backend_worker_with_rabbitmq
|
|
|
|
# Make a simple HTTP request and check for trace headers
|
|
# Note: We're checking the server's response headers to see if it's trace-aware
|
|
response = requests.get(f"{server_address}/system_stats")
|
|
|
|
logger.info(f"Response headers: {dict(response.headers)}")
|
|
|
|
# The server should be instrumented and may include trace context in responses
|
|
# or at minimum, should accept traceparent headers in requests
|
|
|
|
# Test sending a traceparent header
|
|
test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
|
|
response_with_trace = requests.get(
|
|
f"{server_address}/system_stats",
|
|
headers={"traceparent": test_traceparent}
|
|
)
|
|
|
|
# Should not error when traceparent is provided
|
|
assert response_with_trace.status_code == 200, "Server should accept traceparent header"
|
|
|
|
logger.info("✓ Server accepts traceparent headers in HTTP requests")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container):
|
|
"""
|
|
Test that multiple independent requests create separate traces.
|
|
|
|
This validates that trace context is properly scoped per request.
|
|
"""
|
|
server_address = frontend_backend_worker_with_rabbitmq
|
|
|
|
# Submit multiple workflows
|
|
task_ids = []
|
|
|
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
|
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
|
for i in range(3):
|
|
prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1)
|
|
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
|
task_ids.append(task_id)
|
|
logger.info(f"Queued task {i}: {task_id}")
|
|
|
|
# Wait for all to complete
|
|
for i, task_id in enumerate(task_ids):
|
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
|
assert status_code == 200, f"Task {i} failed"
|
|
logger.info(f"Task {i} completed")
|
|
|
|
# Give Jaeger time to receive spans
|
|
await asyncio.sleep(5)
|
|
|
|
# Query Jaeger and verify we have multiple distinct traces
|
|
jaeger_url = jaeger_container.get_query_url()
|
|
|
|
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10)
|
|
traces = traces_response.get("data", [])
|
|
|
|
assert len(traces) >= 2
|
|
# Get trace IDs
|
|
trace_ids = [trace.get("traceID") for trace in traces]
|
|
unique_trace_ids = set(trace_ids)
|
|
|
|
logger.info(f"Found {len(unique_trace_ids)} unique traces")
|
|
|
|
# Verify we have multiple distinct traces
|
|
assert len(unique_trace_ids) >= 2, (
|
|
f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. "
|
|
"Each request should create its own trace."
|
|
)
|
|
|
|
logger.info("✓ Multiple requests created distinct traces")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.skip(reason="rabbitmq has to be configured for observability?")
|
|
async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container):
|
|
"""
|
|
Test that traces include RabbitMQ publish/consume operations.
|
|
|
|
This is critical for distributed tracing - the RabbitMQ operations
|
|
are what link the frontend and backend spans together.
|
|
"""
|
|
server_address = frontend_backend_worker_with_rabbitmq
|
|
jaeger_url = jaeger_container.get_query_url()
|
|
|
|
# Submit a workflow
|
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
|
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
|
prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1)
|
|
task_id = await client.queue_and_forget_prompt_api(prompt)
|
|
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60)
|
|
assert status_code == 200
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m")
|
|
traces = traces_response.get("data", [])
|
|
|
|
# Look for RabbitMQ-related operations in any trace
|
|
rabbitmq_operations = [
|
|
"publish", "consume", "amq_queue_publish", "amq_queue_consume",
|
|
"amq.basic.publish", "amq.basic.consume", "send", "receive"
|
|
]
|
|
|
|
found_rabbitmq_ops = []
|
|
for trace in traces:
|
|
for span in trace.get("spans", []):
|
|
op_name = span.get("operationName", "").lower()
|
|
for rmq_op in rabbitmq_operations:
|
|
if rmq_op in op_name:
|
|
found_rabbitmq_ops.append(op_name)
|
|
|
|
assert found_rabbitmq_ops, "No RabbitMQ-related operations found in traces" |