diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index c3cdd7c9a..0c84a866b 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -15,7 +15,6 @@ import warnings from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor -from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SpanExporter @@ -24,6 +23,7 @@ from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs from .. import options from ..tracing_compatibility import ProgressSpanSampler from ..tracing_compatibility import patch_spanbuilder_set_channel +from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor options.enable_args_parsing() if os.name == "nt": diff --git a/comfy/vendor/aiohttp_server_instrumentation.py b/comfy/vendor/aiohttp_server_instrumentation.py new file mode 100644 index 000000000..12fb801e5 --- /dev/null +++ b/comfy/vendor/aiohttp_server_instrumentation.py @@ -0,0 +1,266 @@ +# Copyright 2020, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import urllib +from timeit import default_timer +from typing import Dict, List, Tuple, Union + +from aiohttp import web +from multidict import CIMultiDictProxy +from opentelemetry import context, metrics, trace +from opentelemetry.context import _SUPPRESS_HTTP_INSTRUMENTATION_KEY +_instruments = ("aiohttp ~= 3.0",) +__version__ = "0.46b0.dev" +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.utils import http_status_to_status_code +from opentelemetry.propagate import extract +from opentelemetry.propagators.textmap import Getter +from opentelemetry.semconv.metrics import MetricInstruments +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace.status import Status, StatusCode +from opentelemetry.util.http import get_excluded_urls, remove_url_credentials + +_duration_attrs = [ + SpanAttributes.HTTP_METHOD, + SpanAttributes.HTTP_HOST, + SpanAttributes.HTTP_SCHEME, + SpanAttributes.HTTP_STATUS_CODE, + SpanAttributes.HTTP_FLAVOR, + SpanAttributes.HTTP_SERVER_NAME, + SpanAttributes.NET_HOST_NAME, + SpanAttributes.NET_HOST_PORT, + SpanAttributes.HTTP_ROUTE, +] + +_active_requests_count_attrs = [ + SpanAttributes.HTTP_METHOD, + SpanAttributes.HTTP_HOST, + SpanAttributes.HTTP_SCHEME, + SpanAttributes.HTTP_FLAVOR, + SpanAttributes.HTTP_SERVER_NAME, +] + +tracer = trace.get_tracer(__name__) +meter = metrics.get_meter(__name__, __version__) +_excluded_urls = get_excluded_urls("AIOHTTP_SERVER") + + +def _parse_duration_attrs(req_attrs): + duration_attrs = {} + for attr_key in _duration_attrs: + if req_attrs.get(attr_key) is not None: + duration_attrs[attr_key] = req_attrs[attr_key] + return duration_attrs + + +def _parse_active_request_count_attrs(req_attrs): + active_requests_count_attrs = {} + for attr_key in _active_requests_count_attrs: + if req_attrs.get(attr_key) is not None: + active_requests_count_attrs[attr_key] = req_attrs[attr_key] + return active_requests_count_attrs + + +def get_default_span_details(request: web.Request) -> Tuple[str, dict]: + """Default implementation for get_default_span_details + Args: + request: the request object itself. + Returns: + a tuple of the span name, and any attributes to attach to the span. + """ + span_name = request.path.strip() or f"HTTP {request.method}" + return span_name, {} + + +def _get_view_func(request: web.Request) -> str: + """Returns the name of the request handler. + Args: + request: the request object itself. + Returns: + a string containing the name of the handler function + """ + try: + return request.match_info.handler.__name__ + except AttributeError: + return "unknown" + + +def collect_request_attributes(request: web.Request) -> Dict: + """Collects HTTP request attributes from the ASGI scope and returns a + dictionary to be used as span creation attributes.""" + + server_host, port, http_url = ( + request.url.host, + request.url.port, + str(request.url), + ) + query_string = request.query_string + if query_string and http_url: + if isinstance(query_string, bytes): + query_string = query_string.decode("utf8") + http_url += "?" + urllib.parse.unquote(query_string) + + result = { + SpanAttributes.HTTP_SCHEME: request.scheme, + SpanAttributes.HTTP_HOST: server_host, + SpanAttributes.NET_HOST_PORT: port, + SpanAttributes.HTTP_ROUTE: _get_view_func(request), + SpanAttributes.HTTP_FLAVOR: f"{request.version.major}.{request.version.minor}", + SpanAttributes.HTTP_TARGET: request.path, + SpanAttributes.HTTP_URL: remove_url_credentials(http_url), + } + + http_method = request.method + if http_method: + result[SpanAttributes.HTTP_METHOD] = http_method + + http_host_value_list = ( + [request.host] if not isinstance(request.host, list) else request.host + ) + if http_host_value_list: + result[SpanAttributes.HTTP_SERVER_NAME] = ",".join( + http_host_value_list + ) + http_user_agent = request.headers.get("user-agent") + if http_user_agent: + result[SpanAttributes.HTTP_USER_AGENT] = http_user_agent + + # remove None values + result = {k: v for k, v in result.items() if v is not None} + + return result + + +def set_status_code(span, status_code: int) -> None: + """Adds HTTP response attributes to span using the status_code argument.""" + + try: + status_code = int(status_code) + except ValueError: + span.set_status( + Status( + StatusCode.ERROR, + "Non-integer HTTP status: " + repr(status_code), + ) + ) + else: + span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status_code) + span.set_status( + Status(http_status_to_status_code(status_code, server_span=True)) + ) + + +class AiohttpGetter(Getter): + """Extract current trace from headers""" + + def get(self, carrier, key: str) -> Union[List, None]: + """Getter implementation to retrieve an HTTP header value from the ASGI + scope. + + Args: + carrier: ASGI scope object + key: header name in scope + Returns: + A list of all header values matching the key, or None if the key + does not match any header. + """ + headers: CIMultiDictProxy = carrier.headers + if not headers: + return None + return headers.getall(key, None) + + def keys(self, carrier: Dict) -> List: + return list(carrier.keys()) + + +getter = AiohttpGetter() + + +@web.middleware +async def middleware(request, handler): + """Middleware for aiohttp implementing tracing logic""" + if ( + context.get_value("suppress_instrumentation") + or context.get_value(_SUPPRESS_HTTP_INSTRUMENTATION_KEY) + or _excluded_urls.url_disabled(request.url.path) + ): + return await handler(request) + + span_name, additional_attributes = get_default_span_details(request) + + req_attrs = collect_request_attributes(request) + duration_attrs = _parse_duration_attrs(req_attrs) + active_requests_count_attrs = _parse_active_request_count_attrs(req_attrs) + + duration_histogram = meter.create_histogram( + name=MetricInstruments.HTTP_SERVER_DURATION, + unit="ms", + description="Duration of HTTP client requests.", + ) + + active_requests_counter = meter.create_up_down_counter( + name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS, + unit="requests", + description="measures the number of concurrent HTTP requests those are currently in flight", + ) + + with tracer.start_as_current_span( + span_name, + context=extract(request, getter=getter), + kind=trace.SpanKind.SERVER, + ) as span: + attributes = collect_request_attributes(request) + attributes.update(additional_attributes) + span.set_attributes(attributes) + start = default_timer() + active_requests_counter.add(1, active_requests_count_attrs) + try: + resp = await handler(request) + set_status_code(span, resp.status) + except web.HTTPException as ex: + set_status_code(span, ex.status_code) + raise + finally: + duration = max((default_timer() - start) * 1000, 0) + duration_histogram.record(duration, duration_attrs) + active_requests_counter.add(-1, active_requests_count_attrs) + return resp + + +class _InstrumentedApplication(web.Application): + """Insert tracing middleware""" + + def __init__(self, *args, **kwargs): + middlewares = kwargs.pop("middlewares", []) + middlewares.insert(0, middleware) + kwargs["middlewares"] = middlewares + super().__init__(*args, **kwargs) + + +class AioHttpServerInstrumentor(BaseInstrumentor): + # pylint: disable=protected-access,attribute-defined-outside-init + """An instrumentor for aiohttp.web.Application + + See `BaseInstrumentor` + """ + + def _instrument(self, **kwargs): + self._original_app = web.Application + setattr(web, "Application", _InstrumentedApplication) + + def _uninstrument(self, **kwargs): + setattr(web, "Application", self._original_app) + + def instrumentation_dependencies(self): + return _instruments diff --git a/custom_nodes/websocket_image_save.py b/custom_nodes/websocket_image_save.py deleted file mode 100644 index 5aa573642..000000000 --- a/custom_nodes/websocket_image_save.py +++ /dev/null @@ -1,45 +0,0 @@ -from PIL import Image, ImageOps -from io import BytesIO -import numpy as np -import struct -import comfy.utils -import time - -#You can use this node to save full size images through the websocket, the -#images will be sent in exactly the same format as the image previews: as -#binary images on the websocket with a 8 byte header indicating the type -#of binary message (first 4 bytes) and the image format (next 4 bytes). - -#Note that no metadata will be put in the images saved with this node. - -class SaveImageWebsocket: - @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ),} - } - - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "api/image" - - def save_images(self, images): - pbar = comfy.utils.ProgressBar(images.shape[0]) - step = 0 - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) - step += 1 - - return {} - - def IS_CHANGED(s, images): - return time.time() - -NODE_CLASS_MAPPINGS = { - "SaveImageWebsocket": SaveImageWebsocket, -} diff --git a/requirements.txt b/requirements.txt index ae2cb825d..9b1ba4120 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,5 +42,8 @@ opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-http opentelemetry-exporter-otlp-proto-grpc opentelemetry-propagator-jaeger +opentelemetry-instrumentation +opentelemetry-util-http opentelemetry-instrumentation-aio-pika -opentelemetry-instrumentation-aiohttp-server @ git+https://github.com/open-telemetry/opentelemetry-python-contrib.git@v0.45b0#subdirectory=instrumentation/opentelemetry-instrumentation-aiohttp-server \ No newline at end of file +opentelemetry-semantic-conventions +wrapt>=1.16.0 \ No newline at end of file diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index 9b6aa0e92..34bf9ce08 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -1,4 +1,5 @@ # This sample shows how to execute a ComfyUI workflow, saving an image file to the location you specify. +# It does not require the server to be run. It runs ComfyUI embedded, as a library. No process is started. # # This script does not need to run within a ComfyUI directory. Instead, this can be used inside your own # Python application or located elsewhere. It should **not** be in the Git repository directory. diff --git a/script_examples/websockets_api_example_ws_images.py b/script_examples/websockets_api_example_ws_images.py deleted file mode 100644 index 737488621..000000000 --- a/script_examples/websockets_api_example_ws_images.py +++ /dev/null @@ -1,159 +0,0 @@ -#This is an example that uses the websockets api and the SaveImageWebsocket node to get images directly without -#them being saved to disk - -import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) -import uuid -import json -import urllib.request -import urllib.parse - -server_address = "127.0.0.1:8188" -client_id = str(uuid.uuid4()) - -def queue_prompt(prompt): - p = {"prompt": prompt, "client_id": client_id} - data = json.dumps(p).encode('utf-8') - req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) - return json.loads(urllib.request.urlopen(req).read()) - -def get_image(filename, subfolder, folder_type): - data = {"filename": filename, "subfolder": subfolder, "type": folder_type} - url_values = urllib.parse.urlencode(data) - with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: - return response.read() - -def get_history(prompt_id): - with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: - return json.loads(response.read()) - -def get_images(ws, prompt): - prompt_id = queue_prompt(prompt)['prompt_id'] - output_images = {} - current_node = "" - while True: - out = ws.recv() - if isinstance(out, str): - message = json.loads(out) - if message['type'] == 'executing': - data = message['data'] - if data['prompt_id'] == prompt_id: - if data['node'] is None: - break #Execution is done - else: - current_node = data['node'] - else: - if current_node == 'save_image_websocket_node': - images_output = output_images.get(current_node, []) - images_output.append(out[8:]) - output_images[current_node] = images_output - - return output_images - -prompt_text = """ -{ - "3": { - "class_type": "KSampler", - "inputs": { - "cfg": 8, - "denoise": 1, - "latent_image": [ - "5", - 0 - ], - "model": [ - "4", - 0 - ], - "negative": [ - "7", - 0 - ], - "positive": [ - "6", - 0 - ], - "sampler_name": "euler", - "scheduler": "normal", - "seed": 8566257, - "steps": 20 - } - }, - "4": { - "class_type": "CheckpointLoaderSimple", - "inputs": { - "ckpt_name": "v1-5-pruned-emaonly.ckpt" - } - }, - "5": { - "class_type": "EmptyLatentImage", - "inputs": { - "batch_size": 1, - "height": 512, - "width": 512 - } - }, - "6": { - "class_type": "CLIPTextEncode", - "inputs": { - "clip": [ - "4", - 1 - ], - "text": "masterpiece best quality girl" - } - }, - "7": { - "class_type": "CLIPTextEncode", - "inputs": { - "clip": [ - "4", - 1 - ], - "text": "bad hands" - } - }, - "8": { - "class_type": "VAEDecode", - "inputs": { - "samples": [ - "3", - 0 - ], - "vae": [ - "4", - 2 - ] - } - }, - "save_image_websocket_node": { - "class_type": "SaveImageWebsocket", - "inputs": { - "images": [ - "8", - 0 - ] - } - } -} -""" - -prompt = json.loads(prompt_text) -#set the text prompt for our positive CLIPTextEncode -prompt["6"]["inputs"]["text"] = "masterpiece best quality man" - -#set the seed for our KSampler node -prompt["3"]["inputs"]["seed"] = 5 - -ws = websocket.WebSocket() -ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) -images = get_images(ws, prompt) - -#Commented out code to display the output images: - -# for node_id in images: -# for image_data in images[node_id]: -# from PIL import Image -# import io -# image = Image.open(io.BytesIO(image_data)) -# image.show() -