Improve tracing to propagate to backend workers correctly when using the API. Fix distributed tests.

This commit is contained in:
doctorpangloss 2024-05-07 13:44:34 -07:00
parent 75b63fce91
commit f8fcfa6f08
12 changed files with 133 additions and 94 deletions

View File

@ -7,14 +7,19 @@ from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional
from opentelemetry import context
from opentelemetry.trace import Span, Status, StatusCode
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..component_model.make_mutable import make_mutable from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
_server_stub_instance = ServerStub() _server_stub_instance = ServerStub()
class EmbeddedComfyClient: class EmbeddedComfyClient:
""" """
Embedded client for comfy executing prompts as a library. Embedded client for comfy executing prompts as a library.
@ -110,30 +115,36 @@ class EmbeddedComfyClient:
await self._loop.run_in_executor(self._executor, create_executor_in_thread) await self._loop.run_in_executor(self._executor, create_executor_in_thread)
@tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self, async def queue_prompt(self,
prompt: PromptDict | dict, prompt: PromptDict | dict,
prompt_id: Optional[str] = None, prompt_id: Optional[str] = None,
client_id: Optional[str] = None) -> dict: client_id: Optional[str] = None) -> dict:
prompt_id = prompt_id or str(uuid.uuid4()) prompt_id = prompt_id or str(uuid.uuid4())
client_id = client_id or self._progress_handler.client_id or None client_id = client_id or self._progress_handler.client_id or None
span_context = context.get_current()
def execute_prompt() -> dict: def execute_prompt() -> dict:
from ..cmd.execution import PromptExecutor, validate_prompt spam: Span
prompt_mut = make_mutable(prompt) with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
validation_tuple = validate_prompt(prompt_mut) from ..cmd.execution import PromptExecutor, validate_prompt
prompt_mut = make_mutable(prompt)
validation_tuple = validate_prompt(prompt_mut)
prompt_executor: PromptExecutor = self._prompt_executor prompt_executor: PromptExecutor = self._prompt_executor
if client_id is None: if client_id is None:
prompt_executor.server = _server_stub_instance prompt_executor.server = _server_stub_instance
else: else:
prompt_executor.server = self._progress_handler prompt_executor.server = self._progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple[2]) execute_outputs=validation_tuple[2])
if prompt_executor.success: if prompt_executor.success:
return prompt_executor.outputs_ui return prompt_executor.outputs_ui
else: else:
raise RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages)) span.set_status(Status(StatusCode.ERROR))
error = RuntimeError("\n".join(event for (event, data) in self._prompt_executor.status_messages))
span.record_exception(error)
return await self._loop.run_in_executor(self._executor, execute_prompt) return await self._loop.run_in_executor(self._executor, execute_prompt)

View File

@ -9,22 +9,26 @@ import threading
import traceback import traceback
import typing import typing
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from typing_extensions import TypedDict
import torch
import lazy_object_proxy
import lazy_object_proxy
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
from typing_extensions import TypedDict
from .main_pre import tracer
from .. import interruption from .. import interruption
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..component_model.executor_types import ExecutorToClientProgress
from .. import model_management from .. import model_management
from ..nodes.package_typing import ExportedNodes from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place. # various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None): def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
if extra_data is None: if extra_data is None:
extra_data = {} extra_data = {}
@ -139,6 +143,7 @@ def format_value(x):
return str(x) return str(x)
@tracer.start_as_current_span("Recursive Execute")
def recursive_execute(server: ExecutorToClientProgress, def recursive_execute(server: ExecutorToClientProgress,
prompt, prompt,
outputs, outputs,
@ -324,6 +329,10 @@ class PromptExecutor:
self.server.send_sync(event, data, self.server.client_id) self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
current_span = get_current_span()
current_span.set_status(Status(StatusCode.ERROR))
current_span.record_exception(ex)
node_id = error["node_id"] node_id = error["node_id"]
class_type = prompt[node_id]["class_type"] class_type = prompt[node_id]["class_type"]
@ -430,7 +439,6 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in # This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the # the actual SD code, instead it will report the node where the
# error was raised # error was raised
# todo: if we're using a distributed queue, we must wrap the server instance to correctly communicate back to the client via the exchange
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id,
extra_data, executed, prompt_id, self.outputs_ui, extra_data, executed, prompt_id, self.outputs_ui,
self.object_storage) self.object_storage)
@ -689,6 +697,7 @@ class ValidationErrorDict(TypedDict):
ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]] ValidationTuple = typing.Tuple[bool, Optional[ValidationErrorDict], typing.List[str], Union[dict, list]]
@tracer.start_as_current_span("Validate Prompt")
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set() outputs = set()
for x in prompt: for x in prompt:

View File

@ -13,17 +13,16 @@ import uuid
from asyncio import Future, AbstractEventLoop from asyncio import Future, AbstractEventLoop
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from posixpath import join as urljoin
from typing import List, Optional, Dict from typing import List, Optional, Dict
from urllib.parse import quote, urlencode from urllib.parse import quote, urlencode
from posixpath import join as urljoin
from can_ada import URL, parse as urlparse
import aiofiles import aiofiles
import aiohttp import aiohttp
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from aiohttp import web from aiohttp import web
from can_ada import URL, parse as urlparse
from pkg_resources import resource_filename from pkg_resources import resource_filename
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
@ -36,12 +35,13 @@ from ..cli_args import args
from ..client.client_types import Output, FileOutput from ..client.client_types import Output, FileOutput
from ..cmd import execution from ..cmd import execution
from ..cmd import folder_paths from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation
from ..digest import digest from ..digest import digest
from ..nodes.package_typing import ExportedNodes
from ..images import open_image from ..images import open_image
from ..nodes.package_typing import ExportedNodes
class HeuristicPath(NamedTuple): class HeuristicPath(NamedTuple):
@ -95,7 +95,7 @@ class PromptServer(ExecutorToClientProgress):
self.user_manager = UserManager() self.user_manager = UserManager()
# todo: this is probably read by custom nodes elsewhere # todo: this is probably read by custom nodes elsewhere
self.supports: List[str] = ["custom_nodes_from_web"] self.supports: List[str] = ["custom_nodes_from_web"]
self.prompt_queue: execution.AbstractPromptQueue | None = None self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
self.loop: AbstractEventLoop = loop self.loop: AbstractEventLoop = loop
self.messages: asyncio.Queue = asyncio.Queue() self.messages: asyncio.Queue = asyncio.Queue()
self.number: int = 0 self.number: int = 0
@ -599,22 +599,27 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1])) return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
# convert a valid prompt to the queue tuple this expects # convert a valid prompt to the queue tuple this expects
completed: Future[TaskInvocation | dict] = self.loop.create_future()
number = self.number number = self.number
self.number += 1 self.number += 1
self.prompt_queue.put(
QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]),
completed=completed))
try: completed: Future[TaskInvocation | dict] = self.loop.create_future()
await completed item = QueueItem(queue_tuple=(number, str(uuid.uuid4()), prompt_dict, {}, valid[2]), completed=completed)
except Exception as ex:
return web.Response(body=str(ex), status=503) if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
# expect a single image # this enables span propagation seamlessly
result: TaskInvocation | dict = completed.result() result = await self.prompt_queue.put_async(item)
if result is None:
return web.Response(body="the queue is shutting down", status=503)
else:
try:
self.prompt_queue.put(item)
await completed
except Exception as ex:
return web.Response(body=str(ex), status=503)
# expect a single image
result: TaskInvocation | dict = completed.result()
outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result outputs_dict: Dict[str, Output] = result.outputs if isinstance(result, TaskInvocation) else result
# find images and read them # find images and read them
output_images: List[FileOutput] = [] output_images: List[FileOutput] = []
for node_id, node in outputs_dict.items(): for node_id, node in outputs_dict.items():
images: List[FileOutput] = [] images: List[FileOutput] = []

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import typing import typing
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation
class AbstractPromptQueue(metaclass=ABCMeta): class AbstractPromptQueue(metaclass=ABCMeta):
@ -118,3 +118,13 @@ class AbstractPromptQueue(metaclass=ABCMeta):
:return: :return:
""" """
pass pass
class AsyncAbstractPromptQueue(AbstractPromptQueue):
@abstractmethod
async def put_async(self, queue_item) -> TaskInvocation | None:
pass
@abstractmethod
async def get_async(self, timeout: float | None = None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
pass

View File

@ -18,13 +18,14 @@ from .distributed_types import RpcRequest, RpcReply
from .history import History from .history import History
from .server_stub import ServerStub from .server_stub import ServerStub
from ..auth.permissions import jwt_decode from ..auth.permissions import jwt_decode
from ..cmd.main_pre import tracer
from ..cmd.server import PromptServer from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
class DistributedPromptQueue(AbstractPromptQueue): class DistributedPromptQueue(AsyncAbstractPromptQueue):
""" """
A distributed prompt queue for the ComfyUI web client and single-threaded worker. A distributed prompt queue for the ComfyUI web client and single-threaded worker.
""" """
@ -39,12 +40,13 @@ class DistributedPromptQueue(AbstractPromptQueue):
async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None: async def progress(self, event: SendSyncEvent, data: SendSyncData, sid: Optional[str]) -> None:
self._caller_server.send_sync(event, data, sid=sid) self._caller_server.send_sync(event, data, sid=sid)
async def put_async(self, queue_item: QueueItem): @tracer.start_as_current_span("Put Async")
async def put_async(self, queue_item: QueueItem) -> TaskInvocation | None:
assert self._is_caller assert self._is_caller
assert self._rpc is not None assert self._rpc is not None
if self._closing: if self._closing:
return return None
self._caller_local_in_progress[queue_item.prompt_id] = queue_item self._caller_local_in_progress[queue_item.prompt_id] = queue_item
if self._caller_server is not None: if self._caller_server is not None:
self._caller_server.queue_updated() self._caller_server.queue_updated()
@ -88,9 +90,7 @@ class DistributedPromptQueue(AbstractPromptQueue):
# if we have a completer, propoagate the exception to it # if we have a completer, propoagate the exception to it
if queue_item.completed is not None: if queue_item.completed is not None:
queue_item.completed.set_exception(e) queue_item.completed.set_exception(e)
else: raise e
# otherwise, this should raise in the event loop, which I suppose isn't handled
raise e
finally: finally:
self._caller_local_in_progress.pop(queue_item.prompt_id) self._caller_local_in_progress.pop(queue_item.prompt_id)
if self._caller_server is not None: if self._caller_server is not None:

View File

@ -12,6 +12,7 @@ from aiormq import AMQPConnectionError
from .distributed_progress import DistributedExecutorToClientProgress from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..cmd.main_pre import tracer
from ..component_model.queue_types import ExecutionStatus from ..component_model.queue_types import ExecutionStatus
@ -32,6 +33,7 @@ class DistributedPromptWorker:
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client self._embedded_comfy_client = embedded_comfy_client
@tracer.start_as_current_span("Do Work Item")
async def _do_work_item(self, request: dict) -> dict: async def _do_work_item(self, request: dict) -> dict:
await self.on_will_complete_work_item(request) await self.on_will_complete_work_item(request)
try: try:
@ -55,6 +57,7 @@ class DistributedPromptWorker:
await self.on_did_complete_work_item(request_obj, reply) await self.on_did_complete_work_item(request_obj, reply)
return asdict(reply) return asdict(reply)
@tracer.start_as_current_span("Initialize Prompt Worker")
async def init(self): async def init(self):
await self._exit_stack.__aenter__() await self._exit_stack.__aenter__()
try: try:

View File

@ -9,8 +9,11 @@ import types
from functools import reduce from functools import reduce
from importlib.metadata import entry_points from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode
from pkg_resources import resource_filename from pkg_resources import resource_filename
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer
_comfy_nodes: ExportedNodes = ExportedNodes() _comfy_nodes: ExportedNodes = ExportedNodes()
@ -46,29 +49,33 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
else: else:
# Iterate through all the submodules # Iterate through all the submodules
for _, name, is_pkg in pkgutil.iter_modules(module.__path__): for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
full_name = module.__name__ + "." + name span: Span
time_before = time.perf_counter() with tracer.start_as_current_span("Load Node") as span:
success = True full_name = module.__name__ + "." + name
time_before = time.perf_counter()
if full_name.endswith(".disabled"): success = True
continue span.set_attribute("full_name", full_name)
try: if full_name.endswith(".disabled"):
submodule = importlib.import_module(full_name) continue
# Recursively call the function if it's a package try:
exported_nodes.update( submodule = importlib.import_module(full_name)
_import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times, # Recursively call the function if it's a package
depth=depth - 1)) new_nodes = _import_and_enumerate_nodes_in_module(submodule, print_import_times=print_import_times, depth=depth - 1)
except KeyboardInterrupt as interrupted: span.set_attribute("new_nodes.length", len(new_nodes))
raise interrupted exported_nodes.update(new_nodes)
except Exception as x: except KeyboardInterrupt as interrupted:
if isinstance(x, AttributeError): raise interrupted
potential_path_error: AttributeError = x except Exception as x:
if potential_path_error.name == '__path__': if isinstance(x, AttributeError):
continue potential_path_error: AttributeError = x
logging.error(f"{full_name} import failed", exc_info=x) if potential_path_error.name == '__path__':
success = False continue
exceptions.append(x) logging.error(f"{full_name} import failed", exc_info=x)
timings.append((time.perf_counter() - time_before, full_name, success)) success = False
exceptions.append(x)
span.set_status(Status(StatusCode.ERROR))
span.record_exception(x)
timings.append((time.perf_counter() - time_before, full_name, success))
if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings): if print_import_times and len(timings) > 0 or any(not success for (_, _, success) in timings):
for (duration, module_name, success) in sorted(timings): for (duration, module_name, success) in sorted(timings):
@ -81,6 +88,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
return exported_nodes return exported_nodes
@tracer.start_as_current_span("Import All Nodes In Workspace")
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order # now actually import the nodes, to improve control of node loading order
from comfy_extras import nodes as comfy_extras_nodes from comfy_extras import nodes as comfy_extras_nodes

View File

@ -1,4 +1,5 @@
import os.path import os.path
import threading
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
@ -17,7 +18,7 @@ from .component_model.executor_types import ExecutorToClientProgress
from .component_model.queue_types import BinaryEventTypes from .component_model.queue_types import BinaryEventTypes
PROGRESS_BAR_ENABLED = True PROGRESS_BAR_ENABLED = True
PROGRESS_BAR_HOOK = None _progress_bar_hook = threading.local()
@ -479,7 +480,7 @@ def hijack_progress(server: ExecutorToClientProgress):
if preview_image is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
set_progress_bar_global_hook(hook) _progress_bar_hook.hook = hook
def set_progress_bar_enabled(enabled): def set_progress_bar_enabled(enabled):
@ -491,11 +492,6 @@ def get_progress_bar_enabled() -> bool:
return PROGRESS_BAR_ENABLED return PROGRESS_BAR_ENABLED
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class _DisabledProgressBar: class _DisabledProgressBar:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
@ -509,10 +505,10 @@ class _DisabledProgressBar:
class ProgressBar: class ProgressBar:
def __init__(self, total: float): def __init__(self, total: float):
global PROGRESS_BAR_HOOK global _progress_bar_hook
self.total: float = total self.total: float = total
self.current: float = 0.0 self.current: float = 0.0
self.hook = PROGRESS_BAR_HOOK self.hook = _progress_bar_hook.hook if hasattr(_progress_bar_hook, "hook") else None
def update_absolute(self, value, total=None, preview=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:

View File

@ -64,7 +64,7 @@ def comfy_background_server(args_pytest):
if not success: if not success:
raise Exception("Failed to start background server") raise Exception("Failed to start background server")
yield yield
p.kill() p.terminate()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,27 +1,26 @@
import asyncio import asyncio
import logging import logging
import os import os
import socket
import subprocess
import sys import sys
import time import time
import uuid import uuid
import subprocess
import socket
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List from typing import List
import jwt import jwt
import pytest import pytest
import requests import requests
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.distributed.server_stub import ServerStub
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.make_mutable import make_mutable from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from testcontainers.rabbitmq import RabbitMqContainer from comfy.distributed.server_stub import ServerStub
# fixes issues with running the testcontainers rabbitmqcontainer on Windows # fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost" os.environ["TC_HOST"] = "localhost"
@ -55,8 +54,7 @@ async def test_basic_queue_worker() -> None:
# this unfortunately does a bunch of initialization on the test thread # this unfortunately does a bunch of initialization on the test thread
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs # now submit some jobs
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}")
await distributed_queue.init() await distributed_queue.init()
queue_item = create_test_prompt() queue_item = create_test_prompt()
res: TaskInvocation = await distributed_queue.put_async(queue_item) res: TaskInvocation = await distributed_queue.put_async(queue_item)
@ -74,10 +72,8 @@ async def test_distributed_prompt_queues_same_process():
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, async with DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) as frontend:
connection_uri=connection_uri) as frontend: async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False, connection_uri=connection_uri) as worker:
async with DistributedPromptQueue(ServerStub(), is_callee=True, is_caller=False,
connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}") as worker:
test_prompt = create_test_prompt() test_prompt = create_test_prompt()
test_prompt.completed = asyncio.Future() test_prompt.completed = asyncio.Future()
@ -117,6 +113,7 @@ async def test_frontend_backend_workers():
frontend_command = [ frontend_command = [
"comfyui", "comfyui",
"--listen=0.0.0.0", "--listen=0.0.0.0",
"--port=9001",
"--cpu", "--cpu",
"--distributed-queue-frontend", "--distributed-queue-frontend",
f"--distributed-queue-connection-uri={connection_uri}", f"--distributed-queue-connection-uri={connection_uri}",
@ -125,6 +122,7 @@ async def test_frontend_backend_workers():
processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr)) processes_to_close.append(subprocess.Popen(frontend_command, stdout=sys.stdout, stderr=sys.stderr))
backend_command = [ backend_command = [
"comfyui-worker", "comfyui-worker",
"--port=9002",
f"--distributed-queue-connection-uri={connection_uri}", f"--distributed-queue-connection-uri={connection_uri}",
] ]
@ -145,7 +143,7 @@ async def test_frontend_backend_workers():
client = AsyncRemoteComfyClient(server_address=server_address) client = AsyncRemoteComfyClient(server_address=server_address)
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1) prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1)
png_image_bytes = await client.queue_prompt(prompt) png_image_bytes = await client.queue_prompt(prompt)
assert len(png_image_bytes) > 1000 assert len(png_image_bytes) > 1000, "expected an image, but got nothing"
finally: finally:
for process in processes_to_close: for process in processes_to_close:
process.terminate() process.terminate()

View File

@ -24,8 +24,7 @@ async def test_cuda_memory_usage():
assert memory_after_workflow > starting_memory, "Expected CUDA memory to increase after running the workflow" assert memory_after_workflow > starting_memory, "Expected CUDA memory to increase after running the workflow"
ending_memory = torch.cuda.memory_allocated(device) ending_memory = torch.cuda.memory_allocated(device)
assert abs( assert abs(ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
ending_memory - starting_memory) < 1e7, "Expected CUDA memory to return close to starting memory after cleanup"
@pytest.mark.asyncio @pytest.mark.asyncio