mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Tracing tests
This commit is contained in:
parent
2f520a4cb4
commit
69d8f1b120
@ -1,23 +1,23 @@
|
||||
from asyncio import AbstractEventLoop
|
||||
from collections import defaultdict
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||
from pathlib import Path
|
||||
from asyncio import AbstractEventLoop
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urlparse, urljoin
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import WSMessage, ClientResponse, ClientTimeout
|
||||
from opentelemetry import trace
|
||||
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.api_client import JSONEncoder
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..api.components.schema.prompt_request import PromptRequest
|
||||
from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict
|
||||
from ..api.schemas import immutabledict
|
||||
from ..component_model.file_output_path import file_output_path
|
||||
from ..component_model.outputs_types import OutputsDict
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class AsyncRemoteComfyClient:
|
||||
"""
|
||||
@ -57,6 +57,27 @@ class AsyncRemoteComfyClient:
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
return self._ensure_session()
|
||||
|
||||
def _build_headers(self, accept_header: str, prefer_header: Optional[str] = None, content_type: str = "application/json") -> dict:
|
||||
"""Build HTTP headers for requests."""
|
||||
headers = {'Content-Type': content_type, 'Accept': accept_header}
|
||||
if prefer_header:
|
||||
headers['Prefer'] = prefer_header
|
||||
return headers
|
||||
|
||||
@tracer.start_as_current_span("Post Prompt")
|
||||
async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse:
|
||||
"""
|
||||
Common method to POST a prompt to a given endpoint.
|
||||
:param prompt: The prompt to send
|
||||
:param endpoint: The API endpoint (e.g., "/api/v1/prompts")
|
||||
:param accept_header: The Accept header value
|
||||
:param prefer_header: Optional Prefer header value
|
||||
:return: The response object
|
||||
"""
|
||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||
headers = self._build_headers(accept_header, prefer_header)
|
||||
return await self.session.post(urljoin(self.server_address, endpoint), data=prompt_json, headers=headers)
|
||||
|
||||
async def len_queue(self) -> int:
|
||||
async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response:
|
||||
if response.status == 200:
|
||||
@ -73,14 +94,7 @@ class AsyncRemoteComfyClient:
|
||||
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
|
||||
:return: the task ID
|
||||
"""
|
||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||
response: ClientResponse
|
||||
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
|
||||
if prefer_header:
|
||||
headers['Prefer'] = prefer_header
|
||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers=headers) as response:
|
||||
|
||||
async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
|
||||
if 200 <= response.status < 400:
|
||||
response_json = await response.json()
|
||||
return response_json["prompt_id"]
|
||||
@ -95,14 +109,7 @@ class AsyncRemoteComfyClient:
|
||||
:param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async")
|
||||
:return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true)
|
||||
"""
|
||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||
response: ClientResponse
|
||||
headers = {'Content-Type': 'application/json', 'Accept': accept_header}
|
||||
if prefer_header:
|
||||
headers['Prefer'] = prefer_header
|
||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers=headers) as response:
|
||||
|
||||
async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response:
|
||||
if 200 <= response.status < 400:
|
||||
return V1QueuePromptResponse(**(await response.json()))
|
||||
else:
|
||||
@ -122,17 +129,13 @@ class AsyncRemoteComfyClient:
|
||||
:param prompt:
|
||||
:return:
|
||||
"""
|
||||
prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt)
|
||||
response: ClientResponse
|
||||
headers = {'Content-Type': 'application/json', 'Accept': 'image/png'}
|
||||
async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json,
|
||||
headers=headers) as response:
|
||||
|
||||
async with await self._post_prompt(prompt, "/api/v1/prompts", "image/png") as response:
|
||||
if 200 <= response.status < 400:
|
||||
return await response.read()
|
||||
else:
|
||||
raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}")
|
||||
|
||||
@tracer.start_as_current_span("Post Prompt (UI)")
|
||||
async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict:
|
||||
"""
|
||||
Uses the comfyui UI API calls to retrieve the outputs dictionary
|
||||
@ -179,6 +182,7 @@ class AsyncRemoteComfyClient:
|
||||
"""
|
||||
return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}"))
|
||||
|
||||
@tracer.start_as_current_span("Poll Prompt Until Done")
|
||||
async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]:
|
||||
"""
|
||||
Poll a prompt until it's done (200), errors (500), or times out.
|
||||
@ -187,6 +191,10 @@ class AsyncRemoteComfyClient:
|
||||
:param poll_interval: Time to wait between polls in seconds
|
||||
:return: Tuple of (status_code, response_json or None)
|
||||
"""
|
||||
span = trace.get_current_span()
|
||||
span.set_attribute("prompt_id", prompt_id)
|
||||
span.set_attribute("max_attempts", max_attempts)
|
||||
|
||||
for _ in range(max_attempts):
|
||||
async with await self.get_prompt_status(prompt_id) as response:
|
||||
if response.status == 200:
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import copy
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from ..api.components.schema.prompt import PromptDict, Prompt
|
||||
|
||||
JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None]
|
||||
_BASE_PROMPT: JSON = {
|
||||
"4": {
|
||||
|
||||
@ -15,6 +15,7 @@ import shutil
|
||||
import warnings
|
||||
|
||||
import fsspec
|
||||
from opentelemetry.instrumentation.urllib3 import URLLib3Instrumentor
|
||||
|
||||
from .. import options
|
||||
from ..app import logger
|
||||
@ -125,9 +126,11 @@ def _create_tracer():
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
|
||||
from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS
|
||||
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
|
||||
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
|
||||
from ..tracing_compatibility import ProgressSpanSampler
|
||||
from ..tracing_compatibility import patch_spanbuilder_set_channel
|
||||
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
|
||||
|
||||
resource = Resource.create({
|
||||
service_attributes.SERVICE_NAME: args.otel_service_name,
|
||||
@ -141,18 +144,24 @@ def _create_tracer():
|
||||
has_endpoint = args.otel_exporter_otlp_endpoint is not None
|
||||
|
||||
if has_endpoint:
|
||||
otlp_exporter = OTLPSpanExporter()
|
||||
exporter = OTLPSpanExporter()
|
||||
else:
|
||||
otlp_exporter = SpanExporter()
|
||||
exporter = SpanExporter()
|
||||
|
||||
processor = BatchSpanProcessor(otlp_exporter)
|
||||
processor = BatchSpanProcessor(exporter)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# enable instrumentation
|
||||
patch_spanbuilder_set_channel()
|
||||
|
||||
AioPikaInstrumentor().instrument()
|
||||
AioHttpServerInstrumentor().instrument()
|
||||
AioHttpClientInstrumentor().instrument()
|
||||
RequestsInstrumentor().instrument()
|
||||
URLLib3Instrumentor().instrument()
|
||||
|
||||
|
||||
provider.add_span_processor(BaggageSpanProcessor(ALLOW_ALL_BAGGAGE_KEYS))
|
||||
# makes this behave better as a library
|
||||
return trace.get_tracer(args.otel_service_name, tracer_provider=provider)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Optional, Sequence
|
||||
from aio_pika.abc import AbstractChannel
|
||||
from opentelemetry.context import Context
|
||||
from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision
|
||||
from opentelemetry.semconv.trace import SpanAttributes
|
||||
from opentelemetry.semconv.attributes.network_attributes import NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT
|
||||
from opentelemetry.trace import SpanKind, Link, TraceState
|
||||
from opentelemetry.util.types import Attributes
|
||||
|
||||
@ -22,8 +22,8 @@ def patch_spanbuilder_set_channel() -> None:
|
||||
port = url.port or 5672
|
||||
self._attributes.update(
|
||||
{
|
||||
SpanAttributes.NET_PEER_NAME: url.host,
|
||||
SpanAttributes.NET_PEER_PORT: port,
|
||||
NETWORK_PEER_ADDRESS: url.host,
|
||||
NETWORK_PEER_PORT: port,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -90,6 +90,14 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
# Check if OTEL endpoint is configured for integration testing
|
||||
otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
|
||||
env = os.environ.copy()
|
||||
if otel_endpoint:
|
||||
env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otel_endpoint
|
||||
logging.info(f"Configuring services to export traces to: {otel_endpoint}")
|
||||
|
||||
frontend_command = [
|
||||
"comfyui",
|
||||
"--listen=127.0.0.1",
|
||||
@ -100,7 +108,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
]
|
||||
|
||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
|
||||
|
||||
# Start multiple workers
|
||||
for i in range(num_workers):
|
||||
@ -111,7 +119,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
f"--executor-factory={executor_factory}"
|
||||
]
|
||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr, env=env))
|
||||
|
||||
try:
|
||||
server_address = f"http://127.0.0.1:19001"
|
||||
|
||||
@ -0,0 +1,128 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
from opentelemetry import trace, propagate, context
|
||||
from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
|
||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
from comfy.component_model.make_mutable import make_mutable
|
||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
|
||||
|
||||
async def create_test_prompt() -> QueueItem:
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
|
||||
prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1))
|
||||
item_id = str(uuid.uuid4())
|
||||
|
||||
validation_tuple = await validate_prompt(item_id, prompt)
|
||||
queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2])
|
||||
return QueueItem(queue_tuple, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rabbitmq_message_properties_contain_trace_context():
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
import aio_pika
|
||||
|
||||
exporter = InMemorySpanExporter()
|
||||
provider = TracerProvider()
|
||||
provider.add_span_processor(SimpleSpanProcessor(exporter))
|
||||
tracer = trace.get_tracer(__name__, tracer_provider=provider)
|
||||
|
||||
with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER):
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
||||
queue_item = await create_test_prompt()
|
||||
|
||||
put_task = asyncio.create_task(frontend.put_async(queue_item))
|
||||
|
||||
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
|
||||
assert incoming is not None, "Worker should receive message"
|
||||
|
||||
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
|
||||
|
||||
result = await put_task
|
||||
assert result is not None, "Frontend should get result"
|
||||
|
||||
# Now inspect the RabbitMQ queue directly to see message structure
|
||||
connection = await aio_pika.connect_robust(connection_uri)
|
||||
channel = await connection.channel()
|
||||
|
||||
# Declare a test queue to inspect message format
|
||||
test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True)
|
||||
|
||||
# Publish a test message with trace context
|
||||
carrier = {}
|
||||
propagate.inject(carrier)
|
||||
|
||||
test_message = aio_pika.Message(
|
||||
body=b"test",
|
||||
headers=carrier
|
||||
)
|
||||
|
||||
await channel.default_exchange.publish(
|
||||
test_message,
|
||||
routing_key=test_queue.name
|
||||
)
|
||||
|
||||
# Get and inspect the message
|
||||
received = await test_queue.get(timeout=2, fail=False)
|
||||
if received:
|
||||
headers = received.headers or {}
|
||||
|
||||
# Document what trace headers should be present
|
||||
# OpenTelemetry uses 'traceparent' header for W3C Trace Context
|
||||
has_traceparent = "traceparent" in headers
|
||||
|
||||
assert has_traceparent
|
||||
|
||||
await received.ack()
|
||||
|
||||
await connection.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distributed_queue_uses_async_interface():
|
||||
"""
|
||||
Test that demonstrates the correct way to use DistributedPromptQueue in async context.
|
||||
The synchronous get() method cannot be used in async tests due to event loop assertions.
|
||||
"""
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
|
||||
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
|
||||
queue_item = await create_test_prompt()
|
||||
|
||||
# Start consuming in background
|
||||
result_future = asyncio.create_task(frontend.put_async(queue_item))
|
||||
|
||||
# Worker gets item asynchronously (not using blocking get())
|
||||
incoming, incoming_prompt_id = await worker.get_async(timeout=5.0)
|
||||
assert incoming is not None, "Should receive a queue item"
|
||||
|
||||
# Complete the work
|
||||
worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, []))
|
||||
|
||||
# Wait for frontend to complete
|
||||
result = await result_future
|
||||
assert result is not None, "Should get result from worker"
|
||||
assert result.status.status_str == "success"
|
||||
497
tests/distributed/test_tracing_integration.py
Normal file
497
tests/distributed/test_tracing_integration.py
Normal file
@ -0,0 +1,497 @@
|
||||
"""
|
||||
Integration tests for distributed tracing across RabbitMQ and services.
|
||||
|
||||
These tests validate that trace context propagates correctly from frontend
|
||||
to backend workers through RabbitMQ, and that Jaeger can reconstruct the
|
||||
full distributed trace.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from testcontainers.core.container import DockerContainer
|
||||
from testcontainers.core.waiting_utils import wait_for_logs
|
||||
|
||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JaegerContainer(DockerContainer):
|
||||
"""Testcontainer for Jaeger all-in-one with OTLP support."""
|
||||
|
||||
def __init__(self, image: str = "jaegertracing/all-in-one:latest"):
|
||||
super().__init__(image)
|
||||
self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP
|
||||
self.with_env("COLLECTOR_OTLP_ENABLED", "true")
|
||||
|
||||
def get_query_url(self) -> str:
|
||||
"""Get Jaeger Query API URL."""
|
||||
host = self.get_container_host_ip()
|
||||
port = self.get_exposed_port(16686)
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def get_otlp_endpoint(self) -> str:
|
||||
"""Get OTLP HTTP endpoint for sending traces."""
|
||||
host = self.get_container_host_ip()
|
||||
port = self.get_exposed_port(4318)
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def start(self):
|
||||
super().start()
|
||||
wait_for_logs(self, ".*Starting GRPC server.*", timeout=30)
|
||||
return self
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def jaeger_container():
|
||||
"""
|
||||
Provide a Jaeger container for collecting traces.
|
||||
|
||||
This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the
|
||||
Jaeger container, and cleans it up when the container stops.
|
||||
"""
|
||||
container = JaegerContainer()
|
||||
container.start()
|
||||
|
||||
# Wait for Jaeger to be fully ready
|
||||
query_url = container.get_query_url()
|
||||
otlp_endpoint = container.get_otlp_endpoint()
|
||||
|
||||
for _ in range(30):
|
||||
try:
|
||||
response = requests.get(f"{query_url}/api/services")
|
||||
if response.status_code == 200:
|
||||
logger.info(f"Jaeger ready at {query_url}")
|
||||
logger.info(f"OTLP endpoint: {otlp_endpoint}")
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(1)
|
||||
|
||||
# Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test
|
||||
old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint
|
||||
logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}")
|
||||
|
||||
try:
|
||||
yield container
|
||||
finally:
|
||||
# Restore original OTEL_EXPORTER_OTLP_ENDPOINT
|
||||
if old_endpoint is not None:
|
||||
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint
|
||||
logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}")
|
||||
else:
|
||||
os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None)
|
||||
logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
|
||||
container.stop()
|
||||
|
||||
|
||||
def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None,
|
||||
lookback: str = "1h", limit: int = 100) -> dict:
|
||||
"""
|
||||
Query Jaeger for traces.
|
||||
|
||||
Args:
|
||||
jaeger_url: Base URL of Jaeger query service
|
||||
service: Service name to query
|
||||
operation: Optional operation name filter
|
||||
lookback: Lookback period (e.g., "1h", "30m")
|
||||
limit: Maximum number of traces to return
|
||||
|
||||
Returns:
|
||||
JSON response from Jaeger API
|
||||
"""
|
||||
params = {
|
||||
"service": service,
|
||||
"lookback": lookback,
|
||||
"limit": limit
|
||||
}
|
||||
if operation:
|
||||
params["operation"] = operation
|
||||
|
||||
response = requests.get(f"{jaeger_url}/api/traces", params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict:
|
||||
"""Find a specific trace by operation name."""
|
||||
for trace in traces_response.get("data", []):
|
||||
for span in trace.get("spans", []):
|
||||
if span.get("operationName") == operation_name:
|
||||
return trace
|
||||
return None
|
||||
|
||||
|
||||
def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool:
|
||||
"""
|
||||
Verify that a trace spans multiple services and maintains parent-child relationships.
|
||||
|
||||
Args:
|
||||
trace: Jaeger trace object
|
||||
expected_services: List of service names expected in the trace
|
||||
|
||||
Returns:
|
||||
True if trace shows proper distributed tracing across services
|
||||
"""
|
||||
if not trace:
|
||||
return False
|
||||
|
||||
spans = trace.get("spans", [])
|
||||
if not spans:
|
||||
return False
|
||||
|
||||
# Check that all expected services are present
|
||||
trace_services = set()
|
||||
for span in spans:
|
||||
process_id = span.get("processID")
|
||||
if process_id:
|
||||
process = trace.get("processes", {}).get(process_id, {})
|
||||
service_name = process.get("serviceName")
|
||||
if service_name:
|
||||
trace_services.add(service_name)
|
||||
|
||||
logger.info(f"Trace contains services: {trace_services}")
|
||||
logger.info(f"Expected services: {set(expected_services)}")
|
||||
|
||||
# Verify all expected services are present
|
||||
for service in expected_services:
|
||||
if service not in trace_services:
|
||||
logger.warning(f"Expected service '{service}' not found in trace")
|
||||
return False
|
||||
|
||||
# Verify all spans share the same trace ID
|
||||
trace_ids = set(span.get("traceID") for span in spans)
|
||||
if len(trace_ids) != 1:
|
||||
logger.warning(f"Multiple trace IDs found: {trace_ids}")
|
||||
return False
|
||||
|
||||
# Verify parent-child relationships exist
|
||||
span_ids = {span.get("spanID") for span in spans}
|
||||
has_parent_refs = False
|
||||
|
||||
for span in spans:
|
||||
references = span.get("references", [])
|
||||
for ref in references:
|
||||
if ref.get("refType") == "CHILD_OF":
|
||||
parent_span_id = ref.get("spanID")
|
||||
if parent_span_id in span_ids:
|
||||
has_parent_refs = True
|
||||
logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}")
|
||||
|
||||
if not has_parent_refs:
|
||||
logger.warning("No parent-child relationships found in trace")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# order matters, execute jaeger_container first
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracing_integration(jaeger_container, frontend_backend_worker_with_rabbitmq):
|
||||
"""
|
||||
Integration test for distributed tracing across services.
|
||||
|
||||
This test:
|
||||
1. Starts ComfyUI frontend and worker with RabbitMQ
|
||||
2. Configures OTLP export to Jaeger testcontainer
|
||||
3. Submits a workflow through the frontend
|
||||
4. Queries Jaeger to verify trace propagation
|
||||
5. Validates that the trace spans multiple services with proper relationships
|
||||
|
||||
Note: The frontend_backend_worker_with_rabbitmq fixture is parameterized,
|
||||
so this test will run with both ThreadPoolExecutor and ProcessPoolExecutor.
|
||||
"""
|
||||
server_address = frontend_backend_worker_with_rabbitmq
|
||||
jaeger_url = jaeger_container.get_query_url()
|
||||
otlp_endpoint = jaeger_container.get_otlp_endpoint()
|
||||
|
||||
logger.info(f"Frontend server: {server_address}")
|
||||
logger.info(f"Jaeger UI: {jaeger_url}")
|
||||
logger.info(f"OTLP endpoint: {otlp_endpoint}")
|
||||
|
||||
# Set up tracing for the async HTTP client
|
||||
resource = Resource.create({
|
||||
service_attributes.SERVICE_NAME: "comfyui-client",
|
||||
})
|
||||
provider = TracerProvider(resource=resource)
|
||||
exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces")
|
||||
processor = BatchSpanProcessor(exporter)
|
||||
provider.add_span_processor(processor)
|
||||
from opentelemetry import trace
|
||||
|
||||
trace.set_tracer_provider(provider)
|
||||
|
||||
# Instrument aiohttp client
|
||||
AioHttpClientInstrumentor().instrument()
|
||||
|
||||
# we have to call this very late, so that the instrumentation isn't initialized too early
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
|
||||
# Note: In a real integration test, you'd need to configure the ComfyUI
|
||||
# services to export traces to this Jaeger instance. For now, this test
|
||||
# documents the expected behavior.
|
||||
|
||||
# Create a unique prompt to identify our trace
|
||||
test_id = str(uuid.uuid4())[:8]
|
||||
prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1)
|
||||
|
||||
# Get the tracer for the client
|
||||
client_tracer = trace.get_tracer("test_tracing_integration")
|
||||
|
||||
# Submit the workflow - wrap in a span to capture the trace ID
|
||||
with client_tracer.start_as_current_span("submit_workflow") as workflow_span:
|
||||
trace_id = format(workflow_span.get_span_context().trace_id, '032x')
|
||||
logger.info(f"Started trace with trace_id: {trace_id}")
|
||||
|
||||
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
||||
logger.info(f"Submitting workflow with test_id: {test_id}")
|
||||
|
||||
# Queue the prompt with async response
|
||||
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||
assert task_id is not None, "Failed to get task ID"
|
||||
|
||||
logger.info(f"Queued task: {task_id}")
|
||||
|
||||
# Poll for completion
|
||||
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||
assert status_code == 200, f"Task failed with status {status_code}"
|
||||
logger.info("Task completed successfully")
|
||||
|
||||
# Give Jaeger time to receive and process spans
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Query Jaeger for traces
|
||||
# Note: The actual service names depend on how your services are configured
|
||||
# Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker"
|
||||
|
||||
expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names
|
||||
|
||||
logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...")
|
||||
|
||||
# First, try to find our specific trace by trace_id from the client service
|
||||
our_trace = None
|
||||
for service in expected_services:
|
||||
try:
|
||||
traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m")
|
||||
if traces_response.get("data"):
|
||||
logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'")
|
||||
for trace in traces_response["data"]:
|
||||
if trace.get("traceID") == trace_id:
|
||||
our_trace = trace
|
||||
logger.info(f"Found our trace in service '{service}'")
|
||||
break
|
||||
if our_trace:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not query traces for service '{service}': {e}")
|
||||
|
||||
# Assert we can find the trace we just created
|
||||
assert our_trace is not None, (
|
||||
f"Could not find trace with trace_id {trace_id} in Jaeger. "
|
||||
f"This indicates that spans from comfyui-client are not being exported correctly."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully found trace with trace_id {trace_id}")
|
||||
|
||||
# Extract services from the trace
|
||||
trace_services = set()
|
||||
for span in our_trace.get("spans", []):
|
||||
process_id = span.get("processID")
|
||||
if process_id:
|
||||
process = our_trace.get("processes", {}).get(process_id, {})
|
||||
service_name = process.get("serviceName")
|
||||
if service_name:
|
||||
trace_services.add(service_name)
|
||||
|
||||
logger.info(f"Services found in trace: {trace_services}")
|
||||
|
||||
# Assert that comfyui-client service is present (since we instrumented it)
|
||||
assert "comfyui-client" in trace_services, (
|
||||
f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. "
|
||||
f"This indicates the client instrumentation is not working."
|
||||
)
|
||||
|
||||
# Validate trace structure
|
||||
logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans")
|
||||
|
||||
# Log all spans for debugging
|
||||
for span in our_trace.get("spans", []):
|
||||
process_id = span.get("processID")
|
||||
process = our_trace.get("processes", {}).get(process_id, {})
|
||||
service_name = process.get("serviceName", "unknown")
|
||||
operation = span.get("operationName", "unknown")
|
||||
logger.info(f" Span: {service_name}.{operation}")
|
||||
|
||||
# Verify trace continuity - only if both services are present
|
||||
assert "comfyui" in trace_services
|
||||
is_continuous = verify_trace_continuity(our_trace, expected_services)
|
||||
|
||||
# This assertion documents what SHOULD happen when distributed tracing works
|
||||
assert is_continuous, (
|
||||
"Trace does not show proper distributed tracing. "
|
||||
"Expected to see spans from multiple services with parent-child relationships. "
|
||||
"This indicates that trace context is not being propagated correctly through RabbitMQ."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq):
|
||||
"""
|
||||
Test that HTTP requests include traceparent headers.
|
||||
|
||||
This validates that the HTTP layer is properly instrumented for tracing.
|
||||
"""
|
||||
server_address = frontend_backend_worker_with_rabbitmq
|
||||
|
||||
# Make a simple HTTP request and check for trace headers
|
||||
# Note: We're checking the server's response headers to see if it's trace-aware
|
||||
response = requests.get(f"{server_address}/system_stats")
|
||||
|
||||
logger.info(f"Response headers: {dict(response.headers)}")
|
||||
|
||||
# The server should be instrumented and may include trace context in responses
|
||||
# or at minimum, should accept traceparent headers in requests
|
||||
|
||||
# Test sending a traceparent header
|
||||
test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
|
||||
response_with_trace = requests.get(
|
||||
f"{server_address}/system_stats",
|
||||
headers={"traceparent": test_traceparent}
|
||||
)
|
||||
|
||||
# Should not error when traceparent is provided
|
||||
assert response_with_trace.status_code == 200, "Server should accept traceparent header"
|
||||
|
||||
logger.info("✓ Server accepts traceparent headers in HTTP requests")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container):
|
||||
"""
|
||||
Test that multiple independent requests create separate traces.
|
||||
|
||||
This validates that trace context is properly scoped per request.
|
||||
"""
|
||||
server_address = frontend_backend_worker_with_rabbitmq
|
||||
|
||||
# Submit multiple workflows
|
||||
task_ids = []
|
||||
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
||||
for i in range(3):
|
||||
prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1)
|
||||
task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async")
|
||||
task_ids.append(task_id)
|
||||
logger.info(f"Queued task {i}: {task_id}")
|
||||
|
||||
# Wait for all to complete
|
||||
for i, task_id in enumerate(task_ids):
|
||||
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0)
|
||||
assert status_code == 200, f"Task {i} failed"
|
||||
logger.info(f"Task {i} completed")
|
||||
|
||||
# Give Jaeger time to receive spans
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Query Jaeger and verify we have multiple distinct traces
|
||||
jaeger_url = jaeger_container.get_query_url()
|
||||
|
||||
try:
|
||||
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10)
|
||||
traces = traces_response.get("data", [])
|
||||
|
||||
if len(traces) >= 2:
|
||||
# Get trace IDs
|
||||
trace_ids = [trace.get("traceID") for trace in traces]
|
||||
unique_trace_ids = set(trace_ids)
|
||||
|
||||
logger.info(f"Found {len(unique_trace_ids)} unique traces")
|
||||
|
||||
# Verify we have multiple distinct traces
|
||||
assert len(unique_trace_ids) >= 2, (
|
||||
f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. "
|
||||
"Each request should create its own trace."
|
||||
)
|
||||
|
||||
logger.info("✓ Multiple requests created distinct traces")
|
||||
else:
|
||||
pytest.skip("Not enough traces to validate")
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not query Jaeger: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container):
|
||||
"""
|
||||
Test that traces include RabbitMQ publish/consume operations.
|
||||
|
||||
This is critical for distributed tracing - the RabbitMQ operations
|
||||
are what link the frontend and backend spans together.
|
||||
"""
|
||||
server_address = frontend_backend_worker_with_rabbitmq
|
||||
jaeger_url = jaeger_container.get_query_url()
|
||||
|
||||
# Submit a workflow
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
async with AsyncRemoteComfyClient(server_address=server_address) as client:
|
||||
prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1)
|
||||
task_id = await client.queue_and_forget_prompt_api(prompt)
|
||||
status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60)
|
||||
assert status_code == 200
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
try:
|
||||
traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m")
|
||||
traces = traces_response.get("data", [])
|
||||
|
||||
if traces:
|
||||
# Look for RabbitMQ-related operations in any trace
|
||||
rabbitmq_operations = [
|
||||
"publish", "consume", "amq_queue_publish", "amq_queue_consume",
|
||||
"amq.basic.publish", "amq.basic.consume", "send", "receive"
|
||||
]
|
||||
|
||||
found_rabbitmq_ops = []
|
||||
for trace in traces:
|
||||
for span in trace.get("spans", []):
|
||||
op_name = span.get("operationName", "").lower()
|
||||
for rmq_op in rabbitmq_operations:
|
||||
if rmq_op in op_name:
|
||||
found_rabbitmq_ops.append(op_name)
|
||||
|
||||
if found_rabbitmq_ops:
|
||||
logger.info(f"✓ Found RabbitMQ operations in traces: {set(found_rabbitmq_ops)}")
|
||||
else:
|
||||
logger.warning(
|
||||
"No RabbitMQ operations found in traces. "
|
||||
"This suggests that either:\n"
|
||||
"1. AioPikaInstrumentor is not creating spans, or\n"
|
||||
"2. The spans are being filtered out by the collector, or\n"
|
||||
"3. The spans exist but use different operation names"
|
||||
)
|
||||
|
||||
# Log all operation names to help debug
|
||||
all_ops = set()
|
||||
for trace in traces[:3]: # First 3 traces
|
||||
for span in trace.get("spans", []):
|
||||
all_ops.add(span.get("operationName"))
|
||||
logger.info(f"Sample operation names: {all_ops}")
|
||||
else:
|
||||
pytest.skip("No traces found")
|
||||
except Exception as e:
|
||||
pytest.skip(f"Could not query Jaeger: {e}")
|
||||
Loading…
Reference in New Issue
Block a user