mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Support ProcessPoolExecutor to improve memory management
This commit is contained in:
parent
c75b9964ab
commit
ed33ab1e7d
@ -198,6 +198,13 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--executor-factory",
|
||||
type=str,
|
||||
default="ThreadPoolExecutor",
|
||||
help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed."
|
||||
)
|
||||
|
||||
# now give plugins a chance to add configuration
|
||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||
try:
|
||||
|
||||
@ -109,6 +109,7 @@ class Configuration(dict):
|
||||
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
||||
force_channels_last (bool): Force channels last format when inferencing the models.
|
||||
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -194,6 +195,8 @@ class Configuration(dict):
|
||||
for key, value in kwargs.items():
|
||||
self[key] = value
|
||||
|
||||
self.executor_factory: str = "ThreadPoolExecutor"
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item not in self:
|
||||
return None
|
||||
|
||||
@ -1,151 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import gc
|
||||
import json
|
||||
import uuid
|
||||
from asyncio import get_event_loop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from multiprocessing import RLock
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import context
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
from opentelemetry import context, propagate
|
||||
from opentelemetry.context import Context, attach, detach
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..cli_args_types import Configuration
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, Executor
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
from ..distributed.process_pool_executor import ProcessPoolExecutor
|
||||
from ..distributed.server_stub import ServerStub
|
||||
|
||||
_server_stub_instance = ServerStub()
|
||||
_prompt_executor = contextvars.ContextVar('prompt_executor')
|
||||
|
||||
|
||||
def _execute_prompt(
|
||||
prompt: dict,
|
||||
prompt_id: str,
|
||||
client_id: str,
|
||||
span_context: dict,
|
||||
progress_handler: ExecutorToClientProgress | None,
|
||||
configuration: Configuration | None) -> dict:
|
||||
span_context: Context = propagate.extract(span_context)
|
||||
token = attach(span_context)
|
||||
try:
|
||||
return __execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration)
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
|
||||
def __execute_prompt(
|
||||
prompt: dict,
|
||||
prompt_id: str,
|
||||
client_id: str,
|
||||
span_context: Context,
|
||||
progress_handler: ExecutorToClientProgress | None,
|
||||
configuration: Configuration | None) -> dict:
|
||||
from .. import options
|
||||
progress_handler = progress_handler or ServerStub()
|
||||
|
||||
try:
|
||||
prompt_executor = _prompt_executor.get()
|
||||
except LookupError:
|
||||
if configuration is None:
|
||||
options.enable_args_parsing()
|
||||
else:
|
||||
from ..cmd.main_pre import args
|
||||
args.clear()
|
||||
args.update(configuration)
|
||||
|
||||
from ..cmd.execution import PromptExecutor
|
||||
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
|
||||
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
|
||||
prompt_executor.raise_exceptions = True
|
||||
_prompt_executor.set(prompt_executor)
|
||||
|
||||
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
||||
try:
|
||||
prompt_mut = make_mutable(prompt)
|
||||
from ..cmd.execution import validate_prompt
|
||||
validation_tuple = validate_prompt(prompt_mut)
|
||||
if not validation_tuple.valid:
|
||||
validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors
|
||||
raise ValueError(json.dumps(validation_error_dict))
|
||||
|
||||
if client_id is None:
|
||||
prompt_executor.server = ServerStub()
|
||||
else:
|
||||
prompt_executor.server = progress_handler
|
||||
|
||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple.good_output_node_ids)
|
||||
return prompt_executor.outputs_ui
|
||||
except Exception as exc_info:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(exc_info)
|
||||
raise exc_info
|
||||
|
||||
|
||||
def _cleanup():
|
||||
from .. import model_management
|
||||
model_management.unload_all_models()
|
||||
gc.collect()
|
||||
try:
|
||||
model_management.soft_empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddedComfyClient:
|
||||
"""
|
||||
Embedded client for comfy executing prompts as a library.
|
||||
|
||||
This client manages a single-threaded executor to run long-running or blocking tasks
|
||||
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
|
||||
in a dedicated thread for executing prompts and handling server-stub communications.
|
||||
|
||||
Example usage:
|
||||
|
||||
Asynchronous (non-blocking) usage with async-await:
|
||||
```
|
||||
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
|
||||
# workspace.
|
||||
prompt_dict = {
|
||||
"1": {"class_type": "KSamplerAdvanced", ...}
|
||||
...
|
||||
}
|
||||
|
||||
# Validate your workflow (the prompt)
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
prompt = Prompt.validate(prompt_dict)
|
||||
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
|
||||
# It does not connect to a remote server.
|
||||
async def main():
|
||||
async with EmbeddedComfyClient() as client:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
print(outputs)
|
||||
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
|
||||
|
||||
if __name__ == "__main__"
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
In order to use this in blocking methods, learn more about asyncio online.
|
||||
"""
|
||||
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1):
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
|
||||
self._progress_handler = progress_handler or ServerStub()
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._configuration = configuration
|
||||
# we don't want to import the executor yet
|
||||
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
|
||||
self._is_running = False
|
||||
self._task_count_lock = RLock()
|
||||
self._task_count = 0
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._is_running
|
||||
|
||||
@property
|
||||
def task_count(self) -> int:
|
||||
return self._task_count
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._initialize_prompt_executor()
|
||||
self._is_running = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
# Perform cleanup here
|
||||
def cleanup():
|
||||
from .. import model_management
|
||||
model_management.unload_all_models()
|
||||
gc.collect()
|
||||
try:
|
||||
model_management.soft_empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
# wait until the queue is done
|
||||
while self._executor._work_queue.qsize() > 0:
|
||||
while self.task_count > 0:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await get_event_loop().run_in_executor(self._executor, cleanup)
|
||||
await get_event_loop().run_in_executor(self._executor, _cleanup)
|
||||
|
||||
self._executor.shutdown(wait=True)
|
||||
self._is_running = False
|
||||
|
||||
async def _initialize_prompt_executor(self):
|
||||
# This method must be async since it's used in __aenter__
|
||||
def create_executor_in_thread():
|
||||
from .. import options
|
||||
if self._configuration is None:
|
||||
options.enable_args_parsing()
|
||||
else:
|
||||
from ..cmd.main_pre import args
|
||||
args.clear()
|
||||
args.update(self._configuration)
|
||||
|
||||
from ..cmd.execution import PromptExecutor
|
||||
|
||||
self._prompt_executor = PromptExecutor(self._progress_handler, lru_size=self._configuration.cache_lru if self._configuration is not None else 0)
|
||||
self._prompt_executor.raise_exceptions = True
|
||||
|
||||
await get_event_loop().run_in_executor(self._executor, create_executor_in_thread)
|
||||
|
||||
@tracer.start_as_current_span("Queue Prompt")
|
||||
async def queue_prompt(self,
|
||||
prompt: PromptDict | dict,
|
||||
prompt_id: Optional[str] = None,
|
||||
client_id: Optional[str] = None) -> dict:
|
||||
with self._task_count_lock:
|
||||
self._task_count += 1
|
||||
prompt_id = prompt_id or str(uuid.uuid4())
|
||||
client_id = client_id or self._progress_handler.client_id or None
|
||||
span_context = context.get_current()
|
||||
|
||||
def execute_prompt() -> dict:
|
||||
spam: Span
|
||||
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
||||
from ..cmd.execution import PromptExecutor, validate_prompt
|
||||
try:
|
||||
prompt_mut = make_mutable(prompt)
|
||||
validation_tuple = validate_prompt(prompt_mut)
|
||||
if not validation_tuple.valid:
|
||||
validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors
|
||||
raise ValueError(json.dumps(validation_error_dict))
|
||||
|
||||
prompt_executor: PromptExecutor = self._prompt_executor
|
||||
|
||||
if client_id is None:
|
||||
prompt_executor.server = _server_stub_instance
|
||||
else:
|
||||
prompt_executor.server = self._progress_handler
|
||||
|
||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple.good_output_node_ids)
|
||||
return prompt_executor.outputs_ui
|
||||
except Exception as exc_info:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(exc_info)
|
||||
raise exc_info
|
||||
|
||||
return await get_event_loop().run_in_executor(self._executor, execute_prompt)
|
||||
carrier = {}
|
||||
propagate.inject(carrier, span_context)
|
||||
try:
|
||||
return await get_event_loop().run_in_executor(
|
||||
self._executor,
|
||||
_execute_prompt,
|
||||
make_mutable(prompt),
|
||||
prompt_id,
|
||||
client_id,
|
||||
carrier,
|
||||
# todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors
|
||||
None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler,
|
||||
self._configuration,
|
||||
)
|
||||
finally:
|
||||
with self._task_count_lock:
|
||||
self._task_count -= 1
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .main_pre import args
|
||||
@ -41,7 +42,8 @@ async def main():
|
||||
|
||||
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
||||
queue_name=args.distributed_queue_name):
|
||||
queue_name=args.distributed_queue_name,
|
||||
executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
|
||||
stop = asyncio.Event()
|
||||
try:
|
||||
await stop.wait()
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from __future__ import annotations # for Python 3.7-3.9
|
||||
|
||||
import concurrent.futures
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
||||
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict, runtime_checkable
|
||||
|
||||
from .queue_types import BinaryEventTypes
|
||||
from ..nodes.package_typing import InputTypeSpec
|
||||
@ -193,3 +194,11 @@ class NodeInputError(Exception):
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Executor(Protocol):
|
||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||
...
|
||||
|
||||
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||
...
|
||||
@ -40,7 +40,6 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
self.receive_all_progress_notifications = receive_all_progress_notifications
|
||||
|
||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
from ..cmd.latent_preview_image_encoding import encode_preview_image
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from .distributed_progress import DistributedExecutorToClientProgress
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.executor_types import Executor
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
|
||||
|
||||
@ -26,13 +27,15 @@ class DistributedPromptWorker:
|
||||
connection_uri: str = "amqp://localhost:5672/",
|
||||
queue_name: str = "comfyui",
|
||||
health_check_port: int = 9090,
|
||||
loop: Optional[AbstractEventLoop] = None):
|
||||
loop: Optional[AbstractEventLoop] = None,
|
||||
executor: Optional[Executor] = None):
|
||||
self._rpc = None
|
||||
self._channel = None
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._queue_name = queue_name
|
||||
self._connection_uri = connection_uri
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._executor = executor
|
||||
self._embedded_comfy_client = embedded_comfy_client
|
||||
self._health_check_port = health_check_port
|
||||
self._health_check_site: Optional[web.TCPSite] = None
|
||||
@ -94,7 +97,7 @@ class DistributedPromptWorker:
|
||||
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
||||
|
||||
if self._embedded_comfy_client is None:
|
||||
self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
|
||||
self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop), executor=self._executor)
|
||||
if not self._embedded_comfy_client.is_running:
|
||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Literal, List
|
||||
from typing import Tuple, Literal, List, Callable
|
||||
|
||||
from ..api.components.schema.prompt import PromptDict, Prompt
|
||||
from ..auth.permissions import ComfyJwt, jwt_decode
|
||||
from ..cli_args_types import Configuration
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
|
||||
|
||||
|
||||
|
||||
19
comfy/distributed/process_pool_executor.py
Normal file
19
comfy/distributed/process_pool_executor.py
Normal file
@ -0,0 +1,19 @@
|
||||
import concurrent.futures
|
||||
|
||||
from pebble import ProcessPool
|
||||
|
||||
from ..component_model.executor_types import Executor
|
||||
|
||||
|
||||
class ProcessPoolExecutor(ProcessPool, Executor):
|
||||
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||
if cancel_futures:
|
||||
raise NotImplementedError("cannot cancel futures in this implementation")
|
||||
if wait:
|
||||
self.close()
|
||||
else:
|
||||
self.stop()
|
||||
return
|
||||
|
||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
||||
@ -1,17 +1,25 @@
|
||||
import platform
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||
|
||||
has_bitsandbytes = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
bnb = {}
|
||||
Params4bit = {}
|
||||
QuantState = {}
|
||||
class bnb:
|
||||
pass
|
||||
|
||||
|
||||
class Params4bit:
|
||||
pass
|
||||
|
||||
|
||||
class QuantState:
|
||||
pass
|
||||
|
||||
|
||||
has_bitsandbytes = False
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.ops
|
||||
import comfy.sd
|
||||
from comfy.cmd.folder_paths import get_folder_paths
|
||||
|
||||
@ -64,4 +64,5 @@ spandrel_extra_arches
|
||||
ml_dtypes
|
||||
diffusers>=0.30.1
|
||||
vtracer
|
||||
skia-python
|
||||
skia-python
|
||||
pebble>=5.0.7
|
||||
@ -86,8 +86,8 @@ def has_gpu() -> bool:
|
||||
yield has_gpu
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=False)
|
||||
def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
||||
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
||||
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
||||
"""
|
||||
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
||||
:return:
|
||||
@ -97,6 +97,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
||||
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
||||
|
||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||
executor_factory = request.param
|
||||
processes_to_close: List[subprocess.Popen] = []
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
@ -119,6 +120,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
||||
"--port=9002",
|
||||
f"-w={str(tmp_path)}",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
f"--executor-factory={executor_factory}"
|
||||
]
|
||||
|
||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
@ -10,9 +11,11 @@ from testcontainers.rabbitmq import RabbitMqContainer
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||
from comfy.component_model.executor_types import Executor
|
||||
from comfy.component_model.make_mutable import make_mutable
|
||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
|
||||
|
||||
@ -35,12 +38,11 @@ async def test_sign_jwt_auth_none():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_queue_worker() -> None:
|
||||
# there are lots of side effects from importing that we have to deal with
|
||||
|
||||
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
|
||||
async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None:
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"):
|
||||
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}", executor=executor_factory(max_workers=1)):
|
||||
# this unfortunately does a bunch of initialization on the test thread
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
# now submit some jobs
|
||||
@ -125,14 +127,15 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_queue_worker_with_health_check():
|
||||
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
|
||||
async def test_basic_queue_worker_with_health_check(executor_factory):
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
health_check_port = 9090
|
||||
|
||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port, executor=executor_factory(max_workers=1)) as worker:
|
||||
health_check_url = f"http://localhost:{health_check_port}/health"
|
||||
|
||||
health_check_ok = await check_health(health_check_url)
|
||||
assert health_check_ok, "Health check server did not start properly"
|
||||
assert health_check_ok, "Health check server did not start properly"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user