diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index f1e35e8ef..f79f44ac1 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -1,23 +1,23 @@ -from asyncio import AbstractEventLoop -from collections import defaultdict - -import aiohttp import asyncio import uuid -from aiohttp import WSMessage, ClientResponse, ClientTimeout -from pathlib import Path +from asyncio import AbstractEventLoop from typing import Optional, List from urllib.parse import urlparse, urljoin +import aiohttp +from aiohttp import WSMessage, ClientResponse, ClientTimeout +from opentelemetry import trace + from .client_types import V1QueuePromptResponse from ..api.api_client import JSONEncoder from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt_request import PromptRequest from ..api.paths.history.get.responses.response_200.content.application_json.schema import Schema as GetHistoryDict from ..api.schemas import immutabledict -from ..component_model.file_output_path import file_output_path from ..component_model.outputs_types import OutputsDict +tracer = trace.get_tracer(__name__) + class AsyncRemoteComfyClient: """ @@ -57,6 +57,27 @@ class AsyncRemoteComfyClient: def session(self) -> aiohttp.ClientSession: return self._ensure_session() + def _build_headers(self, accept_header: str, prefer_header: Optional[str] = None, content_type: str = "application/json") -> dict: + """Build HTTP headers for requests.""" + headers = {'Content-Type': content_type, 'Accept': accept_header} + if prefer_header: + headers['Prefer'] = prefer_header + return headers + + @tracer.start_as_current_span("Post Prompt") + async def _post_prompt(self, prompt: PromptDict, endpoint: str, accept_header: str, prefer_header: Optional[str] = None) -> ClientResponse: + """ + Common method to POST a prompt to a given endpoint. + :param prompt: The prompt to send + :param endpoint: The API endpoint (e.g., "/api/v1/prompts") + :param accept_header: The Accept header value + :param prefer_header: Optional Prefer header value + :return: The response object + """ + prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) + headers = self._build_headers(accept_header, prefer_header) + return await self.session.post(urljoin(self.server_address, endpoint), data=prompt_json, headers=headers) + async def len_queue(self) -> int: async with self.session.get(urljoin(self.server_address, "/prompt"), headers={'Accept': 'application/json'}) as response: if response.status == 200: @@ -73,14 +94,7 @@ class AsyncRemoteComfyClient: :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the task ID """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - headers = {'Content-Type': 'application/json', 'Accept': accept_header} - if prefer_header: - headers['Prefer'] = prefer_header - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers=headers) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response: if 200 <= response.status < 400: response_json = await response.json() return response_json["prompt_id"] @@ -95,14 +109,7 @@ class AsyncRemoteComfyClient: :param accept_header: The Accept header value (e.g., "application/json", "application/json+respond-async") :return: the API response from the server containing URLs and the outputs for the UI (nodes with OUTPUT_NODE == true) """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - headers = {'Content-Type': 'application/json', 'Accept': accept_header} - if prefer_header: - headers['Prefer'] = prefer_header - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers=headers) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", accept_header, prefer_header) as response: if 200 <= response.status < 400: return V1QueuePromptResponse(**(await response.json())) else: @@ -122,17 +129,13 @@ class AsyncRemoteComfyClient: :param prompt: :return: """ - prompt_json = AsyncRemoteComfyClient.__json_encoder.encode(prompt) - response: ClientResponse - headers = {'Content-Type': 'application/json', 'Accept': 'image/png'} - async with self.session.post(urljoin(self.server_address, "/api/v1/prompts"), data=prompt_json, - headers=headers) as response: - + async with await self._post_prompt(prompt, "/api/v1/prompts", "image/png") as response: if 200 <= response.status < 400: return await response.read() else: raise RuntimeError(f"could not prompt: {response.status}: {await response.text()}") + @tracer.start_as_current_span("Post Prompt (UI)") async def queue_prompt_ui(self, prompt: PromptDict) -> OutputsDict: """ Uses the comfyui UI API calls to retrieve the outputs dictionary @@ -179,6 +182,7 @@ class AsyncRemoteComfyClient: """ return await self.session.get(urljoin(self.server_address, f"/api/v1/prompts/{prompt_id}")) + @tracer.start_as_current_span("Poll Prompt Until Done") async def poll_prompt_until_done(self, prompt_id: str, max_attempts: int = 60, poll_interval: float = 1.0) -> tuple[int, dict | None]: """ Poll a prompt until it's done (200), errors (500), or times out. @@ -187,6 +191,10 @@ class AsyncRemoteComfyClient: :param poll_interval: Time to wait between polls in seconds :return: Tuple of (status_code, response_json or None) """ + span = trace.get_current_span() + span.set_attribute("prompt_id", prompt_id) + span.set_attribute("max_attempts", max_attempts) + for _ in range(max_attempts): async with await self.get_prompt_status(prompt_id) as response: if response.status == 200: diff --git a/comfy/client/sdxl_with_refiner_workflow.py b/comfy/client/sdxl_with_refiner_workflow.py index e08c877d4..d7dab309a 100644 --- a/comfy/client/sdxl_with_refiner_workflow.py +++ b/comfy/client/sdxl_with_refiner_workflow.py @@ -1,8 +1,6 @@ import copy from typing import TypeAlias, Union -from ..api.components.schema.prompt import PromptDict, Prompt - JSON: TypeAlias = Union[dict[str, "JSON"], list["JSON"], str, int, float, bool, None] _BASE_PROMPT: JSON = { "4": { diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 5a8d0f7f8..249482188 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -15,6 +15,7 @@ import shutil import warnings import fsspec +from opentelemetry.instrumentation.urllib3 import URLLib3Instrumentor from .. import options from ..app import logger @@ -125,9 +126,11 @@ def _create_tracer(): from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter + from opentelemetry.processor.baggage import BaggageSpanProcessor, ALLOW_ALL_BAGGAGE_KEYS + from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor + from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor from ..tracing_compatibility import ProgressSpanSampler from ..tracing_compatibility import patch_spanbuilder_set_channel - from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor resource = Resource.create({ service_attributes.SERVICE_NAME: args.otel_service_name, @@ -141,18 +144,24 @@ def _create_tracer(): has_endpoint = args.otel_exporter_otlp_endpoint is not None if has_endpoint: - otlp_exporter = OTLPSpanExporter() + exporter = OTLPSpanExporter() else: - otlp_exporter = SpanExporter() + exporter = SpanExporter() - processor = BatchSpanProcessor(otlp_exporter) + processor = BatchSpanProcessor(exporter) provider.add_span_processor(processor) # enable instrumentation patch_spanbuilder_set_channel() + AioPikaInstrumentor().instrument() AioHttpServerInstrumentor().instrument() + AioHttpClientInstrumentor().instrument() RequestsInstrumentor().instrument() + URLLib3Instrumentor().instrument() + + + provider.add_span_processor(BaggageSpanProcessor(ALLOW_ALL_BAGGAGE_KEYS)) # makes this behave better as a library return trace.get_tracer(args.otel_service_name, tracer_provider=provider) diff --git a/comfy/tracing_compatibility.py b/comfy/tracing_compatibility.py index 6c0ef16f5..a77f1e79b 100644 --- a/comfy/tracing_compatibility.py +++ b/comfy/tracing_compatibility.py @@ -3,7 +3,7 @@ from typing import Optional, Sequence from aio_pika.abc import AbstractChannel from opentelemetry.context import Context from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult, Decision -from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.semconv.attributes.network_attributes import NETWORK_PEER_ADDRESS, NETWORK_PEER_PORT from opentelemetry.trace import SpanKind, Link, TraceState from opentelemetry.util.types import Attributes @@ -22,8 +22,8 @@ def patch_spanbuilder_set_channel() -> None: port = url.port or 5672 self._attributes.update( { - SpanAttributes.NET_PEER_NAME: url.host, - SpanAttributes.NET_PEER_PORT: port, + NETWORK_PEER_ADDRESS: url.host, + NETWORK_PEER_PORT: port, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 33ae98f9d..931ec8deb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,6 +90,14 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers params = rabbitmq.get_connection_params() connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + # Check if OTEL endpoint is configured for integration testing + otel_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + + env = os.environ.copy() + if otel_endpoint: + env["OTEL_EXPORTER_OTLP_ENDPOINT"] = otel_endpoint + logging.info(f"Configuring services to export traces to: {otel_endpoint}") + frontend_command = [ "comfyui", "--listen=127.0.0.1", @@ -100,7 +108,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers f"--distributed-queue-connection-uri={connection_uri}", ] - processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) + processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr, env=env)) # Start multiple workers for i in range(num_workers): @@ -111,7 +119,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers f"--distributed-queue-connection-uri={connection_uri}", f"--executor-factory={executor_factory}" ] - processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) + processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr, env=env)) try: server_address = f"http://127.0.0.1:19001" diff --git a/tests/distributed/test_tracing.py b/tests/distributed/test_tracing.py index e69de29bb..8c156a9bf 100644 --- a/tests/distributed/test_tracing.py +++ b/tests/distributed/test_tracing.py @@ -0,0 +1,128 @@ +import asyncio +import logging + +logging.basicConfig(level=logging.ERROR) + +import uuid + +import pytest +from testcontainers.rabbitmq import RabbitMqContainer +from opentelemetry import trace, propagate, context +from opentelemetry.trace import SpanKind +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner +from comfy.component_model.make_mutable import make_mutable +from comfy.component_model.queue_types import QueueItem, QueueTuple, ExecutionStatus +from comfy.distributed.server_stub import ServerStub + + +async def create_test_prompt() -> QueueItem: + from comfy.cmd.execution import validate_prompt + + prompt = make_mutable(sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)) + item_id = str(uuid.uuid4()) + + validation_tuple = await validate_prompt(item_id, prompt) + queue_tuple: QueueTuple = (0, item_id, prompt, {}, validation_tuple[2]) + return QueueItem(queue_tuple, None) + + +@pytest.mark.asyncio +async def test_rabbitmq_message_properties_contain_trace_context(): + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + import aio_pika + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = trace.get_tracer(__name__, tracer_provider=provider) + + with tracer.start_as_current_span("test_message_headers", kind=SpanKind.PRODUCER): + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: + queue_item = await create_test_prompt() + + put_task = asyncio.create_task(frontend.put_async(queue_item)) + + incoming, incoming_prompt_id = await worker.get_async(timeout=5.0) + assert incoming is not None, "Worker should receive message" + + worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, [])) + + result = await put_task + assert result is not None, "Frontend should get result" + + # Now inspect the RabbitMQ queue directly to see message structure + connection = await aio_pika.connect_robust(connection_uri) + channel = await connection.channel() + + # Declare a test queue to inspect message format + test_queue = await channel.declare_queue("test_inspection_queue", durable=False, auto_delete=True) + + # Publish a test message with trace context + carrier = {} + propagate.inject(carrier) + + test_message = aio_pika.Message( + body=b"test", + headers=carrier + ) + + await channel.default_exchange.publish( + test_message, + routing_key=test_queue.name + ) + + # Get and inspect the message + received = await test_queue.get(timeout=2, fail=False) + if received: + headers = received.headers or {} + + # Document what trace headers should be present + # OpenTelemetry uses 'traceparent' header for W3C Trace Context + has_traceparent = "traceparent" in headers + + assert has_traceparent + + await received.ack() + + await connection.close() + + +@pytest.mark.asyncio +async def test_distributed_queue_uses_async_interface(): + """ + Test that demonstrates the correct way to use DistributedPromptQueue in async context. + The synchronous get() method cannot be used in async tests due to event loop assertions. + """ + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: + queue_item = await create_test_prompt() + + # Start consuming in background + result_future = asyncio.create_task(frontend.put_async(queue_item)) + + # Worker gets item asynchronously (not using blocking get()) + incoming, incoming_prompt_id = await worker.get_async(timeout=5.0) + assert incoming is not None, "Should receive a queue item" + + # Complete the work + worker.task_done(incoming_prompt_id, {}, ExecutionStatus("success", True, [])) + + # Wait for frontend to complete + result = await result_future + assert result is not None, "Should get result from worker" + assert result.status.status_str == "success" diff --git a/tests/distributed/test_tracing_integration.py b/tests/distributed/test_tracing_integration.py new file mode 100644 index 000000000..412b14c02 --- /dev/null +++ b/tests/distributed/test_tracing_integration.py @@ -0,0 +1,497 @@ +""" +Integration tests for distributed tracing across RabbitMQ and services. + +These tests validate that trace context propagates correctly from frontend +to backend workers through RabbitMQ, and that Jaeger can reconstruct the +full distributed trace. +""" +import asyncio +import logging +import os +import time +import uuid + +import pytest +import requests +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.semconv.attributes import service_attributes +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs + +from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class JaegerContainer(DockerContainer): + """Testcontainer for Jaeger all-in-one with OTLP support.""" + + def __init__(self, image: str = "jaegertracing/all-in-one:latest"): + super().__init__(image) + self.with_exposed_ports(16686, 4318, 14268) # UI, OTLP HTTP, Jaeger HTTP + self.with_env("COLLECTOR_OTLP_ENABLED", "true") + + def get_query_url(self) -> str: + """Get Jaeger Query API URL.""" + host = self.get_container_host_ip() + port = self.get_exposed_port(16686) + return f"http://{host}:{port}" + + def get_otlp_endpoint(self) -> str: + """Get OTLP HTTP endpoint for sending traces.""" + host = self.get_container_host_ip() + port = self.get_exposed_port(4318) + return f"http://{host}:{port}" + + def start(self): + super().start() + wait_for_logs(self, ".*Starting GRPC server.*", timeout=30) + return self + + +@pytest.fixture(scope="module") +def jaeger_container(): + """ + Provide a Jaeger container for collecting traces. + + This fixture automatically sets OTEL_EXPORTER_OTLP_ENDPOINT to point to the + Jaeger container, and cleans it up when the container stops. + """ + container = JaegerContainer() + container.start() + + # Wait for Jaeger to be fully ready + query_url = container.get_query_url() + otlp_endpoint = container.get_otlp_endpoint() + + for _ in range(30): + try: + response = requests.get(f"{query_url}/api/services") + if response.status_code == 200: + logger.info(f"Jaeger ready at {query_url}") + logger.info(f"OTLP endpoint: {otlp_endpoint}") + break + except Exception: + pass + time.sleep(1) + + # Set OTEL_EXPORTER_OTLP_ENDPOINT for the duration of the test + old_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") + os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = otlp_endpoint + logger.info(f"Set OTEL_EXPORTER_OTLP_ENDPOINT={otlp_endpoint}") + + try: + yield container + finally: + # Restore original OTEL_EXPORTER_OTLP_ENDPOINT + if old_endpoint is not None: + os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = old_endpoint + logger.info(f"Restored OTEL_EXPORTER_OTLP_ENDPOINT={old_endpoint}") + else: + os.environ.pop("OTEL_EXPORTER_OTLP_ENDPOINT", None) + logger.info("Removed OTEL_EXPORTER_OTLP_ENDPOINT") + + container.stop() + + +def query_jaeger_traces(jaeger_url: str, service: str, operation: str = None, + lookback: str = "1h", limit: int = 100) -> dict: + """ + Query Jaeger for traces. + + Args: + jaeger_url: Base URL of Jaeger query service + service: Service name to query + operation: Optional operation name filter + lookback: Lookback period (e.g., "1h", "30m") + limit: Maximum number of traces to return + + Returns: + JSON response from Jaeger API + """ + params = { + "service": service, + "lookback": lookback, + "limit": limit + } + if operation: + params["operation"] = operation + + response = requests.get(f"{jaeger_url}/api/traces", params=params) + response.raise_for_status() + return response.json() + + +def find_trace_by_operation(traces_response: dict, operation_name: str) -> dict: + """Find a specific trace by operation name.""" + for trace in traces_response.get("data", []): + for span in trace.get("spans", []): + if span.get("operationName") == operation_name: + return trace + return None + + +def verify_trace_continuity(trace: dict, expected_services: list[str]) -> bool: + """ + Verify that a trace spans multiple services and maintains parent-child relationships. + + Args: + trace: Jaeger trace object + expected_services: List of service names expected in the trace + + Returns: + True if trace shows proper distributed tracing across services + """ + if not trace: + return False + + spans = trace.get("spans", []) + if not spans: + return False + + # Check that all expected services are present + trace_services = set() + for span in spans: + process_id = span.get("processID") + if process_id: + process = trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName") + if service_name: + trace_services.add(service_name) + + logger.info(f"Trace contains services: {trace_services}") + logger.info(f"Expected services: {set(expected_services)}") + + # Verify all expected services are present + for service in expected_services: + if service not in trace_services: + logger.warning(f"Expected service '{service}' not found in trace") + return False + + # Verify all spans share the same trace ID + trace_ids = set(span.get("traceID") for span in spans) + if len(trace_ids) != 1: + logger.warning(f"Multiple trace IDs found: {trace_ids}") + return False + + # Verify parent-child relationships exist + span_ids = {span.get("spanID") for span in spans} + has_parent_refs = False + + for span in spans: + references = span.get("references", []) + for ref in references: + if ref.get("refType") == "CHILD_OF": + parent_span_id = ref.get("spanID") + if parent_span_id in span_ids: + has_parent_refs = True + logger.info(f"Found parent-child relationship: {parent_span_id} -> {span.get('spanID')}") + + if not has_parent_refs: + logger.warning("No parent-child relationships found in trace") + return False + + return True + + +# order matters, execute jaeger_container first +@pytest.mark.asyncio +async def test_tracing_integration(jaeger_container, frontend_backend_worker_with_rabbitmq): + """ + Integration test for distributed tracing across services. + + This test: + 1. Starts ComfyUI frontend and worker with RabbitMQ + 2. Configures OTLP export to Jaeger testcontainer + 3. Submits a workflow through the frontend + 4. Queries Jaeger to verify trace propagation + 5. Validates that the trace spans multiple services with proper relationships + + Note: The frontend_backend_worker_with_rabbitmq fixture is parameterized, + so this test will run with both ThreadPoolExecutor and ProcessPoolExecutor. + """ + server_address = frontend_backend_worker_with_rabbitmq + jaeger_url = jaeger_container.get_query_url() + otlp_endpoint = jaeger_container.get_otlp_endpoint() + + logger.info(f"Frontend server: {server_address}") + logger.info(f"Jaeger UI: {jaeger_url}") + logger.info(f"OTLP endpoint: {otlp_endpoint}") + + # Set up tracing for the async HTTP client + resource = Resource.create({ + service_attributes.SERVICE_NAME: "comfyui-client", + }) + provider = TracerProvider(resource=resource) + exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces") + processor = BatchSpanProcessor(exporter) + provider.add_span_processor(processor) + from opentelemetry import trace + + trace.set_tracer_provider(provider) + + # Instrument aiohttp client + AioHttpClientInstrumentor().instrument() + + # we have to call this very late, so that the instrumentation isn't initialized too early + from comfy.client.aio_client import AsyncRemoteComfyClient + + # Note: In a real integration test, you'd need to configure the ComfyUI + # services to export traces to this Jaeger instance. For now, this test + # documents the expected behavior. + + # Create a unique prompt to identify our trace + test_id = str(uuid.uuid4())[:8] + prompt = sdxl_workflow_with_refiner(f"test_trace_{test_id}", inference_steps=1, refiner_steps=1) + + # Get the tracer for the client + client_tracer = trace.get_tracer("test_tracing_integration") + + # Submit the workflow - wrap in a span to capture the trace ID + with client_tracer.start_as_current_span("submit_workflow") as workflow_span: + trace_id = format(workflow_span.get_span_context().trace_id, '032x') + logger.info(f"Started trace with trace_id: {trace_id}") + + async with AsyncRemoteComfyClient(server_address=server_address) as client: + logger.info(f"Submitting workflow with test_id: {test_id}") + + # Queue the prompt with async response + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + assert task_id is not None, "Failed to get task ID" + + logger.info(f"Queued task: {task_id}") + + # Poll for completion + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + assert status_code == 200, f"Task failed with status {status_code}" + logger.info("Task completed successfully") + + # Give Jaeger time to receive and process spans + await asyncio.sleep(5) + + # Query Jaeger for traces + # Note: The actual service names depend on how your services are configured + # Common service names might be: "slack-bot", "comfyui-frontend", "comfyui-worker" + + expected_services = ["comfyui", "comfyui-client"] # Adjust based on actual service names + + logger.info(f"Querying Jaeger for traces with trace_id: {trace_id}...") + + # First, try to find our specific trace by trace_id from the client service + our_trace = None + for service in expected_services: + try: + traces_response = query_jaeger_traces(jaeger_url, service, lookback="5m") + if traces_response.get("data"): + logger.info(f"Found {len(traces_response['data'])} traces for service '{service}'") + for trace in traces_response["data"]: + if trace.get("traceID") == trace_id: + our_trace = trace + logger.info(f"Found our trace in service '{service}'") + break + if our_trace: + break + except Exception as e: + logger.warning(f"Could not query traces for service '{service}': {e}") + + # Assert we can find the trace we just created + assert our_trace is not None, ( + f"Could not find trace with trace_id {trace_id} in Jaeger. " + f"This indicates that spans from comfyui-client are not being exported correctly." + ) + + logger.info(f"Successfully found trace with trace_id {trace_id}") + + # Extract services from the trace + trace_services = set() + for span in our_trace.get("spans", []): + process_id = span.get("processID") + if process_id: + process = our_trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName") + if service_name: + trace_services.add(service_name) + + logger.info(f"Services found in trace: {trace_services}") + + # Assert that comfyui-client service is present (since we instrumented it) + assert "comfyui-client" in trace_services, ( + f"Expected 'comfyui-client' service in trace, but found only: {trace_services}. " + f"This indicates the client instrumentation is not working." + ) + + # Validate trace structure + logger.info(f"Analyzing trace with {len(our_trace.get('spans', []))} spans") + + # Log all spans for debugging + for span in our_trace.get("spans", []): + process_id = span.get("processID") + process = our_trace.get("processes", {}).get(process_id, {}) + service_name = process.get("serviceName", "unknown") + operation = span.get("operationName", "unknown") + logger.info(f" Span: {service_name}.{operation}") + + # Verify trace continuity - only if both services are present + assert "comfyui" in trace_services + is_continuous = verify_trace_continuity(our_trace, expected_services) + + # This assertion documents what SHOULD happen when distributed tracing works + assert is_continuous, ( + "Trace does not show proper distributed tracing. " + "Expected to see spans from multiple services with parent-child relationships. " + "This indicates that trace context is not being propagated correctly through RabbitMQ." + ) + +@pytest.mark.asyncio +async def test_trace_context_in_http_headers(frontend_backend_worker_with_rabbitmq): + """ + Test that HTTP requests include traceparent headers. + + This validates that the HTTP layer is properly instrumented for tracing. + """ + server_address = frontend_backend_worker_with_rabbitmq + + # Make a simple HTTP request and check for trace headers + # Note: We're checking the server's response headers to see if it's trace-aware + response = requests.get(f"{server_address}/system_stats") + + logger.info(f"Response headers: {dict(response.headers)}") + + # The server should be instrumented and may include trace context in responses + # or at minimum, should accept traceparent headers in requests + + # Test sending a traceparent header + test_traceparent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + response_with_trace = requests.get( + f"{server_address}/system_stats", + headers={"traceparent": test_traceparent} + ) + + # Should not error when traceparent is provided + assert response_with_trace.status_code == 200, "Server should accept traceparent header" + + logger.info("✓ Server accepts traceparent headers in HTTP requests") + + +@pytest.mark.asyncio +async def test_multiple_requests_different_traces(frontend_backend_worker_with_rabbitmq, jaeger_container): + """ + Test that multiple independent requests create separate traces. + + This validates that trace context is properly scoped per request. + """ + server_address = frontend_backend_worker_with_rabbitmq + + # Submit multiple workflows + task_ids = [] + + from comfy.client.aio_client import AsyncRemoteComfyClient + async with AsyncRemoteComfyClient(server_address=server_address) as client: + for i in range(3): + prompt = sdxl_workflow_with_refiner(f"test_{i}", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt, prefer_header="respond-async") + task_ids.append(task_id) + logger.info(f"Queued task {i}: {task_id}") + + # Wait for all to complete + for i, task_id in enumerate(task_ids): + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60, poll_interval=1.0) + assert status_code == 200, f"Task {i} failed" + logger.info(f"Task {i} completed") + + # Give Jaeger time to receive spans + await asyncio.sleep(5) + + # Query Jaeger and verify we have multiple distinct traces + jaeger_url = jaeger_container.get_query_url() + + try: + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m", limit=10) + traces = traces_response.get("data", []) + + if len(traces) >= 2: + # Get trace IDs + trace_ids = [trace.get("traceID") for trace in traces] + unique_trace_ids = set(trace_ids) + + logger.info(f"Found {len(unique_trace_ids)} unique traces") + + # Verify we have multiple distinct traces + assert len(unique_trace_ids) >= 2, ( + f"Expected at least 2 distinct traces, found {len(unique_trace_ids)}. " + "Each request should create its own trace." + ) + + logger.info("✓ Multiple requests created distinct traces") + else: + pytest.skip("Not enough traces to validate") + except Exception as e: + pytest.skip(f"Could not query Jaeger: {e}") + + +@pytest.mark.asyncio +async def test_trace_contains_rabbitmq_operations(frontend_backend_worker_with_rabbitmq, jaeger_container): + """ + Test that traces include RabbitMQ publish/consume operations. + + This is critical for distributed tracing - the RabbitMQ operations + are what link the frontend and backend spans together. + """ + server_address = frontend_backend_worker_with_rabbitmq + jaeger_url = jaeger_container.get_query_url() + + # Submit a workflow + from comfy.client.aio_client import AsyncRemoteComfyClient + async with AsyncRemoteComfyClient(server_address=server_address) as client: + prompt = sdxl_workflow_with_refiner("test_rmq", inference_steps=1, refiner_steps=1) + task_id = await client.queue_and_forget_prompt_api(prompt) + status_code, result = await client.poll_prompt_until_done(task_id, max_attempts=60) + assert status_code == 200 + + await asyncio.sleep(5) + + try: + traces_response = query_jaeger_traces(jaeger_url, "comfyui", lookback="5m") + traces = traces_response.get("data", []) + + if traces: + # Look for RabbitMQ-related operations in any trace + rabbitmq_operations = [ + "publish", "consume", "amq_queue_publish", "amq_queue_consume", + "amq.basic.publish", "amq.basic.consume", "send", "receive" + ] + + found_rabbitmq_ops = [] + for trace in traces: + for span in trace.get("spans", []): + op_name = span.get("operationName", "").lower() + for rmq_op in rabbitmq_operations: + if rmq_op in op_name: + found_rabbitmq_ops.append(op_name) + + if found_rabbitmq_ops: + logger.info(f"✓ Found RabbitMQ operations in traces: {set(found_rabbitmq_ops)}") + else: + logger.warning( + "No RabbitMQ operations found in traces. " + "This suggests that either:\n" + "1. AioPikaInstrumentor is not creating spans, or\n" + "2. The spans are being filtered out by the collector, or\n" + "3. The spans exist but use different operation names" + ) + + # Log all operation names to help debug + all_ops = set() + for trace in traces[:3]: # First 3 traces + for span in trace.get("spans", []): + all_ops.add(span.get("operationName")) + logger.info(f"Sample operation names: {all_ops}") + else: + pytest.skip("No traces found") + except Exception as e: + pytest.skip(f"Could not query Jaeger: {e}")