Support ProcessPoolExecutor to improve memory management

This commit is contained in:
doctorpangloss 2024-09-04 17:03:22 -07:00
parent c75b9964ab
commit ed33ab1e7d
13 changed files with 191 additions and 125 deletions

View File

@ -198,6 +198,13 @@ def _create_parser() -> EnhancedConfigArgParser:
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.", help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
) )
parser.add_argument(
"--executor-factory",
type=str,
default="ThreadPoolExecutor",
help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed."
)
# now give plugins a chance to add configuration # now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'): for entry_point in entry_points().select(group='comfyui.custom_config'):
try: try:

View File

@ -109,6 +109,7 @@ class Configuration(dict):
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint. otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
force_channels_last (bool): Force channels last format when inferencing the models. force_channels_last (bool): Force channels last format when inferencing the models.
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure. force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -194,6 +195,8 @@ class Configuration(dict):
for key, value in kwargs.items(): for key, value in kwargs.items():
self[key] = value self[key] = value
self.executor_factory: str = "ThreadPoolExecutor"
def __getattr__(self, item): def __getattr__(self, item):
if item not in self: if item not in self:
return None return None

View File

@ -1,151 +1,159 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextvars
import gc import gc
import json import json
import uuid import uuid
from asyncio import get_event_loop from asyncio import get_event_loop
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from multiprocessing import RLock
from typing import Optional from typing import Optional
from opentelemetry import context from opentelemetry import context, propagate
from opentelemetry.trace import Span, Status, StatusCode from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress, Executor
from ..component_model.make_mutable import make_mutable from ..component_model.make_mutable import make_mutable
from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
_server_stub_instance = ServerStub() _prompt_executor = contextvars.ContextVar('prompt_executor')
def _execute_prompt(
prompt: dict,
prompt_id: str,
client_id: str,
span_context: dict,
progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None) -> dict:
span_context: Context = propagate.extract(span_context)
token = attach(span_context)
try:
return __execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration)
finally:
detach(token)
def __execute_prompt(
prompt: dict,
prompt_id: str,
client_id: str,
span_context: Context,
progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None) -> dict:
from .. import options
progress_handler = progress_handler or ServerStub()
try:
prompt_executor = _prompt_executor.get()
except LookupError:
if configuration is None:
options.enable_args_parsing()
else:
from ..cmd.main_pre import args
args.clear()
args.update(configuration)
from ..cmd.execution import PromptExecutor
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
prompt_executor.raise_exceptions = True
_prompt_executor.set(prompt_executor)
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
try:
prompt_mut = make_mutable(prompt)
from ..cmd.execution import validate_prompt
validation_tuple = validate_prompt(prompt_mut)
if not validation_tuple.valid:
validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors
raise ValueError(json.dumps(validation_error_dict))
if client_id is None:
prompt_executor.server = ServerStub()
else:
prompt_executor.server = progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple.good_output_node_ids)
return prompt_executor.outputs_ui
except Exception as exc_info:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(exc_info)
raise exc_info
def _cleanup():
from .. import model_management
model_management.unload_all_models()
gc.collect()
try:
model_management.soft_empty_cache()
except:
pass
class EmbeddedComfyClient: class EmbeddedComfyClient:
""" def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
Embedded client for comfy executing prompts as a library.
This client manages a single-threaded executor to run long-running or blocking tasks
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
in a dedicated thread for executing prompts and handling server-stub communications.
Example usage:
Asynchronous (non-blocking) usage with async-await:
```
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
# workspace.
prompt_dict = {
"1": {"class_type": "KSamplerAdvanced", ...}
...
}
# Validate your workflow (the prompt)
from comfy.api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt_dict)
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
# It does not connect to a remote server.
async def main():
async with EmbeddedComfyClient() as client:
outputs = await client.queue_prompt(prompt)
print(outputs)
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
if __name__ == "__main__"
asyncio.run(main())
```
In order to use this in blocking methods, learn more about asyncio online.
"""
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1):
self._progress_handler = progress_handler or ServerStub() self._progress_handler = progress_handler or ServerStub()
self._executor = ThreadPoolExecutor(max_workers=max_workers) self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
self._configuration = configuration self._configuration = configuration
# we don't want to import the executor yet
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
self._is_running = False self._is_running = False
self._task_count_lock = RLock()
self._task_count = 0
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return self._is_running return self._is_running
@property
def task_count(self) -> int:
return self._task_count
async def __aenter__(self): async def __aenter__(self):
await self._initialize_prompt_executor()
self._is_running = True self._is_running = True
return self return self
async def __aexit__(self, *args): async def __aexit__(self, *args):
# Perform cleanup here
def cleanup():
from .. import model_management
model_management.unload_all_models()
gc.collect()
try:
model_management.soft_empty_cache()
except:
pass
# wait until the queue is done while self.task_count > 0:
while self._executor._work_queue.qsize() > 0:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
await get_event_loop().run_in_executor(self._executor, cleanup) await get_event_loop().run_in_executor(self._executor, _cleanup)
self._executor.shutdown(wait=True) self._executor.shutdown(wait=True)
self._is_running = False self._is_running = False
async def _initialize_prompt_executor(self):
# This method must be async since it's used in __aenter__
def create_executor_in_thread():
from .. import options
if self._configuration is None:
options.enable_args_parsing()
else:
from ..cmd.main_pre import args
args.clear()
args.update(self._configuration)
from ..cmd.execution import PromptExecutor
self._prompt_executor = PromptExecutor(self._progress_handler, lru_size=self._configuration.cache_lru if self._configuration is not None else 0)
self._prompt_executor.raise_exceptions = True
await get_event_loop().run_in_executor(self._executor, create_executor_in_thread)
@tracer.start_as_current_span("Queue Prompt") @tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self, async def queue_prompt(self,
prompt: PromptDict | dict, prompt: PromptDict | dict,
prompt_id: Optional[str] = None, prompt_id: Optional[str] = None,
client_id: Optional[str] = None) -> dict: client_id: Optional[str] = None) -> dict:
with self._task_count_lock:
self._task_count += 1
prompt_id = prompt_id or str(uuid.uuid4()) prompt_id = prompt_id or str(uuid.uuid4())
client_id = client_id or self._progress_handler.client_id or None client_id = client_id or self._progress_handler.client_id or None
span_context = context.get_current() span_context = context.get_current()
carrier = {}
def execute_prompt() -> dict: propagate.inject(carrier, span_context)
spam: Span try:
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span: return await get_event_loop().run_in_executor(
from ..cmd.execution import PromptExecutor, validate_prompt self._executor,
try: _execute_prompt,
prompt_mut = make_mutable(prompt) make_mutable(prompt),
validation_tuple = validate_prompt(prompt_mut) prompt_id,
if not validation_tuple.valid: client_id,
validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors carrier,
raise ValueError(json.dumps(validation_error_dict)) # todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors
None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler,
prompt_executor: PromptExecutor = self._prompt_executor self._configuration,
)
if client_id is None: finally:
prompt_executor.server = _server_stub_instance with self._task_count_lock:
else: self._task_count -= 1
prompt_executor.server = self._progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple.good_output_node_ids)
return prompt_executor.outputs_ui
except Exception as exc_info:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(exc_info)
raise exc_info
return await get_event_loop().run_in_executor(self._executor, execute_prompt)

View File

@ -1,7 +1,8 @@
import asyncio import asyncio
import itertools import itertools
import os
import logging import logging
import os
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from .extra_model_paths import load_extra_path_config from .extra_model_paths import load_extra_path_config
from .main_pre import args from .main_pre import args
@ -41,7 +42,8 @@ async def main():
from ..distributed.distributed_prompt_worker import DistributedPromptWorker from ..distributed.distributed_prompt_worker import DistributedPromptWorker
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name): queue_name=args.distributed_queue_name,
executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
stop = asyncio.Event() stop = asyncio.Event()
try: try:
await stop.wait() await stop.wait()

View File

@ -1,11 +1,12 @@
from __future__ import annotations # for Python 3.7-3.9 from __future__ import annotations # for Python 3.7-3.9
import concurrent.futures
import typing import typing
from enum import Enum from enum import Enum
from typing import Optional, Literal, Protocol, Union, NamedTuple, List from typing import Optional, Literal, Protocol, Union, NamedTuple, List
import PIL.Image import PIL.Image
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict, runtime_checkable
from .queue_types import BinaryEventTypes from .queue_types import BinaryEventTypes
from ..nodes.package_typing import InputTypeSpec from ..nodes.package_typing import InputTypeSpec
@ -193,3 +194,11 @@ class NodeInputError(Exception):
class NodeNotFoundError(Exception): class NodeNotFoundError(Exception):
pass pass
class Executor(Protocol):
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
...
def shutdown(self, wait=True, *, cancel_futures=False):
...

View File

@ -40,7 +40,6 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
self.receive_all_progress_notifications = receive_all_progress_notifications self.receive_all_progress_notifications = receive_all_progress_notifications
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
from ..cmd.latent_preview_image_encoding import encode_preview_image from ..cmd.latent_preview_image_encoding import encode_preview_image

View File

@ -14,6 +14,7 @@ from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from ..client.embedded_comfy_client import EmbeddedComfyClient from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
from ..component_model.executor_types import Executor
from ..component_model.queue_types import ExecutionStatus from ..component_model.queue_types import ExecutionStatus
@ -26,13 +27,15 @@ class DistributedPromptWorker:
connection_uri: str = "amqp://localhost:5672/", connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui", queue_name: str = "comfyui",
health_check_port: int = 9090, health_check_port: int = 9090,
loop: Optional[AbstractEventLoop] = None): loop: Optional[AbstractEventLoop] = None,
executor: Optional[Executor] = None):
self._rpc = None self._rpc = None
self._channel = None self._channel = None
self._exit_stack = AsyncExitStack() self._exit_stack = AsyncExitStack()
self._queue_name = queue_name self._queue_name = queue_name
self._connection_uri = connection_uri self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop() self._loop = loop or asyncio.get_event_loop()
self._executor = executor
self._embedded_comfy_client = embedded_comfy_client self._embedded_comfy_client = embedded_comfy_client
self._health_check_port = health_check_port self._health_check_port = health_check_port
self._health_check_site: Optional[web.TCPSite] = None self._health_check_site: Optional[web.TCPSite] = None
@ -94,7 +97,7 @@ class DistributedPromptWorker:
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False) self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
if self._embedded_comfy_client is None: if self._embedded_comfy_client is None:
self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop)) self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop), executor=self._executor)
if not self._embedded_comfy_client.is_running: if not self._embedded_comfy_client.is_running:
await self._exit_stack.enter_async_context(self._embedded_comfy_client) await self._exit_stack.enter_async_context(self._embedded_comfy_client)

View File

@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, Literal, List from typing import Tuple, Literal, List, Callable
from ..api.components.schema.prompt import PromptDict, Prompt from ..api.components.schema.prompt import PromptDict, Prompt
from ..auth.permissions import ComfyJwt, jwt_decode from ..auth.permissions import ComfyJwt, jwt_decode
from ..cli_args_types import Configuration
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus

View File

@ -0,0 +1,19 @@
import concurrent.futures
from pebble import ProcessPool
from ..component_model.executor_types import Executor
class ProcessPoolExecutor(ProcessPool, Executor):
def shutdown(self, wait=True, *, cancel_futures=False):
if cancel_futures:
raise NotImplementedError("cannot cancel futures in this implementation")
if wait:
self.close()
else:
self.stop()
return
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)

View File

@ -1,17 +1,25 @@
import platform
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState from bitsandbytes.nn.modules import Params4bit, QuantState
has_bitsandbytes = True has_bitsandbytes = True
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
bnb = {} class bnb:
Params4bit = {} pass
QuantState = {}
class Params4bit:
pass
class QuantState:
pass
has_bitsandbytes = False has_bitsandbytes = False
import torch import torch
import comfy.ops import comfy.ops
import comfy.sd import comfy.sd
from comfy.cmd.folder_paths import get_folder_paths from comfy.cmd.folder_paths import get_folder_paths

View File

@ -64,4 +64,5 @@ spandrel_extra_arches
ml_dtypes ml_dtypes
diffusers>=0.30.1 diffusers>=0.30.1
vtracer vtracer
skia-python skia-python
pebble>=5.0.7

View File

@ -86,8 +86,8 @@ def has_gpu() -> bool:
yield has_gpu yield has_gpu
@pytest.fixture(scope="module", autouse=False) @pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str: def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
""" """
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
:return: :return:
@ -97,6 +97,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors") hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
tmp_path = tmp_path_factory.mktemp("comfy_background_server") tmp_path = tmp_path_factory.mktemp("comfy_background_server")
executor_factory = request.param
processes_to_close: List[subprocess.Popen] = [] processes_to_close: List[subprocess.Popen] = []
from testcontainers.rabbitmq import RabbitMqContainer from testcontainers.rabbitmq import RabbitMqContainer
with RabbitMqContainer("rabbitmq:latest") as rabbitmq: with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
@ -119,6 +120,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
"--port=9002", "--port=9002",
f"-w={str(tmp_path)}", f"-w={str(tmp_path)}",
f"--distributed-queue-connection-uri={connection_uri}", f"--distributed-queue-connection-uri={connection_uri}",
f"--executor-factory={executor_factory}"
] ]
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable
import jwt import jwt
import pytest import pytest
@ -10,9 +11,11 @@ from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.aio_client import AsyncRemoteComfyClient
from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
from comfy.component_model.executor_types import Executor
from comfy.component_model.make_mutable import make_mutable from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.server_stub import ServerStub from comfy.distributed.server_stub import ServerStub
@ -35,12 +38,11 @@ async def test_sign_jwt_auth_none():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_queue_worker() -> None: @pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
# there are lots of side effects from importing that we have to deal with async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None:
with RabbitMqContainer("rabbitmq:latest") as rabbitmq: with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params() params = rabbitmq.get_connection_params()
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"): async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}", executor=executor_factory(max_workers=1)):
# this unfortunately does a bunch of initialization on the test thread # this unfortunately does a bunch of initialization on the test thread
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
# now submit some jobs # now submit some jobs
@ -125,14 +127,15 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_queue_worker_with_health_check(): @pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
async def test_basic_queue_worker_with_health_check(executor_factory):
with RabbitMqContainer("rabbitmq:latest") as rabbitmq: with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params() params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
health_check_port = 9090 health_check_port = 9090
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker: async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port, executor=executor_factory(max_workers=1)) as worker:
health_check_url = f"http://localhost:{health_check_port}/health" health_check_url = f"http://localhost:{health_check_port}/health"
health_check_ok = await check_health(health_check_url) health_check_ok = await check_health(health_check_url)
assert health_check_ok, "Health check server did not start properly" assert health_check_ok, "Health check server did not start properly"