diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 40f523c7b..0fc38a9cf 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -198,6 +198,13 @@ def _create_parser() -> EnhancedConfigArgParser: help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.", ) + parser.add_argument( + "--executor-factory", + type=str, + default="ThreadPoolExecutor", + help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed." + ) + # now give plugins a chance to add configuration for entry_point in entry_points().select(group='comfyui.custom_config'): try: diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 729465a75..70af63f6f 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -109,6 +109,7 @@ class Configuration(dict): otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint. force_channels_last (bool): Force channels last format when inferencing the models. force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure. + executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor """ def __init__(self, **kwargs): @@ -194,6 +195,8 @@ class Configuration(dict): for key, value in kwargs.items(): self[key] = value + self.executor_factory: str = "ThreadPoolExecutor" + def __getattr__(self, item): if item not in self: return None diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 577b9ac22..659e66fa6 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -1,151 +1,159 @@ from __future__ import annotations import asyncio +import contextvars import gc import json import uuid from asyncio import get_event_loop from concurrent.futures import ThreadPoolExecutor +from multiprocessing import RLock from typing import Optional -from opentelemetry import context -from opentelemetry.trace import Span, Status, StatusCode +from opentelemetry import context, propagate +from opentelemetry.context import Context, attach, detach +from opentelemetry.trace import Status, StatusCode from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration from ..cmd.main_pre import tracer -from ..component_model.executor_types import ExecutorToClientProgress +from ..component_model.executor_types import ExecutorToClientProgress, Executor from ..component_model.make_mutable import make_mutable +from ..distributed.process_pool_executor import ProcessPoolExecutor from ..distributed.server_stub import ServerStub -_server_stub_instance = ServerStub() +_prompt_executor = contextvars.ContextVar('prompt_executor') + + +def _execute_prompt( + prompt: dict, + prompt_id: str, + client_id: str, + span_context: dict, + progress_handler: ExecutorToClientProgress | None, + configuration: Configuration | None) -> dict: + span_context: Context = propagate.extract(span_context) + token = attach(span_context) + try: + return __execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration) + finally: + detach(token) + + +def __execute_prompt( + prompt: dict, + prompt_id: str, + client_id: str, + span_context: Context, + progress_handler: ExecutorToClientProgress | None, + configuration: Configuration | None) -> dict: + from .. import options + progress_handler = progress_handler or ServerStub() + + try: + prompt_executor = _prompt_executor.get() + except LookupError: + if configuration is None: + options.enable_args_parsing() + else: + from ..cmd.main_pre import args + args.clear() + args.update(configuration) + + from ..cmd.execution import PromptExecutor + with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span: + prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0) + prompt_executor.raise_exceptions = True + _prompt_executor.set(prompt_executor) + + with tracer.start_as_current_span("Execute Prompt", context=span_context) as span: + try: + prompt_mut = make_mutable(prompt) + from ..cmd.execution import validate_prompt + validation_tuple = validate_prompt(prompt_mut) + if not validation_tuple.valid: + validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors + raise ValueError(json.dumps(validation_error_dict)) + + if client_id is None: + prompt_executor.server = ServerStub() + else: + prompt_executor.server = progress_handler + + prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, + execute_outputs=validation_tuple.good_output_node_ids) + return prompt_executor.outputs_ui + except Exception as exc_info: + span.set_status(Status(StatusCode.ERROR)) + span.record_exception(exc_info) + raise exc_info + + +def _cleanup(): + from .. import model_management + model_management.unload_all_models() + gc.collect() + try: + model_management.soft_empty_cache() + except: + pass class EmbeddedComfyClient: - """ - Embedded client for comfy executing prompts as a library. - - This client manages a single-threaded executor to run long-running or blocking tasks - asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor - in a dedicated thread for executing prompts and handling server-stub communications. - - Example usage: - - Asynchronous (non-blocking) usage with async-await: - ``` - # Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your - # workspace. - prompt_dict = { - "1": {"class_type": "KSamplerAdvanced", ...} - ... - } - - # Validate your workflow (the prompt) - from comfy.api.components.schema.prompt import Prompt - prompt = Prompt.validate(prompt_dict) - # Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor. - # It does not connect to a remote server. - async def main(): - async with EmbeddedComfyClient() as client: - outputs = await client.queue_prompt(prompt) - print(outputs) - print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI") - - if __name__ == "__main__" - asyncio.run(main()) - ``` - - In order to use this in blocking methods, learn more about asyncio online. - """ - - def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1): + def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None): self._progress_handler = progress_handler or ServerStub() - self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._executor = executor or ThreadPoolExecutor(max_workers=max_workers) self._configuration = configuration - # we don't want to import the executor yet - self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None self._is_running = False + self._task_count_lock = RLock() + self._task_count = 0 @property def is_running(self) -> bool: return self._is_running + @property + def task_count(self) -> int: + return self._task_count + async def __aenter__(self): - await self._initialize_prompt_executor() self._is_running = True return self async def __aexit__(self, *args): - # Perform cleanup here - def cleanup(): - from .. import model_management - model_management.unload_all_models() - gc.collect() - try: - model_management.soft_empty_cache() - except: - pass - # wait until the queue is done - while self._executor._work_queue.qsize() > 0: + while self.task_count > 0: await asyncio.sleep(0.1) - await get_event_loop().run_in_executor(self._executor, cleanup) + await get_event_loop().run_in_executor(self._executor, _cleanup) self._executor.shutdown(wait=True) self._is_running = False - async def _initialize_prompt_executor(self): - # This method must be async since it's used in __aenter__ - def create_executor_in_thread(): - from .. import options - if self._configuration is None: - options.enable_args_parsing() - else: - from ..cmd.main_pre import args - args.clear() - args.update(self._configuration) - - from ..cmd.execution import PromptExecutor - - self._prompt_executor = PromptExecutor(self._progress_handler, lru_size=self._configuration.cache_lru if self._configuration is not None else 0) - self._prompt_executor.raise_exceptions = True - - await get_event_loop().run_in_executor(self._executor, create_executor_in_thread) - @tracer.start_as_current_span("Queue Prompt") async def queue_prompt(self, prompt: PromptDict | dict, prompt_id: Optional[str] = None, client_id: Optional[str] = None) -> dict: + with self._task_count_lock: + self._task_count += 1 prompt_id = prompt_id or str(uuid.uuid4()) client_id = client_id or self._progress_handler.client_id or None span_context = context.get_current() - - def execute_prompt() -> dict: - spam: Span - with tracer.start_as_current_span("Execute Prompt", context=span_context) as span: - from ..cmd.execution import PromptExecutor, validate_prompt - try: - prompt_mut = make_mutable(prompt) - validation_tuple = validate_prompt(prompt_mut) - if not validation_tuple.valid: - validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors - raise ValueError(json.dumps(validation_error_dict)) - - prompt_executor: PromptExecutor = self._prompt_executor - - if client_id is None: - prompt_executor.server = _server_stub_instance - else: - prompt_executor.server = self._progress_handler - - prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, - execute_outputs=validation_tuple.good_output_node_ids) - return prompt_executor.outputs_ui - except Exception as exc_info: - span.set_status(Status(StatusCode.ERROR)) - span.record_exception(exc_info) - raise exc_info - - return await get_event_loop().run_in_executor(self._executor, execute_prompt) + carrier = {} + propagate.inject(carrier, span_context) + try: + return await get_event_loop().run_in_executor( + self._executor, + _execute_prompt, + make_mutable(prompt), + prompt_id, + client_id, + carrier, + # todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors + None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler, + self._configuration, + ) + finally: + with self._task_count_lock: + self._task_count -= 1 diff --git a/comfy/cmd/worker.py b/comfy/cmd/worker.py index b3a2f6cf7..37a795af6 100644 --- a/comfy/cmd/worker.py +++ b/comfy/cmd/worker.py @@ -1,7 +1,8 @@ import asyncio import itertools -import os import logging +import os +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor from .extra_model_paths import load_extra_path_config from .main_pre import args @@ -41,7 +42,8 @@ async def main(): from ..distributed.distributed_prompt_worker import DistributedPromptWorker async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, - queue_name=args.distributed_queue_name): + queue_name=args.distributed_queue_name, + executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)): stop = asyncio.Event() try: await stop.wait() diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 0459d4de9..612a5d7bf 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -1,11 +1,12 @@ from __future__ import annotations # for Python 3.7-3.9 +import concurrent.futures import typing from enum import Enum from typing import Optional, Literal, Protocol, Union, NamedTuple, List import PIL.Image -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, TypedDict, runtime_checkable from .queue_types import BinaryEventTypes from ..nodes.package_typing import InputTypeSpec @@ -193,3 +194,11 @@ class NodeInputError(Exception): class NodeNotFoundError(Exception): pass + + +class Executor(Protocol): + def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: + ... + + def shutdown(self, wait=True, *, cancel_futures=False): + ... \ No newline at end of file diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py index 348626348..2db22f540 100644 --- a/comfy/distributed/distributed_progress.py +++ b/comfy/distributed/distributed_progress.py @@ -40,7 +40,6 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress): self.receive_all_progress_notifications = receive_all_progress_notifications async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: - # for now, do not send binary data this way, since it cannot be json serialized / it's impractical if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: from ..cmd.latent_preview_image_encoding import encode_preview_image diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index 8416e453a..76a48911a 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -14,6 +14,7 @@ from .distributed_progress import DistributedExecutorToClientProgress from .distributed_types import RpcRequest, RpcReply from ..client.embedded_comfy_client import EmbeddedComfyClient from ..cmd.main_pre import tracer +from ..component_model.executor_types import Executor from ..component_model.queue_types import ExecutionStatus @@ -26,13 +27,15 @@ class DistributedPromptWorker: connection_uri: str = "amqp://localhost:5672/", queue_name: str = "comfyui", health_check_port: int = 9090, - loop: Optional[AbstractEventLoop] = None): + loop: Optional[AbstractEventLoop] = None, + executor: Optional[Executor] = None): self._rpc = None self._channel = None self._exit_stack = AsyncExitStack() self._queue_name = queue_name self._connection_uri = connection_uri self._loop = loop or asyncio.get_event_loop() + self._executor = executor self._embedded_comfy_client = embedded_comfy_client self._health_check_port = health_check_port self._health_check_site: Optional[web.TCPSite] = None @@ -94,7 +97,7 @@ class DistributedPromptWorker: self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False) if self._embedded_comfy_client is None: - self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop)) + self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop), executor=self._executor) if not self._embedded_comfy_client.is_running: await self._exit_stack.enter_async_context(self._embedded_comfy_client) diff --git a/comfy/distributed/distributed_types.py b/comfy/distributed/distributed_types.py index 35c420d66..2ecde40f0 100644 --- a/comfy/distributed/distributed_types.py +++ b/comfy/distributed/distributed_types.py @@ -1,10 +1,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Tuple, Literal, List +from typing import Tuple, Literal, List, Callable from ..api.components.schema.prompt import PromptDict, Prompt from ..auth.permissions import ComfyJwt, jwt_decode +from ..cli_args_types import Configuration +from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus diff --git a/comfy/distributed/process_pool_executor.py b/comfy/distributed/process_pool_executor.py new file mode 100644 index 000000000..6a198f47f --- /dev/null +++ b/comfy/distributed/process_pool_executor.py @@ -0,0 +1,19 @@ +import concurrent.futures + +from pebble import ProcessPool + +from ..component_model.executor_types import Executor + + +class ProcessPoolExecutor(ProcessPool, Executor): + def shutdown(self, wait=True, *, cancel_futures=False): + if cancel_futures: + raise NotImplementedError("cannot cancel futures in this implementation") + if wait: + self.close() + else: + self.stop() + return + + def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: + return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None) \ No newline at end of file diff --git a/comfy_extras/nodes/nodes_nf4.py b/comfy_extras/nodes/nodes_nf4.py index cae80f71c..b551ff8bf 100644 --- a/comfy_extras/nodes/nodes_nf4.py +++ b/comfy_extras/nodes/nodes_nf4.py @@ -1,17 +1,25 @@ -import platform - try: import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit, QuantState has_bitsandbytes = True except (ImportError, ModuleNotFoundError): - bnb = {} - Params4bit = {} - QuantState = {} + class bnb: + pass + + + class Params4bit: + pass + + + class QuantState: + pass + + has_bitsandbytes = False import torch + import comfy.ops import comfy.sd from comfy.cmd.folder_paths import get_folder_paths diff --git a/requirements.txt b/requirements.txt index 6c7df91eb..5c0eb7f7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -64,4 +64,5 @@ spandrel_extra_arches ml_dtypes diffusers>=0.30.1 vtracer -skia-python \ No newline at end of file +skia-python +pebble>=5.0.7 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f36264f17..2224783f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,8 +86,8 @@ def has_gpu() -> bool: yield has_gpu -@pytest.fixture(scope="module", autouse=False) -def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str: +@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"]) +def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str: """ populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend :return: @@ -97,6 +97,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str: hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors") tmp_path = tmp_path_factory.mktemp("comfy_background_server") + executor_factory = request.param processes_to_close: List[subprocess.Popen] = [] from testcontainers.rabbitmq import RabbitMqContainer with RabbitMqContainer("rabbitmq:latest") as rabbitmq: @@ -119,6 +120,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str: "--port=9002", f"-w={str(tmp_path)}", f"--distributed-queue-connection-uri={connection_uri}", + f"--executor-factory={executor_factory}" ] processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 8924bb28f..4fb228362 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -1,6 +1,7 @@ import asyncio import uuid from concurrent.futures import ThreadPoolExecutor +from typing import Callable import jwt import pytest @@ -10,9 +11,11 @@ from testcontainers.rabbitmq import RabbitMqContainer from comfy.client.aio_client import AsyncRemoteComfyClient from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner +from comfy.component_model.executor_types import Executor from comfy.component_model.make_mutable import make_mutable from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker +from comfy.distributed.process_pool_executor import ProcessPoolExecutor from comfy.distributed.server_stub import ServerStub @@ -35,12 +38,11 @@ async def test_sign_jwt_auth_none(): @pytest.mark.asyncio -async def test_basic_queue_worker() -> None: - # there are lots of side effects from importing that we have to deal with - +@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,)) +async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None: with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() - async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"): + async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}", executor=executor_factory(max_workers=1)): # this unfortunately does a bunch of initialization on the test thread from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue # now submit some jobs @@ -125,14 +127,15 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0) @pytest.mark.asyncio -async def test_basic_queue_worker_with_health_check(): +@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,)) +async def test_basic_queue_worker_with_health_check(executor_factory): with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" health_check_port = 9090 - async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker: + async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port, executor=executor_factory(max_workers=1)) as worker: health_check_url = f"http://localhost:{health_check_port}/health" health_check_ok = await check_health(health_check_url) - assert health_check_ok, "Health check server did not start properly" \ No newline at end of file + assert health_check_ok, "Health check server did not start properly"