ComfyUI/tests/distributed/test_tracing.py
2025-11-18 11:04:09 -08:00

129 lines
5.5 KiB
Python

import asyncio
import logging
logging.basicConfig(level=logging.ERROR)
import uuid
import pytest
from testcontainers.rabbitmq import RabbitMqContainer
from opentelemetry import trace, propagate
from opentelemetry.trace import SpanKind
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus
from comfy.distributed.server_stub import ServerStub
async def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
item_id = str(uuid.uuid4())
validation_tuple = await validate_prompt(item_id, prompt)
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
return QueueItem(queue_tuple, None)
@pytest.mark.asyncio
async def test_rabbitmq_message_properties_contain_trace_context():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
import aio_pika
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
tracer = trace.get_tracer(__name__, tracer_provider=provider)
with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER):
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
queue_item = await create_test_prompt()
put_task = asyncio.create_task(frontend.put_async(queue_item))
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
assert incoming is not None, "Worker should receive message"
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
result = await put_task
assert result is not None, "Frontend should get result"
# Now inspect the RabbitMQ queue directly to see message structure
connection = await aio_pika.connect_robust(connection_uri)
channel = await connection.channel()
# Declare a test queue to inspect message format
test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True)
# Publish a test message with trace context
carrier = {}
propagate.inject(carrier)
test_message = aio_pika.Message(
body=b"test",
headers=carrier
)
await channel.default_exchange.publish(
test_message,
routing_key=test_queue.name
)
# Get and inspect the message
received = await test_queue.get(timeout=2, fail=False)
if received:
headers = received.headers or {}
# Document what trace headers should be present
# OpenTelemetry uses 'traceparent' header for W3C Trace Context
has_traceparent = "traceparent" in headers
assert has_traceparent
await received.ack()
await connection.close()
@pytest.mark.asyncio
async def test_distributed_queue_uses_async_interface():
"""
Test that demonstrates the correct way to use DistributedPromptQueue in async context.
The synchronous get() method cannot be used in async tests due to event loop assertions.
"""
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
queue_item = await create_test_prompt()
# Start consuming in background
result_future = asyncio.create_task(frontend.put_async(queue_item))
# Worker gets item asynchronously (not using blocking get())
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
assert incoming is not None, "Should receive a queue item"
# Complete the work
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
# Wait for frontend to complete
result = await result_future
assert result is not None, "Should get result from worker"
assert result.status.status_str == "success"