Improve tracing to propagate to backend workers correctly when using the API. Fix distributed tests.

This commit is contained in:
doctorpangloss 2024-05-07 13:44:34 -07:00
parent 75b63fce91
commit f8fcfa6f08
12 changed files with 133 additions and 94 deletions

View File

@ -7,14 +7,19 @@ from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from opentelemetry import context
from opentelemetry.trace import Span, Status, StatusCode
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..component_model.make_mutable import make_mutable
from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable
from ..distributed.server_stub import ServerStub
_server_stub_instance = ServerStub()
class EmbeddedComfyClient:
"""
Embedded client for comfy executing prompts as a library.
@ -110,30 +115,36 @@ class EmbeddedComfyClient:
await self._loop.run_in_executor(self._executor, create_executor_in_thread)
@tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self,
prompt: PromptDict | dict,
prompt_id: Optional[str] = None,
client_id: Optional[str] = None) -> dict:
prompt_id = prompt_id or str(uuid.uuid4())
client_id = client_id or self._progress_handler.client_id or None
span_context = context.get_current()
def execute_prompt() -> dict:
from ..cmd.execution import PromptExecutor, validate_prompt
prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut)
spam: Span
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
from ..cmd.execution import PromptExecutor, validate_prompt
prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut)
prompt_executor: PromptExecutor = self._prompt_executor
prompt_executor: PromptExecutor = self._prompt_executor
if client_id is None:
prompt_executor.server = _server_stub_instance
else:
prompt_executor.server = self._progress_handler
if client_id is None:
prompt_executor.server = _server_stub_instance
else:
prompt_executor.server = self._progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2])
if prompt_executor.success:
return prompt_executor.outputs_ui
else:
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2])
if prompt_executor.success:
return prompt_executor.outputs_ui
else:
span.set_status(Status(StatusCode.ERROR))
error = RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
span.record_exception(error)
return await self._loop.run_in_executor(self._executor, execute_prompt)

View File

@ -9,22 +9,26 @@ import threading
import traceback
import typing
from typing import List, Optional, Tuple, Union
from typing_extensions import TypedDict
import torch
import lazy_object_proxy
import lazy_object_proxy
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
from typing_extensions import TypedDict
from .main_pre import tracer
from .. import interruption
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..component_model.executor_types import ExecutorToClientProgress
from .. import model_management
from ..nodes.package_typing import ExportedNodes
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
if extra_data is None:
extra_data = {}
@ -139,6 +143,7 @@ def format_value(x):
return str(x)
@tracer.start_as_current_span("Recursive Execute")
def recursive_execute(server: ExecutorToClientProgress,
prompt,
outputs,
@ -324,6 +329,10 @@ class PromptExecutor:
self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
current_span = get_current_span()
current_span.set_status(Status(StatusCode.ERROR))
current_span.record_exception(ex)
node_id = error["node_id"]
class_type = prompt[node_id]["class_type"]
@ -430,7 +439,6 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
# todo: if we're using a distributed queue, we must wrap the server instance to correctly communicate back to the client via the exchange
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id,
extra_data, executed, prompt_id, self.outputs_ui,
self.object_storage)
@ -689,6 +697,7 @@ class ValidationErrorDict(TypedDict):
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
@tracer.start_as_current_span("Validate Prompt")
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set()
for x in prompt:

View File

@ -13,17 +13,16 @@ import uuid
from asyncio import Future, AbstractEventLoop
from enum import Enum
from io import BytesIO
from posixpath import join as urljoin
from typing import List, Optional, Dict
from urllib.parse import quote, urlencode
from posixpath import join as urljoin
from can_ada import URL, parse as urlparse
import aiofiles
import aiohttp
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
from can_ada import URL, parse as urlparse
from pkg_resources import resource_filename
from typing_extensions import NamedTuple
@ -36,12 +35,13 @@ from ..cli_args import args
from ..client.client_types import Output, FileOutput
from ..cmd import execution
from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from ..digest import digest
from ..nodes.package_typing import ExportedNodes
from ..images import open_image
from ..nodes.package_typing import ExportedNodes
class HeuristicPath(NamedTuple):
@ -95,7 +95,7 @@ class PromptServer(ExecutorToClientProgress):
self.user_manager = UserManager()
# todo: this is probably read by custom nodes elsewhere
self.supports: List[str] = ["custom_nodes_from_web"]
self.prompt_queue: execution.AbstractPromptQueue | None = None
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
self.loop: AbstractEventLoop = loop
self.messages: asyncio.Queue = asyncio.Queue()
self.number: int = 0
@ -599,22 +599,27 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
# convert a valid prompt to the queue tuple this expects
completed: Future[TaskInvocation | dict] = self.loop.create_future()
number = self.number
self.number += 1
self.prompt_queue.put(
QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]),
completed=completed))
try:
await completed
except Exception as ex:
return web.Response(body=str(ex), status=503)
# expect a single image
result: TaskInvocation | dict = completed.result()
completed: Future[TaskInvocation | dict] = self.loop.create_future()
item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
# this enables span propagation seamlessly
result = await self.prompt_queue.put_async(item)
if result is None:
return web.Response(body="the queue is shutting down", status=503)
else:
try:
self.prompt_queue.put(item)
await completed
except Exception as ex:
return web.Response(body=str(ex), status=503)
# expect a single image
result: TaskInvocation | dict = completed.result()
outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
# find images and read them
output_images: List[FileOutput] = []
for node_id, node in outputs_dict.items():
images: List[FileOutput] = []

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import typing
from abc import ABCMeta, abstractmethod
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation
class AbstractPromptQueue(metaclass=ABCMeta):
@ -118,3 +118,13 @@ class AbstractPromptQueue(metaclass=ABCMeta):
:return:
"""
pass
class AsyncAbstractPromptQueue(AbstractPromptQueue):
@abstractmethod
async def put_async(self, queue_item) -> TaskInvocation | None:
pass
@abstractmethod
async def get_async(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
pass

View File

@ -18,13 +18,14 @@ from .distributed_types import RpcRequest, RpcReply
from .history import History
from .server_stub import ServerStub
from ..auth.permissions import jwt_decode
from ..cmd.main_pre import tracer
from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
class DistributedPromptQueue(AbstractPromptQueue):
class DistributedPromptQueue(AsyncAbstractPromptQueue):
"""
A distributed prompt queue for the ComfyUI web client and single-threaded worker.
"""
@ -39,12 +40,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
self._caller_server.send_sync(event, data, sid=sid)
async def put_async(self, queue_item: QueueItem):
@tracer.start_as_current_span("Put Async")
async def put_async(self, queue_item: QueueItem) -> TaskInvocation | None:
assert self._is_caller
assert self._rpc is not None
if self._closing:
return
return None
self._caller_local_in_progress[queue_item.prompt_id] = queue_item
if self._caller_server is not None:
self._caller_server.queue_updated()
@ -88,9 +90,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
# if we have a completer, propoagate the exception to it
if queue_item.completed is not None:
queue_item.completed.set_exception(e)
else:
# otherwise, this should raise in the event loop, which I suppose isn't handled
raise e
raise e
finally:
self._caller_local_in_progress.pop(queue_item.prompt_id)
if self._caller_server is not None:

View File

@ -12,6 +12,7 @@ from aiormq import AMQPConnectionError
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..cmd.main_pre import tracer
from ..component_model.queue_types import ExecutionStatus
@ -32,6 +33,7 @@ class DistributedPromptWorker:
self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client
@tracer.start_as_current_span("Do Work Item")
async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request)
try:
@ -55,6 +57,7 @@ class DistributedPromptWorker:
await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply)
@tracer.start_as_current_span("Initialize Prompt Worker")
async def init(self):
await self._exit_stack.__aenter__()
try:

View File

@ -9,8 +9,11 @@ import types
from functools import reduce
from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode
from pkg_resources import resource_filename
from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer
_comfy_nodes: ExportedNodes = ExportedNodes()
@ -46,29 +49,33 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
else:
# Iterate through all the submodules
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
full_name = module.__name__ + "." + name
time_before = time.perf_counter()
success = True
if full_name.endswith(".disabled"):
continue
try:
submodule = importlib.import_module(full_name)
# Recursively call the function if it's a package
exported_nodes.update(
_import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times,
depth=depth - 1))
except KeyboardInterrupt as interrupted:
raise interrupted
except Exception as x:
if isinstance(x, AttributeError):
potential_path_error: AttributeError = x
if potential_path_error.name == '__path__':
continue
logging.error(f"{full_name} import failed", exc_info=x)
success = False
exceptions.append(x)
timings.append((time.perf_counter() - time_before, full_name, success))
span: Span
with tracer.start_as_current_span("Load Node") as span:
full_name = module.__name__ + "." + name
time_before = time.perf_counter()
success = True
span.set_attribute("full_name", full_name)
if full_name.endswith(".disabled"):
continue
try:
submodule = importlib.import_module(full_name)
# Recursively call the function if it's a package
new_nodes = _import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times, depth=depth - 1)
span.set_attribute("new_nodes.length", len(new_nodes))
exported_nodes.update(new_nodes)
except KeyboardInterrupt as interrupted:
raise interrupted
except Exception as x:
if isinstance(x, AttributeError):
potential_path_error: AttributeError = x
if potential_path_error.name == '__path__':
continue
logging.error(f"{full_name} import failed", exc_info=x)
success = False
exceptions.append(x)
span.set_status(Status(StatusCode.ERROR))
span.record_exception(x)
timings.append((time.perf_counter() - time_before, full_name, success))
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
for (duration, module_name, success) in sorted(timings):
@ -81,6 +88,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
return exported_nodes
@tracer.start_as_current_span("Import All Nodes In Workspace")
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order
from comfy_extras import nodes as comfy_extras_nodes

View File

@ -1,4 +1,5 @@
import os.path
import threading
from contextlib import contextmanager
import torch
@ -17,7 +18,7 @@ from .component_model.executor_types import ExecutorToClientProgress
from .component_model.queue_types import BinaryEventTypes
PROGRESS_BAR_ENABLED = True
PROGRESS_BAR_HOOK = None
_progress_bar_hook = threading.local()
@ -479,7 +480,7 @@ def hijack_progress(server: ExecutorToClientProgress):
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
set_progress_bar_global_hook(hook)
_progress_bar_hook.hook = hook
def set_progress_bar_enabled(enabled):
@ -491,11 +492,6 @@ def get_progress_bar_enabled() -> bool:
return PROGRESS_BAR_ENABLED
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class _DisabledProgressBar:
def __init__(self, *args, **kwargs):
pass
@ -509,10 +505,10 @@ class _DisabledProgressBar:
class ProgressBar:
def __init__(self, total: float):
global PROGRESS_BAR_HOOK
global _progress_bar_hook
self.total: float = total
self.current: float = 0.0
self.hook = PROGRESS_BAR_HOOK
self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None
def update_absolute(self, value, total=None, preview=None):
if total is not None:

View File

@ -64,7 +64,7 @@ def comfy_background_server(args_pytest):
if not success:
raise Exception("Failed to start background server")
yield
p.kill()
p.terminate()
torch.cuda.empty_cache()

View File

@ -1,27 +1,26 @@
import asyncio
import logging
import os
import socket
import subprocess
import sys
import time
import uuid
import subprocess
import socket
from concurrent.futures import ThreadPoolExecutor
from typing import List
import jwt
import pytest
import requests
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.distributed.server_stub import ServerStub
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.distributed.server_stub import ServerStub
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
@ -55,8 +54,7 @@ async def test_basic_queue_worker() -> None:
# this unfortunately does a bunch of initialization on the test thread
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init()
queue_item = create_test_prompt()
res: TaskInvocation = await distributed_queue.put_async(queue_item)
@ -74,10 +72,8 @@ async def test_distributed_prompt_queues_same_process():
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True,
connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
test_prompt = create_test_prompt()
test_prompt.completed = asyncio.Future()
@ -117,6 +113,7 @@ async def test_frontend_backend_workers():
frontend_command = [
"comfyui",
"--listen=0.0.0.0",
"--port=9001",
"--cpu",
"--distributed-queue-frontend",
f"--distributed-queue-connection-uri={connection_uri}",
@ -125,6 +122,7 @@ async def test_frontend_backend_workers():
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
backend_command = [
"comfyui-worker",
"--port=9002",
f"--distributed-queue-connection-uri={connection_uri}",
]
@ -145,7 +143,7 @@ async def test_frontend_backend_workers():
client = AsyncRemoteComfyClient(server_address=server_address)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000
assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
finally:
for process in processes_to_close:
process.terminate()

View File

@ -24,8 +24,7 @@ async def test_cuda_memory_usage():
assert memory_after_workflow > starting_memory, "Expected CUDA memory to increase after running the workflow"
ending_memory = torch.cuda.memory_allocated(device)
assert abs(
ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
assert abs(ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
@pytest.mark.asyncio