diff --git a/comfy/vendor/aiohttp_server_instrumentation.py b/comfy/vendor/aiohttp_server_instrumentation.py index 12fb801e5..61b355237 100644 --- a/comfy/vendor/aiohttp_server_instrumentation.py +++ b/comfy/vendor/aiohttp_server_instrumentation.py @@ -18,12 +18,15 @@ from typing import Dict, List, Tuple, Union from aiohttp import web from multidict import CIMultiDictProxy -from opentelemetry import context, metrics, trace -from opentelemetry.context import _SUPPRESS_HTTP_INSTRUMENTATION_KEY + +from opentelemetry import metrics, trace _instruments = ("aiohttp ~= 3.0",) -__version__ = "0.46b0.dev" +__version__ = "0.49b0.dev" from opentelemetry.instrumentation.instrumentor import BaseInstrumentor -from opentelemetry.instrumentation.utils import http_status_to_status_code +from opentelemetry.instrumentation.utils import ( + http_status_to_status_code, + is_http_instrumentation_enabled, +) from opentelemetry.propagate import extract from opentelemetry.propagators.textmap import Getter from opentelemetry.semconv.metrics import MetricInstruments @@ -152,7 +155,7 @@ def set_status_code(span, status_code: int) -> None: Status( StatusCode.ERROR, "Non-integer HTTP status: " + repr(status_code), - ) + ) ) else: span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status_code) @@ -190,10 +193,8 @@ getter = AiohttpGetter() @web.middleware async def middleware(request, handler): """Middleware for aiohttp implementing tracing logic""" - if ( - context.get_value("suppress_instrumentation") - or context.get_value(_SUPPRESS_HTTP_INSTRUMENTATION_KEY) - or _excluded_urls.url_disabled(request.url.path) + if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled( + request.url.path ): return await handler(request) @@ -206,7 +207,7 @@ async def middleware(request, handler): duration_histogram = meter.create_histogram( name=MetricInstruments.HTTP_SERVER_DURATION, unit="ms", - description="Duration of HTTP client requests.", + description="Measures the duration of inbound HTTP requests.", ) active_requests_counter = meter.create_up_down_counter( diff --git a/tests/asyncio/test_aio_http_bugs.py b/tests/asyncio/test_aio_http_bugs.py new file mode 100644 index 000000000..0ea7a1745 --- /dev/null +++ b/tests/asyncio/test_aio_http_bugs.py @@ -0,0 +1,52 @@ +import asyncio +import pytest +import subprocess +from aiohttp import web + +async def health_check(request): + return web.Response(text="HEALTHY") + +@web.middleware +async def middleware(request, handler): + # Access request.url.path to trigger the potential error + print(f"Accessing path: {request.url.path}") + response = await handler(request) + return response + +async def run_server(): + app = web.Application(middlewares=[middleware]) + app.router.add_get('/health', health_check) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 9090) + await site.start() + print("Server started on http://localhost:9090") + return runner + +@pytest.mark.asyncio +async def test_health_check(): + runner = await run_server() + try: + # Use asyncio.create_subprocess_exec to run curl command + proc = await asyncio.create_subprocess_exec( + 'curl', '-s', 'http://localhost:9090/health', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=5.0) + except asyncio.TimeoutError: + print("Curl request timed out") + proc.kill() + await proc.wait() + return + + if proc.returncode != 0: + print(f"Curl failed with return code {proc.returncode}") + print(f"stderr: {stderr.decode()}") + else: + response = stdout.decode().strip() + assert response == "HEALTHY", f"Unexpected response: {response}" + print("Test passed: Received 'HEALTHY' response") + finally: + await runner.cleanup() \ No newline at end of file