mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
129 lines
5.5 KiB
Python
129 lines
5.5 KiB
Python
import asyncio
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.ERROR)
|
|
|
|
import uuid
|
|
|
|
import pytest
|
|
from testcontainers.rabbitmq import RabbitMqContainer
|
|
from opentelemetry import trace, propagate
|
|
from opentelemetry.trace import SpanKind
|
|
from opentelemetry.sdk.trace import TracerProvider
|
|
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
|
|
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
|
from comfy.component_model.make_mutable import make_mutable
|
|
from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus
|
|
from comfy.distributed.server_stub import ServerStub
|
|
|
|
|
|
async def create_test_prompt() -> QueueItem:
|
|
from comfy.cmd.execution import validate_prompt
|
|
|
|
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
|
item_id = str(uuid.uuid4())
|
|
|
|
validation_tuple = await validate_prompt(item_id, prompt)
|
|
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
|
return QueueItem(queue_tuple, None)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rabbitmq_message_properties_contain_trace_context():
|
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
params = rabbitmq.get_connection_params()
|
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
|
|
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
|
import aio_pika
|
|
|
|
exporter = InMemorySpanExporter()
|
|
provider = TracerProvider()
|
|
provider.add_span_processor(SimpleSpanProcessor(exporter))
|
|
tracer = trace.get_tracer(__name__, tracer_provider=provider)
|
|
|
|
with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER):
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
|
queue_item = await create_test_prompt()
|
|
|
|
put_task = asyncio.create_task(frontend.put_async(queue_item))
|
|
|
|
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
|
|
assert incoming is not None, "Worker should receive message"
|
|
|
|
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
|
|
|
|
result = await put_task
|
|
assert result is not None, "Frontend should get result"
|
|
|
|
# Now inspect the RabbitMQ queue directly to see message structure
|
|
connection = await aio_pika.connect_robust(connection_uri)
|
|
channel = await connection.channel()
|
|
|
|
# Declare a test queue to inspect message format
|
|
test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True)
|
|
|
|
# Publish a test message with trace context
|
|
carrier = {}
|
|
propagate.inject(carrier)
|
|
|
|
test_message = aio_pika.Message(
|
|
body=b"test",
|
|
headers=carrier
|
|
)
|
|
|
|
await channel.default_exchange.publish(
|
|
test_message,
|
|
routing_key=test_queue.name
|
|
)
|
|
|
|
# Get and inspect the message
|
|
received = await test_queue.get(timeout=2, fail=False)
|
|
if received:
|
|
headers = received.headers or {}
|
|
|
|
# Document what trace headers should be present
|
|
# OpenTelemetry uses 'traceparent' header for W3C Trace Context
|
|
has_traceparent = "traceparent" in headers
|
|
|
|
assert has_traceparent
|
|
|
|
await received.ack()
|
|
|
|
await connection.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_distributed_queue_uses_async_interface():
|
|
"""
|
|
Test that demonstrates the correct way to use DistributedPromptQueue in async context.
|
|
The synchronous get() method cannot be used in async tests due to event loop assertions.
|
|
"""
|
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
|
params = rabbitmq.get_connection_params()
|
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
|
|
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
|
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
|
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
|
queue_item = await create_test_prompt()
|
|
|
|
# Start consuming in background
|
|
result_future = asyncio.create_task(frontend.put_async(queue_item))
|
|
|
|
# Worker gets item asynchronously (not using blocking get())
|
|
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
|
|
assert incoming is not None, "Should receive a queue item"
|
|
|
|
# Complete the work
|
|
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
|
|
|
|
# Wait for frontend to complete
|
|
result = await result_future
|
|
assert result is not None, "Should get result from worker"
|
|
assert result.status.status_str == "success"
|