Tracing tests

This commit is contained in:
doctorpangloss 2025-11-07 14:27:31 -08:00
parent 2f520a4cb4
commit 69d8f1b120
7 changed files with 688 additions and 40 deletions

View File

@ -1,23 +1,23 @@
from asyncio import AbstractEventLoop
from collections import defaultdict
import aiohttp
import asyncio
import uuid
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from pathlib import Path
from asyncio import AbstractEventLoop
from typing import Optional, List
from urllib.parse import urlparse, urljoin
import aiohttp
from aiohttp import WSMessage, ClientResponse, ClientTimeout
from opentelemetry import trace
from .client_types import V1QueuePromptResponse
from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt import PromptDict
from ..api.components.schema.prompt_request import PromptRequest
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
from ..api.schemas import immutabledict
from ..component_model.file_output_path import file_output_path
from ..component_model.outputs_types import OutputsDict
tracer = trace.get_tracer(__name__)
class AsyncRemoteComfyClient:
"""
@ -57,6 +57,27 @@ class AsyncRemoteComfyClient:
def session(self) -> aiohttp.ClientSession:
return self._ensure_session()
def _build_headers(self, accept_header: str, prefer_header: Optional[str] = None, content_type: str = "application/json") -> dict:
"""Build HTTP headers for requests."""
headers = {'Content-Type': content_type, 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
return headers
@tracer.start_as_current_span("Post Prompt")
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
"""
Common method to POST a prompt to a given endpoint.
:param prompt: The prompt to send
:param endpoint: The API endpoint (e.g., "/api/v1/prompts")
:param accept_header: The Accept header value
:param prefer_header: Optional Prefer header value
:return: The response object
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
headers = self._build_headers(accept_header, prefer_header)
return await self.session.post(urljoin(self.server_address, endpoint), data=prompt_json, headers=headers)
async def len_queue(self) -> int:
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
if response.status == 200:
@ -73,14 +94,7 @@ class AsyncRemoteComfyClient:
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the task ID
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers=headers) as response:
async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
if 200 <= response.status < 400:
response_json = await response.json()
return response_json["prompt_id"]
@ -95,14 +109,7 @@ class AsyncRemoteComfyClient:
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
if prefer_header:
headers['Prefer'] = prefer_header
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers=headers) as response:
async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
if 200 <= response.status < 400:
return V1QueuePromptResponse(**(await response.json()))
else:
@ -122,17 +129,13 @@ class AsyncRemoteComfyClient:
:param prompt:
:return:
"""
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
response: ClientResponse
headers = {'Content-Type': 'application/json', 'Accept': 'image/png'}
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
headers=headers) as response:
async with await self._post_prompt(prompt, "/api/v1/prompts", "image/png") as response:
if 200 <= response.status < 400:
return await response.read()
else:
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
@tracer.start_as_current_span("Post Prompt (UI)")
async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict:
"""
Uses the comfyui UI API calls to retrieve the outputs dictionary
@ -179,6 +182,7 @@ class AsyncRemoteComfyClient:
"""
return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}"))
@tracer.start_as_current_span("Poll Prompt Until Done")
async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]:
"""
Poll a prompt until it's done (200), errors (500), or times out.
@ -187,6 +191,10 @@ class AsyncRemoteComfyClient:
:param poll_interval: Time to wait between polls in seconds
:return: Tuple of (status_code, response_json or None)
"""
span = trace.get_current_span()
span.set_attribute("prompt_id", prompt_id)
span.set_attribute("max_attempts", max_attempts)
for _ in range(max_attempts):
async with await self.get_prompt_status(prompt_id) as response:
if response.status == 200:

View File

@ -1,8 +1,6 @@
import copy
from typing import TypeAlias, Union
from ..api.components.schema.prompt import PromptDict, Prompt
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
_BASE_PROMPT: JSON = {
"4": {

View File

@ -15,6 +15,7 @@ import shutil
import warnings
import fsspec
from opentelemetry.instrumentation.urllib3 import URLLib3Instrumentor
from .. import options
from ..app import logger
@ -125,9 +126,11 @@ def _create_tracer():
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
from ..tracing_compatibility import ProgressSpanSampler
from ..tracing_compatibility import patch_spanbuilder_set_channel
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
resource = Resource.create({
service_attributes.SERVICE_NAME: args.otel_service_name,
@ -141,18 +144,24 @@ def _create_tracer():
has_endpoint = args.otel_exporter_otlp_endpoint is not None
if has_endpoint:
otlp_exporter = OTLPSpanExporter()
exporter = OTLPSpanExporter()
else:
otlp_exporter = SpanExporter()
exporter = SpanExporter()
processor = BatchSpanProcessor(otlp_exporter)
processor = BatchSpanProcessor(exporter)
provider.add_span_processor(processor)
# enable instrumentation
patch_spanbuilder_set_channel()
AioPikaInstrumentor().instrument()
AioHttpServerInstrumentor().instrument()
AioHttpClientInstrumentor().instrument()
RequestsInstrumentor().instrument()
URLLib3Instrumentor().instrument()
provider.add_span_processor(BaggageSpanProcessor(ALLOW_ALL_BAGGAGE_KEYS))
# makes this behave better as a library
return trace.get_tracer(args.otel_service_name, tracer_provider=provider)

View File

@ -3,7 +3,7 @@ from typing import Optional, Sequence
from aio_pika.abc import AbstractChannel
from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.semconv.attributes.network_attributes import NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT
from opentelemetry.trace import SpanKind, Link, TraceState
from opentelemetry.util.types import Attributes
@ -22,8 +22,8 @@ def patch_spanbuilder_set_channel() -> None:
port = url.port or 5672
self._attributes.update(
{
SpanAttributes.NET_PEER_NAME: url.host,
SpanAttributes.NET_PEER_PORT: port,
NETWORK_PEER_ADDRESS: url.host,
NETWORK_PEER_PORT: port,
}
)

View File

@ -90,6 +90,14 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
# Check if OTEL endpoint is configured for integration testing
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
env = os.environ.copy()
if otel_endpoint:
env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otel_endpoint
logging.info(f"Configuring services to export traces to: {otel_endpoint}")
frontend_command = [
"comfyui",
"--listen=127.0.0.1",
@ -100,7 +108,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
f"--distributed-queue-connection-uri={connection_uri}",
]
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
# Start multiple workers
for i in range(num_workers):
@ -111,7 +119,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
f"--distributed-queue-connection-uri={connection_uri}",
f"--executor-factory={executor_factory}"
]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
try:
server_address = f"http://127.0.0.1:19001"

View File

@ -0,0 +1,128 @@
import asyncio
import logging
logging.basicConfig(level=logging.ERROR)
import uuid
import pytest
from testcontainers.rabbitmq import RabbitMqContainer
from opentelemetry import trace, propagate, context
from opentelemetry.trace import SpanKind
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus
from comfy.distributed.server_stub import ServerStub
async def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
item_id = str(uuid.uuid4())
validation_tuple = await validate_prompt(item_id, prompt)
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
return QueueItem(queue_tuple, None)
@pytest.mark.asyncio
async def test_rabbitmq_message_properties_contain_trace_context():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
import aio_pika
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
tracer = trace.get_tracer(__name__, tracer_provider=provider)
with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER):
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
queue_item = await create_test_prompt()
put_task = asyncio.create_task(frontend.put_async(queue_item))
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
assert incoming is not None, "Worker should receive message"
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
result = await put_task
assert result is not None, "Frontend should get result"
# Now inspect the RabbitMQ queue directly to see message structure
connection = await aio_pika.connect_robust(connection_uri)
channel = await connection.channel()
# Declare a test queue to inspect message format
test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True)
# Publish a test message with trace context
carrier = {}
propagate.inject(carrier)
test_message = aio_pika.Message(
body=b"test",
headers=carrier
)
await channel.default_exchange.publish(
test_message,
routing_key=test_queue.name
)
# Get and inspect the message
received = await test_queue.get(timeout=2, fail=False)
if received:
headers = received.headers or {}
# Document what trace headers should be present
# OpenTelemetry uses 'traceparent' header for W3C Trace Context
has_traceparent = "traceparent" in headers
assert has_traceparent
await received.ack()
await connection.close()
@pytest.mark.asyncio
async def test_distributed_queue_uses_async_interface():
"""
Test that demonstrates the correct way to use DistributedPromptQueue in async context.
The synchronous get() method cannot be used in async tests due to event loop assertions.
"""
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
queue_item = await create_test_prompt()
# Start consuming in background
result_future = asyncio.create_task(frontend.put_async(queue_item))
# Worker gets item asynchronously (not using blocking get())
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
assert incoming is not None, "Should receive a queue item"
# Complete the work
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
# Wait for frontend to complete
result = await result_future
assert result is not None, "Should get result from worker"
assert result.status.status_str == "success"

View File

@ -0,0 +1,497 @@
"""
Integration tests for distributed tracing across RabbitMQ and services.
These tests validate that trace context propagates correctly from frontend
to backend workers through RabbitMQ, and that Jaeger can reconstruct the
full distributed trace.
"""
import asyncio
import logging
import os
import time
import uuid
import pytest
import requests
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.attributes import service_attributes
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class JaegerContainer(DockerContainer):
"""Testcontainer for Jaeger all-in-one with OTLP support."""
def __init__(self, image: str = "jaegertracing/all-in-one:latest"):
super().__init__(image)
self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP
self.with_env("COLLECTOR_OTLP_ENABLED", "true")
def get_query_url(self) -> str:
"""Get Jaeger Query API URL."""
host = self.get_container_host_ip()
port = self.get_exposed_port(16686)
return f"http://{host}:{port}"
def get_otlp_endpoint(self) -> str:
"""Get OTLP HTTP endpoint for sending traces."""
host = self.get_container_host_ip()
port = self.get_exposed_port(4318)
return f"http://{host}:{port}"
def start(self):
super().start()
wait_for_logs(self, ".*Starting GRPC server.*", timeout=30)
return self
@pytest.fixture(scope="module")
def jaeger_container():
"""
Provide a Jaeger container for collecting traces.
This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the
Jaeger container, and cleans it up when the container stops.
"""
container = JaegerContainer()
container.start()
# Wait for Jaeger to be fully ready
query_url = container.get_query_url()
otlp_endpoint = container.get_otlp_endpoint()
for _ in range(30):
try:
response = requests.get(f"{query_url}/api/services")
if response.status_code == 200:
logger.info(f"Jaeger ready at {query_url}")
logger.info(f"OTLP endpoint: {otlp_endpoint}")
break
except Exception:
pass
time.sleep(1)
# Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test
old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint
logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}")
try:
yield container
finally:
# Restore original OTEL_EXPORTER_OTLP_ENDPOINT
if old_endpoint is not None:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint
logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}")
else:
os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None)
logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT")
container.stop()
def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None,
lookback: str = "1h", limit: int = 100) -> dict:
"""
Query Jaeger for traces.
Args:
jaeger_url: Base URL of Jaeger query service
service: Service name to query
operation: Optional operation name filter
lookback: Lookback period (e.g., "1h", "30m")
limit: Maximum number of traces to return
Returns:
JSON response from Jaeger API
"""
params = {
"service": service,
"lookback": lookback,
"limit": limit
}
if operation:
params["operation"] = operation
response = requests.get(f"{jaeger_url}/api/traces", params=params)
response.raise_for_status()
return response.json()
def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict:
"""Find a specific trace by operation name."""
for trace in traces_response.get("data", []):
for span in trace.get("spans", []):
if span.get("operationName") == operation_name:
return trace
return None
def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool:
"""
Verify that a trace spans multiple services and maintains parent-child relationships.
Args:
trace: Jaeger trace object
expected_services: List of service names expected in the trace
Returns:
True if trace shows proper distributed tracing across services
"""
if not trace:
return False
spans = trace.get("spans", [])
if not spans:
return False
# Check that all expected services are present
trace_services = set()
for span in spans:
process_id = span.get("processID")
if process_id:
process = trace.get("processes", {}).get(process_id, {})
service_name = process.get("serviceName")
if service_name:
trace_services.add(service_name)
logger.info(f"Trace contains services: {trace_services}")
logger.info(f"Expected services: {set(expected_services)}")
# Verify all expected services are present
for service in expected_services:
if service not in trace_services:
logger.warning(f"Expected service '{service}' not found in trace")
return False
# Verify all spans share the same trace ID
trace_ids = set(span.get("traceID") for span in spans)
if len(trace_ids) != 1:
logger.warning(f"Multiple trace IDs found: {trace_ids}")
return False
# Verify parent-child relationships exist
span_ids = {span.get("spanID") for span in spans}
has_parent_refs = False
for span in spans:
references = span.get("references", [])
for ref in references:
if ref.get("refType") == "CHILD_OF":
parent_span_id = ref.get("spanID")
if parent_span_id in span_ids:
has_parent_refs = True
logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}")
if not has_parent_refs:
logger.warning("No parent-child relationships found in trace")
return False
return True
# order matters, execute jaeger_container first
@pytest.mark.asyncio
async def test_tracing_integration(jaeger_container, frontend_backend_worker_with_rabbitmq):
"""
Integration test for distributed tracing across services.
This test:
1. Starts ComfyUI frontend and worker with RabbitMQ
2. Configures OTLP export to Jaeger testcontainer
3. Submits a workflow through the frontend
4. Queries Jaeger to verify trace propagation
5. Validates that the trace spans multiple services with proper relationships
Note: The frontend_backend_worker_with_rabbitmq fixture is parameterized,
so this test will run with both ThreadPoolExecutor and ProcessPoolExecutor.
"""
server_address = frontend_backend_worker_with_rabbitmq
jaeger_url = jaeger_container.get_query_url()
otlp_endpoint = jaeger_container.get_otlp_endpoint()
logger.info(f"Frontend server: {server_address}")
logger.info(f"Jaeger UI: {jaeger_url}")
logger.info(f"OTLP endpoint: {otlp_endpoint}")
# Set up tracing for the async HTTP client
resource = Resource.create({
service_attributes.SERVICE_NAME: "comfyui-client",
})
provider = TracerProvider(resource=resource)
exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces")
processor = BatchSpanProcessor(exporter)
provider.add_span_processor(processor)
from opentelemetry import trace
trace.set_tracer_provider(provider)
# Instrument aiohttp client
AioHttpClientInstrumentor().instrument()
# we have to call this very late, so that the instrumentation isn't initialized too early
from comfy.client.aio_client import AsyncRemoteComfyClient
# Note: In a real integration test, you'd need to configure the ComfyUI
# services to export traces to this Jaeger instance. For now, this test
# documents the expected behavior.
# Create a unique prompt to identify our trace
test_id = str(uuid.uuid4())[:8]
prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1)
# Get the tracer for the client
client_tracer = trace.get_tracer("test_tracing_integration")
# Submit the workflow - wrap in a span to capture the trace ID
with client_tracer.start_as_current_span("submit_workflow") as workflow_span:
trace_id = format(workflow_span.get_span_context().trace_id, '032x')
logger.info(f"Started trace with trace_id: {trace_id}")
async with AsyncRemoteComfyClient(server_address=server_address) as client:
logger.info(f"Submitting workflow with test_id: {test_id}")
# Queue the prompt with async response
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
assert task_id is not None, "Failed to get task ID"
logger.info(f"Queued task: {task_id}")
# Poll for completion
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
assert status_code == 200, f"Task failed with status {status_code}"
logger.info("Task completed successfully")
# Give Jaeger time to receive and process spans
await asyncio.sleep(5)
# Query Jaeger for traces
# Note: The actual service names depend on how your services are configured
# Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker"
expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names
logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...")
# First, try to find our specific trace by trace_id from the client service
our_trace = None
for service in expected_services:
try:
traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m")
if traces_response.get("data"):
logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'")
for trace in traces_response["data"]:
if trace.get("traceID") == trace_id:
our_trace = trace
logger.info(f"Found our trace in service '{service}'")
break
if our_trace:
break
except Exception as e:
logger.warning(f"Could not query traces for service '{service}': {e}")
# Assert we can find the trace we just created
assert our_trace is not None, (
f"Could not find trace with trace_id {trace_id} in Jaeger. "
f"This indicates that spans from comfyui-client are not being exported correctly."
)
logger.info(f"Successfully found trace with trace_id {trace_id}")
# Extract services from the trace
trace_services = set()
for span in our_trace.get("spans", []):
process_id = span.get("processID")
if process_id:
process = our_trace.get("processes", {}).get(process_id, {})
service_name = process.get("serviceName")
if service_name:
trace_services.add(service_name)
logger.info(f"Services found in trace: {trace_services}")
# Assert that comfyui-client service is present (since we instrumented it)
assert "comfyui-client" in trace_services, (
f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. "
f"This indicates the client instrumentation is not working."
)
# Validate trace structure
logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans")
# Log all spans for debugging
for span in our_trace.get("spans", []):
process_id = span.get("processID")
process = our_trace.get("processes", {}).get(process_id, {})
service_name = process.get("serviceName", "unknown")
operation = span.get("operationName", "unknown")
logger.info(f" Span: {service_name}.{operation}")
# Verify trace continuity - only if both services are present
assert "comfyui" in trace_services
is_continuous = verify_trace_continuity(our_trace, expected_services)
# This assertion documents what SHOULD happen when distributed tracing works
assert is_continuous, (
"Trace does not show proper distributed tracing. "
"Expected to see spans from multiple services with parent-child relationships. "
"This indicates that trace context is not being propagated correctly through RabbitMQ."
)
@pytest.mark.asyncio
async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq):
"""
Test that HTTP requests include traceparent headers.
This validates that the HTTP layer is properly instrumented for tracing.
"""
server_address = frontend_backend_worker_with_rabbitmq
# Make a simple HTTP request and check for trace headers
# Note: We're checking the server's response headers to see if it's trace-aware
response = requests.get(f"{server_address}/system_stats")
logger.info(f"Response headers: {dict(response.headers)}")
# The server should be instrumented and may include trace context in responses
# or at minimum, should accept traceparent headers in requests
# Test sending a traceparent header
test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
response_with_trace = requests.get(
f"{server_address}/system_stats",
headers={"traceparent": test_traceparent}
)
# Should not error when traceparent is provided
assert response_with_trace.status_code == 200, "Server should accept traceparent header"
logger.info("✓ Server accepts traceparent headers in HTTP requests")
@pytest.mark.asyncio
async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container):
"""
Test that multiple independent requests create separate traces.
This validates that trace context is properly scoped per request.
"""
server_address = frontend_backend_worker_with_rabbitmq
# Submit multiple workflows
task_ids = []
from comfy.client.aio_client import AsyncRemoteComfyClient
async with AsyncRemoteComfyClient(server_address=server_address) as client:
for i in range(3):
prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1)
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
task_ids.append(task_id)
logger.info(f"Queued task {i}: {task_id}")
# Wait for all to complete
for i, task_id in enumerate(task_ids):
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
assert status_code == 200, f"Task {i} failed"
logger.info(f"Task {i} completed")
# Give Jaeger time to receive spans
await asyncio.sleep(5)
# Query Jaeger and verify we have multiple distinct traces
jaeger_url = jaeger_container.get_query_url()
try:
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10)
traces = traces_response.get("data", [])
if len(traces) >= 2:
# Get trace IDs
trace_ids = [trace.get("traceID") for trace in traces]
unique_trace_ids = set(trace_ids)
logger.info(f"Found {len(unique_trace_ids)} unique traces")
# Verify we have multiple distinct traces
assert len(unique_trace_ids) >= 2, (
f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. "
"Each request should create its own trace."
)
logger.info("✓ Multiple requests created distinct traces")
else:
pytest.skip("Not enough traces to validate")
except Exception as e:
pytest.skip(f"Could not query Jaeger: {e}")
@pytest.mark.asyncio
async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container):
"""
Test that traces include RabbitMQ publish/consume operations.
This is critical for distributed tracing - the RabbitMQ operations
are what link the frontend and backend spans together.
"""
server_address = frontend_backend_worker_with_rabbitmq
jaeger_url = jaeger_container.get_query_url()
# Submit a workflow
from comfy.client.aio_client import AsyncRemoteComfyClient
async with AsyncRemoteComfyClient(server_address=server_address) as client:
prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1)
task_id = await client.queue_and_forget_prompt_api(prompt)
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60)
assert status_code == 200
await asyncio.sleep(5)
try:
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m")
traces = traces_response.get("data", [])
if traces:
# Look for RabbitMQ-related operations in any trace
rabbitmq_operations = [
"publish", "consume", "amq_queue_publish", "amq_queue_consume",
"amq.basic.publish", "amq.basic.consume", "send", "receive"
]
found_rabbitmq_ops = []
for trace in traces:
for span in trace.get("spans", []):
op_name = span.get("operationName", "").lower()
for rmq_op in rabbitmq_operations:
if rmq_op in op_name:
found_rabbitmq_ops.append(op_name)
if found_rabbitmq_ops:
logger.info(f"✓ Found RabbitMQ operations in traces: {set(found_rabbitmq_ops)}")
else:
logger.warning(
"No RabbitMQ operations found in traces. "
"This suggests that either:\n"
"1. AioPikaInstrumentor is not creating spans, or\n"
"2. The spans are being filtered out by the collector, or\n"
"3. The spans exist but use different operation names"
)
# Log all operation names to help debug
all_ops = set()
for trace in traces[:3]: # First 3 traces
for span in trace.get("spans", []):
all_ops.add(span.get("operationName"))
logger.info(f"Sample operation names: {all_ops}")
else:
pytest.skip("No traces found")
except Exception as e:
pytest.skip(f"Could not query Jaeger: {e}")