mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Support ProcessPoolExecutor to improve memory management
This commit is contained in:
parent
c75b9964ab
commit
ed33ab1e7d
@ -198,6 +198,13 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--executor-factory",
|
||||||
|
type=str,
|
||||||
|
default="ThreadPoolExecutor",
|
||||||
|
help="When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and large, contiguous blocks of CUDA memory can be freed."
|
||||||
|
)
|
||||||
|
|
||||||
# now give plugins a chance to add configuration
|
# now give plugins a chance to add configuration
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -109,6 +109,7 @@ class Configuration(dict):
|
|||||||
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
||||||
force_channels_last (bool): Force channels last format when inferencing the models.
|
force_channels_last (bool): Force channels last format when inferencing the models.
|
||||||
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||||
|
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -194,6 +195,8 @@ class Configuration(dict):
|
|||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
self[key] = value
|
self[key] = value
|
||||||
|
|
||||||
|
self.executor_factory: str = "ThreadPoolExecutor"
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
if item not in self:
|
if item not in self:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -1,151 +1,159 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import get_event_loop
|
from asyncio import get_event_loop
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from multiprocessing import RLock
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from opentelemetry import context
|
from opentelemetry import context, propagate
|
||||||
from opentelemetry.trace import Span, Status, StatusCode
|
from opentelemetry.context import Context, attach, detach
|
||||||
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..cli_args_types import Configuration
|
from ..cli_args_types import Configuration
|
||||||
from ..cmd.main_pre import tracer
|
from ..cmd.main_pre import tracer
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress, Executor
|
||||||
from ..component_model.make_mutable import make_mutable
|
from ..component_model.make_mutable import make_mutable
|
||||||
|
from ..distributed.process_pool_executor import ProcessPoolExecutor
|
||||||
from ..distributed.server_stub import ServerStub
|
from ..distributed.server_stub import ServerStub
|
||||||
|
|
||||||
_server_stub_instance = ServerStub()
|
_prompt_executor = contextvars.ContextVar('prompt_executor')
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_prompt(
|
||||||
|
prompt: dict,
|
||||||
|
prompt_id: str,
|
||||||
|
client_id: str,
|
||||||
|
span_context: dict,
|
||||||
|
progress_handler: ExecutorToClientProgress | None,
|
||||||
|
configuration: Configuration | None) -> dict:
|
||||||
|
span_context: Context = propagate.extract(span_context)
|
||||||
|
token = attach(span_context)
|
||||||
|
try:
|
||||||
|
return __execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration)
|
||||||
|
finally:
|
||||||
|
detach(token)
|
||||||
|
|
||||||
|
|
||||||
|
def __execute_prompt(
|
||||||
|
prompt: dict,
|
||||||
|
prompt_id: str,
|
||||||
|
client_id: str,
|
||||||
|
span_context: Context,
|
||||||
|
progress_handler: ExecutorToClientProgress | None,
|
||||||
|
configuration: Configuration | None) -> dict:
|
||||||
|
from .. import options
|
||||||
|
progress_handler = progress_handler or ServerStub()
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt_executor = _prompt_executor.get()
|
||||||
|
except LookupError:
|
||||||
|
if configuration is None:
|
||||||
|
options.enable_args_parsing()
|
||||||
|
else:
|
||||||
|
from ..cmd.main_pre import args
|
||||||
|
args.clear()
|
||||||
|
args.update(configuration)
|
||||||
|
|
||||||
|
from ..cmd.execution import PromptExecutor
|
||||||
|
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
|
||||||
|
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
|
||||||
|
prompt_executor.raise_exceptions = True
|
||||||
|
_prompt_executor.set(prompt_executor)
|
||||||
|
|
||||||
|
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
||||||
|
try:
|
||||||
|
prompt_mut = make_mutable(prompt)
|
||||||
|
from ..cmd.execution import validate_prompt
|
||||||
|
validation_tuple = validate_prompt(prompt_mut)
|
||||||
|
if not validation_tuple.valid:
|
||||||
|
validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors
|
||||||
|
raise ValueError(json.dumps(validation_error_dict))
|
||||||
|
|
||||||
|
if client_id is None:
|
||||||
|
prompt_executor.server = ServerStub()
|
||||||
|
else:
|
||||||
|
prompt_executor.server = progress_handler
|
||||||
|
|
||||||
|
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||||
|
execute_outputs=validation_tuple.good_output_node_ids)
|
||||||
|
return prompt_executor.outputs_ui
|
||||||
|
except Exception as exc_info:
|
||||||
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
|
span.record_exception(exc_info)
|
||||||
|
raise exc_info
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup():
|
||||||
|
from .. import model_management
|
||||||
|
model_management.unload_all_models()
|
||||||
|
gc.collect()
|
||||||
|
try:
|
||||||
|
model_management.soft_empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EmbeddedComfyClient:
|
class EmbeddedComfyClient:
|
||||||
"""
|
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
|
||||||
Embedded client for comfy executing prompts as a library.
|
|
||||||
|
|
||||||
This client manages a single-threaded executor to run long-running or blocking tasks
|
|
||||||
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
|
|
||||||
in a dedicated thread for executing prompts and handling server-stub communications.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
|
|
||||||
Asynchronous (non-blocking) usage with async-await:
|
|
||||||
```
|
|
||||||
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
|
|
||||||
# workspace.
|
|
||||||
prompt_dict = {
|
|
||||||
"1": {"class_type": "KSamplerAdvanced", ...}
|
|
||||||
...
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate your workflow (the prompt)
|
|
||||||
from comfy.api.components.schema.prompt import Prompt
|
|
||||||
prompt = Prompt.validate(prompt_dict)
|
|
||||||
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
|
|
||||||
# It does not connect to a remote server.
|
|
||||||
async def main():
|
|
||||||
async with EmbeddedComfyClient() as client:
|
|
||||||
outputs = await client.queue_prompt(prompt)
|
|
||||||
print(outputs)
|
|
||||||
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
|
|
||||||
|
|
||||||
if __name__ == "__main__"
|
|
||||||
asyncio.run(main())
|
|
||||||
```
|
|
||||||
|
|
||||||
In order to use this in blocking methods, learn more about asyncio online.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1):
|
|
||||||
self._progress_handler = progress_handler or ServerStub()
|
self._progress_handler = progress_handler or ServerStub()
|
||||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self._configuration = configuration
|
self._configuration = configuration
|
||||||
# we don't want to import the executor yet
|
|
||||||
self._prompt_executor: Optional["comfy.cmd.execution.PromptExecutor"] = None
|
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
self._task_count_lock = RLock()
|
||||||
|
self._task_count = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return self._is_running
|
return self._is_running
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_count(self) -> int:
|
||||||
|
return self._task_count
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
await self._initialize_prompt_executor()
|
|
||||||
self._is_running = True
|
self._is_running = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
async def __aexit__(self, *args):
|
||||||
# Perform cleanup here
|
|
||||||
def cleanup():
|
|
||||||
from .. import model_management
|
|
||||||
model_management.unload_all_models()
|
|
||||||
gc.collect()
|
|
||||||
try:
|
|
||||||
model_management.soft_empty_cache()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# wait until the queue is done
|
while self.task_count > 0:
|
||||||
while self._executor._work_queue.qsize() > 0:
|
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
await get_event_loop().run_in_executor(self._executor, cleanup)
|
await get_event_loop().run_in_executor(self._executor, _cleanup)
|
||||||
|
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
|
||||||
async def _initialize_prompt_executor(self):
|
|
||||||
# This method must be async since it's used in __aenter__
|
|
||||||
def create_executor_in_thread():
|
|
||||||
from .. import options
|
|
||||||
if self._configuration is None:
|
|
||||||
options.enable_args_parsing()
|
|
||||||
else:
|
|
||||||
from ..cmd.main_pre import args
|
|
||||||
args.clear()
|
|
||||||
args.update(self._configuration)
|
|
||||||
|
|
||||||
from ..cmd.execution import PromptExecutor
|
|
||||||
|
|
||||||
self._prompt_executor = PromptExecutor(self._progress_handler, lru_size=self._configuration.cache_lru if self._configuration is not None else 0)
|
|
||||||
self._prompt_executor.raise_exceptions = True
|
|
||||||
|
|
||||||
await get_event_loop().run_in_executor(self._executor, create_executor_in_thread)
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Queue Prompt")
|
@tracer.start_as_current_span("Queue Prompt")
|
||||||
async def queue_prompt(self,
|
async def queue_prompt(self,
|
||||||
prompt: PromptDict | dict,
|
prompt: PromptDict | dict,
|
||||||
prompt_id: Optional[str] = None,
|
prompt_id: Optional[str] = None,
|
||||||
client_id: Optional[str] = None) -> dict:
|
client_id: Optional[str] = None) -> dict:
|
||||||
|
with self._task_count_lock:
|
||||||
|
self._task_count += 1
|
||||||
prompt_id = prompt_id or str(uuid.uuid4())
|
prompt_id = prompt_id or str(uuid.uuid4())
|
||||||
client_id = client_id or self._progress_handler.client_id or None
|
client_id = client_id or self._progress_handler.client_id or None
|
||||||
span_context = context.get_current()
|
span_context = context.get_current()
|
||||||
|
carrier = {}
|
||||||
def execute_prompt() -> dict:
|
propagate.inject(carrier, span_context)
|
||||||
spam: Span
|
try:
|
||||||
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
return await get_event_loop().run_in_executor(
|
||||||
from ..cmd.execution import PromptExecutor, validate_prompt
|
self._executor,
|
||||||
try:
|
_execute_prompt,
|
||||||
prompt_mut = make_mutable(prompt)
|
make_mutable(prompt),
|
||||||
validation_tuple = validate_prompt(prompt_mut)
|
prompt_id,
|
||||||
if not validation_tuple.valid:
|
client_id,
|
||||||
validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors
|
carrier,
|
||||||
raise ValueError(json.dumps(validation_error_dict))
|
# todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors
|
||||||
|
None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler,
|
||||||
prompt_executor: PromptExecutor = self._prompt_executor
|
self._configuration,
|
||||||
|
)
|
||||||
if client_id is None:
|
finally:
|
||||||
prompt_executor.server = _server_stub_instance
|
with self._task_count_lock:
|
||||||
else:
|
self._task_count -= 1
|
||||||
prompt_executor.server = self._progress_handler
|
|
||||||
|
|
||||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
|
||||||
execute_outputs=validation_tuple.good_output_node_ids)
|
|
||||||
return prompt_executor.outputs_ui
|
|
||||||
except Exception as exc_info:
|
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
|
||||||
span.record_exception(exc_info)
|
|
||||||
raise exc_info
|
|
||||||
|
|
||||||
return await get_event_loop().run_in_executor(self._executor, execute_prompt)
|
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||||
|
|
||||||
from .extra_model_paths import load_extra_path_config
|
from .extra_model_paths import load_extra_path_config
|
||||||
from .main_pre import args
|
from .main_pre import args
|
||||||
@ -41,7 +42,8 @@ async def main():
|
|||||||
|
|
||||||
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
|
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||||
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
||||||
queue_name=args.distributed_queue_name):
|
queue_name=args.distributed_queue_name,
|
||||||
|
executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
|
||||||
stop = asyncio.Event()
|
stop = asyncio.Event()
|
||||||
try:
|
try:
|
||||||
await stop.wait()
|
await stop.wait()
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
from __future__ import annotations # for Python 3.7-3.9
|
from __future__ import annotations # for Python 3.7-3.9
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
import typing
|
import typing
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict, runtime_checkable
|
||||||
|
|
||||||
from .queue_types import BinaryEventTypes
|
from .queue_types import BinaryEventTypes
|
||||||
from ..nodes.package_typing import InputTypeSpec
|
from ..nodes.package_typing import InputTypeSpec
|
||||||
@ -193,3 +194,11 @@ class NodeInputError(Exception):
|
|||||||
|
|
||||||
class NodeNotFoundError(Exception):
|
class NodeNotFoundError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Executor(Protocol):
|
||||||
|
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||||
|
...
|
||||||
|
|
||||||
|
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||||
|
...
|
||||||
@ -40,7 +40,6 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
|||||||
self.receive_all_progress_notifications = receive_all_progress_notifications
|
self.receive_all_progress_notifications = receive_all_progress_notifications
|
||||||
|
|
||||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
|
||||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||||
from ..cmd.latent_preview_image_encoding import encode_preview_image
|
from ..cmd.latent_preview_image_encoding import encode_preview_image
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from .distributed_progress import DistributedExecutorToClientProgress
|
|||||||
from .distributed_types import RpcRequest, RpcReply
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from ..cmd.main_pre import tracer
|
from ..cmd.main_pre import tracer
|
||||||
|
from ..component_model.executor_types import Executor
|
||||||
from ..component_model.queue_types import ExecutionStatus
|
from ..component_model.queue_types import ExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -26,13 +27,15 @@ class DistributedPromptWorker:
|
|||||||
connection_uri: str = "amqp://localhost:5672/",
|
connection_uri: str = "amqp://localhost:5672/",
|
||||||
queue_name: str = "comfyui",
|
queue_name: str = "comfyui",
|
||||||
health_check_port: int = 9090,
|
health_check_port: int = 9090,
|
||||||
loop: Optional[AbstractEventLoop] = None):
|
loop: Optional[AbstractEventLoop] = None,
|
||||||
|
executor: Optional[Executor] = None):
|
||||||
self._rpc = None
|
self._rpc = None
|
||||||
self._channel = None
|
self._channel = None
|
||||||
self._exit_stack = AsyncExitStack()
|
self._exit_stack = AsyncExitStack()
|
||||||
self._queue_name = queue_name
|
self._queue_name = queue_name
|
||||||
self._connection_uri = connection_uri
|
self._connection_uri = connection_uri
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
|
self._executor = executor
|
||||||
self._embedded_comfy_client = embedded_comfy_client
|
self._embedded_comfy_client = embedded_comfy_client
|
||||||
self._health_check_port = health_check_port
|
self._health_check_port = health_check_port
|
||||||
self._health_check_site: Optional[web.TCPSite] = None
|
self._health_check_site: Optional[web.TCPSite] = None
|
||||||
@ -94,7 +97,7 @@ class DistributedPromptWorker:
|
|||||||
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
self._rpc = await JsonRPC.create(channel=self._channel, auto_delete=True, durable=False)
|
||||||
|
|
||||||
if self._embedded_comfy_client is None:
|
if self._embedded_comfy_client is None:
|
||||||
self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop))
|
self._embedded_comfy_client = EmbeddedComfyClient(progress_handler=DistributedExecutorToClientProgress(self._rpc, self._queue_name, self._loop), executor=self._executor)
|
||||||
if not self._embedded_comfy_client.is_running:
|
if not self._embedded_comfy_client.is_running:
|
||||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple, Literal, List
|
from typing import Tuple, Literal, List, Callable
|
||||||
|
|
||||||
from ..api.components.schema.prompt import PromptDict, Prompt
|
from ..api.components.schema.prompt import PromptDict, Prompt
|
||||||
from ..auth.permissions import ComfyJwt, jwt_decode
|
from ..auth.permissions import ComfyJwt, jwt_decode
|
||||||
|
from ..cli_args_types import Configuration
|
||||||
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
|
from ..component_model.queue_types import NamedQueueTuple, TaskInvocation, ExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
19
comfy/distributed/process_pool_executor.py
Normal file
19
comfy/distributed/process_pool_executor.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
from pebble import ProcessPool
|
||||||
|
|
||||||
|
from ..component_model.executor_types import Executor
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessPoolExecutor(ProcessPool, Executor):
|
||||||
|
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||||
|
if cancel_futures:
|
||||||
|
raise NotImplementedError("cannot cancel futures in this implementation")
|
||||||
|
if wait:
|
||||||
|
self.close()
|
||||||
|
else:
|
||||||
|
self.stop()
|
||||||
|
return
|
||||||
|
|
||||||
|
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||||
|
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
||||||
@ -1,17 +1,25 @@
|
|||||||
import platform
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||||
|
|
||||||
has_bitsandbytes = True
|
has_bitsandbytes = True
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
bnb = {}
|
class bnb:
|
||||||
Params4bit = {}
|
pass
|
||||||
QuantState = {}
|
|
||||||
|
|
||||||
|
class Params4bit:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QuantState:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
has_bitsandbytes = False
|
has_bitsandbytes = False
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
from comfy.cmd.folder_paths import get_folder_paths
|
from comfy.cmd.folder_paths import get_folder_paths
|
||||||
|
|||||||
@ -64,4 +64,5 @@ spandrel_extra_arches
|
|||||||
ml_dtypes
|
ml_dtypes
|
||||||
diffusers>=0.30.1
|
diffusers>=0.30.1
|
||||||
vtracer
|
vtracer
|
||||||
skia-python
|
skia-python
|
||||||
|
pebble>=5.0.7
|
||||||
@ -86,8 +86,8 @@ def has_gpu() -> bool:
|
|||||||
yield has_gpu
|
yield has_gpu
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=False)
|
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
||||||
def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory) -> str:
|
||||||
"""
|
"""
|
||||||
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
||||||
:return:
|
:return:
|
||||||
@ -97,6 +97,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
|||||||
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
||||||
|
|
||||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
|
executor_factory = request.param
|
||||||
processes_to_close: List[subprocess.Popen] = []
|
processes_to_close: List[subprocess.Popen] = []
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
@ -119,6 +120,7 @@ def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
|||||||
"--port=9002",
|
"--port=9002",
|
||||||
f"-w={str(tmp_path)}",
|
f"-w={str(tmp_path)}",
|
||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
|
f"--executor-factory={executor_factory}"
|
||||||
]
|
]
|
||||||
|
|
||||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
@ -10,9 +11,11 @@ from testcontainers.rabbitmq import RabbitMqContainer
|
|||||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
from comfy.client.sdxl_with_refiner_workflow import sdxl_workflow_with_refiner
|
||||||
|
from comfy.component_model.executor_types import Executor
|
||||||
from comfy.component_model.make_mutable import make_mutable
|
from comfy.component_model.make_mutable import make_mutable
|
||||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
||||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||||
|
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
|
||||||
from comfy.distributed.server_stub import ServerStub
|
from comfy.distributed.server_stub import ServerStub
|
||||||
|
|
||||||
|
|
||||||
@ -35,12 +38,11 @@ async def test_sign_jwt_auth_none():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_queue_worker() -> None:
|
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
|
||||||
# there are lots of side effects from importing that we have to deal with
|
async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None:
|
||||||
|
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
params = rabbitmq.get_connection_params()
|
params = rabbitmq.get_connection_params()
|
||||||
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}"):
|
async with DistributedPromptWorker(connection_uri=f"amqp://guest:guest@127.0.0.1:{params.port}", executor=executor_factory(max_workers=1)):
|
||||||
# this unfortunately does a bunch of initialization on the test thread
|
# this unfortunately does a bunch of initialization on the test thread
|
||||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
# now submit some jobs
|
# now submit some jobs
|
||||||
@ -125,14 +127,15 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_queue_worker_with_health_check():
|
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
|
||||||
|
async def test_basic_queue_worker_with_health_check(executor_factory):
|
||||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
params = rabbitmq.get_connection_params()
|
params = rabbitmq.get_connection_params()
|
||||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
health_check_port = 9090
|
health_check_port = 9090
|
||||||
|
|
||||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port, executor=executor_factory(max_workers=1)) as worker:
|
||||||
health_check_url = f"http://localhost:{health_check_port}/health"
|
health_check_url = f"http://localhost:{health_check_port}/health"
|
||||||
|
|
||||||
health_check_ok = await check_health(health_check_url)
|
health_check_ok = await check_health(health_check_url)
|
||||||
assert health_check_ok, "Health check server did not start properly"
|
assert health_check_ok, "Health check server did not start properly"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user