From f8fcfa6f0879394fea244fc13815ea36c3e1d629 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 7 May 2024 13:44:34 -0700 Subject: [PATCH] Improve tracing to propagate to backend workers correctly when using the API. Fix distributed tests. --- comfy/client/embedded_comfy_client.py | 41 ++++++++------ comfy/cmd/execution.py | 25 ++++++--- comfy/cmd/server.py | 37 +++++++------ .../component_model/abstract_prompt_queue.py | 12 ++++- comfy/distributed/distributed_prompt_queue.py | 14 ++--- .../distributed/distributed_prompt_worker.py | 3 ++ comfy/nodes/package.py | 54 +++++++++++-------- comfy/utils.py | 14 ++--- .../test_asyncio_remote_client.py | 0 tests/conftest.py | 2 +- tests/distributed/test_distributed_queue.py | 22 ++++---- tests/distributed/test_embedded_client.py | 3 +- 12 files changed, 133 insertions(+), 94 deletions(-) rename tests/{distributed => asyncio}/test_asyncio_remote_client.py (100%) diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 34e75c04d..a0ffea630 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -7,14 +7,19 @@ from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from typing import Optional +from opentelemetry import context +from opentelemetry.trace import Span, Status, StatusCode + from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration -from ..component_model.make_mutable import make_mutable +from ..cmd.main_pre import tracer from ..component_model.executor_types import ExecutorToClientProgress +from ..component_model.make_mutable import make_mutable from ..distributed.server_stub import ServerStub _server_stub_instance = ServerStub() + class EmbeddedComfyClient: """ Embedded client for comfy executing prompts as a library. @@ -110,30 +115,36 @@ class EmbeddedComfyClient: await self._loop.run_in_executor(self._executor, create_executor_in_thread) + @tracer.start_as_current_span("Queue Prompt") async def queue_prompt(self, prompt: PromptDict | dict, prompt_id: Optional[str] = None, client_id: Optional[str] = None) -> dict: prompt_id = prompt_id or str(uuid.uuid4()) client_id = client_id or self._progress_handler.client_id or None + span_context = context.get_current() def execute_prompt() -> dict: - from ..cmd.execution import PromptExecutor, validate_prompt - prompt_mut = make_mutable(prompt) - validation_tuple = validate_prompt(prompt_mut) + spam: Span + with tracer.start_as_current_span("Execute Prompt", context=span_context) as span: + from ..cmd.execution import PromptExecutor, validate_prompt + prompt_mut = make_mutable(prompt) + validation_tuple = validate_prompt(prompt_mut) - prompt_executor: PromptExecutor = self._prompt_executor + prompt_executor: PromptExecutor = self._prompt_executor - if client_id is None: - prompt_executor.server = _server_stub_instance - else: - prompt_executor.server = self._progress_handler + if client_id is None: + prompt_executor.server = _server_stub_instance + else: + prompt_executor.server = self._progress_handler - prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, - execute_outputs=validation_tuple[2]) - if prompt_executor.success: - return prompt_executor.outputs_ui - else: - raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages)) + prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, + execute_outputs=validation_tuple[2]) + if prompt_executor.success: + return prompt_executor.outputs_ui + else: + span.set_status(Status(StatusCode.ERROR)) + error = RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages)) + span.record_exception(error) return await self._loop.run_in_executor(self._executor, execute_prompt) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index af76f1103..6c7c0afb6 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -9,22 +9,26 @@ import threading import traceback import typing from typing import List, Optional, Tuple, Union -from typing_extensions import TypedDict -import torch -import lazy_object_proxy +import lazy_object_proxy +import torch +from opentelemetry.trace import get_current_span, StatusCode, Status +from typing_extensions import TypedDict + +from .main_pre import tracer from .. import interruption -from ..component_model.abstract_prompt_queue import AbstractPromptQueue -from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus -from ..component_model.executor_types import ExecutorToClientProgress from .. import model_management -from ..nodes.package_typing import ExportedNodes +from ..component_model.abstract_prompt_queue import AbstractPromptQueue +from ..component_model.executor_types import ExecutorToClientProgress +from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..nodes.package import import_all_nodes_in_workspace +from ..nodes.package_typing import ExportedNodes # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # various functions that are declared here. It should have been a context in the first place. nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) + def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None): if extra_data is None: extra_data = {} @@ -139,6 +143,7 @@ def format_value(x): return str(x) +@tracer.start_as_current_span("Recursive Execute") def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, @@ -324,6 +329,10 @@ class PromptExecutor: self.server.send_sync(event, data, self.server.client_id) def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + current_span = get_current_span() + current_span.set_status(Status(StatusCode.ERROR)) + current_span.record_exception(ex) + node_id = error["node_id"] class_type = prompt[node_id]["class_type"] @@ -430,7 +439,6 @@ class PromptExecutor: # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - # todo: if we're using a distributed queue, we must wrap the server instance to correctly communicate back to the client via the exchange self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) @@ -689,6 +697,7 @@ class ValidationErrorDict(TypedDict): ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]] +@tracer.start_as_current_span("Validate Prompt") def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: outputs = set() for x in prompt: diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 79b89bd4a..12f2ba0ab 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -13,17 +13,16 @@ import uuid from asyncio import Future, AbstractEventLoop from enum import Enum from io import BytesIO +from posixpath import join as urljoin from typing import List, Optional, Dict from urllib.parse import quote, urlencode -from posixpath import join as urljoin - -from can_ada import URL, parse as urlparse import aiofiles import aiohttp from PIL import Image from PIL.PngImagePlugin import PngInfo from aiohttp import web +from can_ada import URL, parse as urlparse from pkg_resources import resource_filename from typing_extensions import NamedTuple @@ -36,12 +35,13 @@ from ..cli_args import args from ..client.client_types import Output, FileOutput from ..cmd import execution from ..cmd import folder_paths +from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.file_output_path import file_output_path from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation from ..digest import digest -from ..nodes.package_typing import ExportedNodes from ..images import open_image +from ..nodes.package_typing import ExportedNodes class HeuristicPath(NamedTuple): @@ -95,7 +95,7 @@ class PromptServer(ExecutorToClientProgress): self.user_manager = UserManager() # todo: this is probably read by custom nodes elsewhere self.supports: List[str] = ["custom_nodes_from_web"] - self.prompt_queue: execution.AbstractPromptQueue | None = None + self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None self.loop: AbstractEventLoop = loop self.messages: asyncio.Queue = asyncio.Queue() self.number: int = 0 @@ -599,22 +599,27 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1])) # convert a valid prompt to the queue tuple this expects - completed: Future[TaskInvocation | dict] = self.loop.create_future() number = self.number self.number += 1 - self.prompt_queue.put( - QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), - completed=completed)) - try: - await completed - except Exception as ex: - return web.Response(body=str(ex), status=503) - # expect a single image - result: TaskInvocation | dict = completed.result() + completed: Future[TaskInvocation | dict] = self.loop.create_future() + item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed) + + if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue): + # this enables span propagation seamlessly + result = await self.prompt_queue.put_async(item) + if result is None: + return web.Response(body="the queue is shutting down", status=503) + else: + try: + self.prompt_queue.put(item) + await completed + except Exception as ex: + return web.Response(body=str(ex), status=503) + # expect a single image + result: TaskInvocation | dict = completed.result() outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result # find images and read them - output_images: List[FileOutput] = [] for node_id, node in outputs_dict.items(): images: List[FileOutput] = [] diff --git a/comfy/component_model/abstract_prompt_queue.py b/comfy/component_model/abstract_prompt_queue.py index aac095ca7..7b65d425f 100644 --- a/comfy/component_model/abstract_prompt_queue.py +++ b/comfy/component_model/abstract_prompt_queue.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing from abc import ABCMeta, abstractmethod -from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus +from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation class AbstractPromptQueue(metaclass=ABCMeta): @@ -118,3 +118,13 @@ class AbstractPromptQueue(metaclass=ABCMeta): :return: """ pass + + +class AsyncAbstractPromptQueue(AbstractPromptQueue): + @abstractmethod + async def put_async(self, queue_item) -> TaskInvocation | None: + pass + + @abstractmethod + async def get_async(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]: + pass \ No newline at end of file diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 613c8d867..e4c038332 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -18,13 +18,14 @@ from .distributed_types import RpcRequest, RpcReply from .history import History from .server_stub import ServerStub from ..auth.permissions import jwt_decode +from ..cmd.main_pre import tracer from ..cmd.server import PromptServer -from ..component_model.abstract_prompt_queue import AbstractPromptQueue +from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation -class DistributedPromptQueue(AbstractPromptQueue): +class DistributedPromptQueue(AsyncAbstractPromptQueue): """ A distributed prompt queue for the ComfyUI web client and single-threaded worker. """ @@ -39,12 +40,13 @@ class DistributedPromptQueue(AbstractPromptQueue): async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None: self._caller_server.send_sync(event, data, sid=sid) - async def put_async(self, queue_item: QueueItem): + @tracer.start_as_current_span("Put Async") + async def put_async(self, queue_item: QueueItem) -> TaskInvocation | None: assert self._is_caller assert self._rpc is not None if self._closing: - return + return None self._caller_local_in_progress[queue_item.prompt_id] = queue_item if self._caller_server is not None: self._caller_server.queue_updated() @@ -88,9 +90,7 @@ class DistributedPromptQueue(AbstractPromptQueue): # if we have a completer, propoagate the exception to it if queue_item.completed is not None: queue_item.completed.set_exception(e) - else: - # otherwise, this should raise in the event loop, which I suppose isn't handled - raise e + raise e finally: self._caller_local_in_progress.pop(queue_item.prompt_id) if self._caller_server is not None: diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index 9969e6c61..a17dcb9d9 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -12,6 +12,7 @@ from aiormq import AMQPConnectionError from .distributed_progress import DistributedExecutorToClientProgress from .distributed_types import RpcRequest, RpcReply from ..client.embedded_comfy_client import EmbeddedComfyClient +from ..cmd.main_pre import tracer from ..component_model.queue_types import ExecutionStatus @@ -32,6 +33,7 @@ class DistributedPromptWorker: self._loop = loop or asyncio.get_event_loop() self._embedded_comfy_client = embedded_comfy_client + @tracer.start_as_current_span("Do Work Item") async def _do_work_item(self, request: dict) -> dict: await self.on_will_complete_work_item(request) try: @@ -55,6 +57,7 @@ class DistributedPromptWorker: await self.on_did_complete_work_item(request_obj, reply) return asdict(reply) + @tracer.start_as_current_span("Initialize Prompt Worker") async def init(self): await self._exit_stack.__aenter__() try: diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 9a8f7e7f7..bc0841f61 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -9,8 +9,11 @@ import types from functools import reduce from importlib.metadata import entry_points +from opentelemetry.trace import Span, Status, StatusCode from pkg_resources import resource_filename + from .package_typing import ExportedNodes +from ..cmd.main_pre import tracer _comfy_nodes: ExportedNodes = ExportedNodes() @@ -46,29 +49,33 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, else: # Iterate through all the submodules for _, name, is_pkg in pkgutil.iter_modules(module.__path__): - full_name = module.__name__ + "." + name - time_before = time.perf_counter() - success = True - - if full_name.endswith(".disabled"): - continue - try: - submodule = importlib.import_module(full_name) - # Recursively call the function if it's a package - exported_nodes.update( - _import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times, - depth=depth - 1)) - except KeyboardInterrupt as interrupted: - raise interrupted - except Exception as x: - if isinstance(x, AttributeError): - potential_path_error: AttributeError = x - if potential_path_error.name == '__path__': - continue - logging.error(f"{full_name} import failed", exc_info=x) - success = False - exceptions.append(x) - timings.append((time.perf_counter() - time_before, full_name, success)) + span: Span + with tracer.start_as_current_span("Load Node") as span: + full_name = module.__name__ + "." + name + time_before = time.perf_counter() + success = True + span.set_attribute("full_name", full_name) + if full_name.endswith(".disabled"): + continue + try: + submodule = importlib.import_module(full_name) + # Recursively call the function if it's a package + new_nodes = _import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times, depth=depth - 1) + span.set_attribute("new_nodes.length", len(new_nodes)) + exported_nodes.update(new_nodes) + except KeyboardInterrupt as interrupted: + raise interrupted + except Exception as x: + if isinstance(x, AttributeError): + potential_path_error: AttributeError = x + if potential_path_error.name == '__path__': + continue + logging.error(f"{full_name} import failed", exc_info=x) + success = False + exceptions.append(x) + span.set_status(Status(StatusCode.ERROR)) + span.record_exception(x) + timings.append((time.perf_counter() - time_before, full_name, success)) if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings): for (duration, module_name, success) in sorted(timings): @@ -81,6 +88,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, return exported_nodes +@tracer.start_as_current_span("Import All Nodes In Workspace") def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: # now actually import the nodes, to improve control of node loading order from comfy_extras import nodes as comfy_extras_nodes diff --git a/comfy/utils.py b/comfy/utils.py index acfaa7088..f26b8b834 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,4 +1,5 @@ import os.path +import threading from contextlib import contextmanager import torch @@ -17,7 +18,7 @@ from .component_model.executor_types import ExecutorToClientProgress from .component_model.queue_types import BinaryEventTypes PROGRESS_BAR_ENABLED = True -PROGRESS_BAR_HOOK = None +_progress_bar_hook = threading.local() @@ -479,7 +480,7 @@ def hijack_progress(server: ExecutorToClientProgress): if preview_image is not None: server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) - set_progress_bar_global_hook(hook) + _progress_bar_hook.hook = hook def set_progress_bar_enabled(enabled): @@ -491,11 +492,6 @@ def get_progress_bar_enabled() -> bool: return PROGRESS_BAR_ENABLED -def set_progress_bar_global_hook(function): - global PROGRESS_BAR_HOOK - PROGRESS_BAR_HOOK = function - - class _DisabledProgressBar: def __init__(self, *args, **kwargs): pass @@ -509,10 +505,10 @@ class _DisabledProgressBar: class ProgressBar: def __init__(self, total: float): - global PROGRESS_BAR_HOOK + global _progress_bar_hook self.total: float = total self.current: float = 0.0 - self.hook = PROGRESS_BAR_HOOK + self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None def update_absolute(self, value, total=None, preview=None): if total is not None: diff --git a/tests/distributed/test_asyncio_remote_client.py b/tests/asyncio/test_asyncio_remote_client.py similarity index 100% rename from tests/distributed/test_asyncio_remote_client.py rename to tests/asyncio/test_asyncio_remote_client.py diff --git a/tests/conftest.py b/tests/conftest.py index d407dc474..20e41497e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,7 +64,7 @@ def comfy_background_server(args_pytest): if not success: raise Exception("Failed to start background server") yield - p.kill() + p.terminate() torch.cuda.empty_cache() diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 1251cddf0..b7e76d98b 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -1,27 +1,26 @@ import asyncio import logging import os +import socket +import subprocess import sys import time import uuid -import subprocess -import socket - from concurrent.futures import ThreadPoolExecutor from typing import List import jwt import pytest import requests +from testcontainers.rabbitmq import RabbitMqContainer from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.embedded_comfy_client import EmbeddedComfyClient -from comfy.distributed.server_stub import ServerStub from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner from comfy.component_model.make_mutable import make_mutable from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker -from testcontainers.rabbitmq import RabbitMqContainer +from comfy.distributed.server_stub import ServerStub # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost" @@ -55,8 +54,7 @@ async def test_basic_queue_worker() -> None: # this unfortunately does a bunch of initialization on the test thread from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue # now submit some jobs - distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, - connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") + distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") await distributed_queue.init() queue_item = create_test_prompt() res: TaskInvocation = await distributed_queue.put_async(queue_item) @@ -74,10 +72,8 @@ async def test_distributed_prompt_queues_same_process(): connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue - async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, - connection_uri=connection_uri) as frontend: - async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, - connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker: + async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend: + async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker: test_prompt = create_test_prompt() test_prompt.completed = asyncio.Future() @@ -117,6 +113,7 @@ async def test_frontend_backend_workers(): frontend_command = [ "comfyui", "--listen=0.0.0.0", + "--port=9001", "--cpu", "--distributed-queue-frontend", f"--distributed-queue-connection-uri={connection_uri}", @@ -125,6 +122,7 @@ async def test_frontend_backend_workers(): processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) backend_command = [ "comfyui-worker", + "--port=9002", f"--distributed-queue-connection-uri={connection_uri}", ] @@ -145,7 +143,7 @@ async def test_frontend_backend_workers(): client = AsyncRemoteComfyClient(server_address=server_address) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) png_image_bytes = await client.queue_prompt(prompt) - assert len(png_image_bytes) > 1000 + assert len(png_image_bytes) > 1000, "expected an image, but got nothing" finally: for process in processes_to_close: process.terminate() diff --git a/tests/distributed/test_embedded_client.py b/tests/distributed/test_embedded_client.py index fb43912e9..6be3b52d1 100644 --- a/tests/distributed/test_embedded_client.py +++ b/tests/distributed/test_embedded_client.py @@ -24,8 +24,7 @@ async def test_cuda_memory_usage(): assert memory_after_workflow > starting_memory, "Expected CUDA memory to increase after running the workflow" ending_memory = torch.cuda.memory_allocated(device) - assert abs( - ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup" + assert abs(ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup" @pytest.mark.asyncio