ComfyUI/comfy/client/embedded_comfy_client.py
2025-12-12 12:11:25 -08:00

475 lines
21 KiB
Python

from __future__ import annotations
from ..cmd.main_pre import tracer
import asyncio
import concurrent.futures
import contextlib
import copy
import gc
import json
import logging
import threading
import uuid
from asyncio import get_event_loop
from multiprocessing import RLock
from typing import Optional, Literal
from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from .async_progress_iterable import QueuePromptWithProgress
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..cli_args import default_configuration
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable
from ..component_model.queue_types import QueueItem, ExecutionStatus, TaskInvocation, QueueTuple, ExtraData
from ..distributed.executors import ContextVarExecutor
from ..distributed.history import History
from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub
from ..component_model.configuration import MODEL_MANAGEMENT_ARGS, requires_process_pool_executor
_prompt_executor = threading.local()
logger = logging.getLogger(__name__)
def _execute_prompt(
prompt: dict,
prompt_id: str,
client_id: str,
span_context: dict,
progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None,
partial_execution_targets: Optional[list[str]] = None) -> dict:
configuration = copy.deepcopy(configuration) if configuration is not None else None
from ..execution_context import current_execution_context
execution_context = current_execution_context()
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
init_default_paths(execution_context.folder_names_and_paths, configuration, replace_existing=True)
span_context: Context = propagate.extract(span_context)
token = attach(span_context)
try:
# there is never an event loop running on a thread or process pool thread here
# this also guarantees nodes will be able to successfully call await
return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration, partial_execution_targets))
finally:
detach(token)
async def __execute_prompt(
prompt: dict,
prompt_id: str,
client_id: str,
span_context: Context,
progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None,
partial_execution_targets: list[str] | None) -> dict:
from ..execution_context import context_configuration
with context_configuration(configuration):
return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets)
async def ___execute_prompt(
prompt: dict,
prompt_id: str,
client_id: str,
span_context: Context,
progress_handler: ExecutorToClientProgress | None,
partial_execution_targets: list[str] | None) -> dict:
from ..cmd.execution import PromptExecutor
progress_handler = progress_handler or ServerStub()
prompt_executor: PromptExecutor = None
try:
prompt_executor: PromptExecutor = _prompt_executor.executor
except (LookupError, AttributeError):
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
# todo: deal with new caching features
prompt_executor = PromptExecutor(progress_handler)
prompt_executor.raise_exceptions = True
_prompt_executor.executor = prompt_executor
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
try:
prompt_mut = make_mutable(prompt)
from ..cmd.execution import validate_prompt
validation_tuple = await validate_prompt(prompt_id, prompt_mut, partial_execution_targets)
if not validation_tuple.valid:
if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0:
validation_error_dict = validation_tuple.node_errors
elif validation_tuple.error is not None:
validation_error_dict = validation_tuple.error
else:
validation_error_dict = {"message": "Unknown", "details": ""}
raise ValueError(json.dumps(validation_error_dict))
if client_id is None:
prompt_executor.server = ServerStub()
else:
prompt_executor.server = progress_handler
await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple.good_output_node_ids)
return prompt_executor.outputs_ui
except Exception as exc_info:
span.set_status(Status(StatusCode.ERROR))
span.record_exception(exc_info)
raise exc_info
def _cleanup(invalidate_nodes=True):
from ..cmd.execution import PromptExecutor
from ..nodes_context import invalidate
try:
prompt_executor: PromptExecutor = _prompt_executor.executor
# this should clear all references to output tensors and make it easier to collect back the memory
prompt_executor.reset()
except (LookupError, AttributeError):
pass
from .. import model_management
model_management.unload_all_models()
gc.collect()
try:
model_management.soft_empty_cache()
except:
pass
if invalidate_nodes:
try:
invalidate()
except:
pass
class Comfy:
"""
A client for running ComfyUI workflows within a Python application.
This client allows you to execute ComfyUI workflows (in API JSON format) programmatically.
It manages the execution environment, including model loading and resource cleanup.
### Configuration and Executors
ComfyUI relies on global state for model management (e.g., loaded models in VRAM). To handle this safely, `Comfy`
executes workflows using one of two strategies based on your `configuration`:
1. **ContextVarExecutor (Default)**: Runs in a thread pool within the current process.
- **Pros**: Efficient, low overhead.
- **Cons**: Modifies global state in the current process.
- **Use Case**: Standard workflows where you are happy with the default ComfyUI settings or sharing state.
2. **ProcessPoolExecutor**: Runs in a separate process.
- **Pros**: Complete isolation. Configuration changes (like `lowvram`) do not affect the main process.
- **Cons**: Higher overhead (process startup).
- **Use Case**: Required when `configuration` overrides arguments that affect global model management state.
These arguments include: `lowvram`, `highvram`, `cpu`, `gpu_only`, `deterministic`, `directml`,
various `fp8`/`fp16`/`bf16` settings, and attention optimizations (e.g., `use_flash_attention`).
The client automatically selects `ProcessPoolExecutor` if you provide a `configuration` that modifies any of these
global settings, unless you explicitly pass an `executor`.
### Parameters
- **configuration** (`Optional[Configuration]`): A dictionary of arguments to override defaults.
See `comfy.cli_args_types.Configuration`.
Example: `{"lowvram": True}` or `{"gpu_only": True}`.
- **progress_handler** (`Optional[ExecutorToClientProgress]`): callback handler for progress updates and previews.
- **max_workers** (`int`): Maximum number of concurrent workflows (default: 1).
- **executor** (`Optional[Union[Executor, str]]`): Explicitly define the executor to use.
- Pass an instance of `ProcessPoolExecutor` or `ContextVarExecutor`.
- Pass the string `"ProcessPoolExecutor"` or `"ContextVarExecutor"` to force initialization of that type.
- If `None` (default), the best executor is chosen based on `configuration`.
### Examples
#### 1. Running a Workflow (Basic)
This example executes a simple workflow and prints the path of the saved image.
```python
import asyncio
from comfy.client.embedded_comfy_client import Comfy
# A simple API format workflow (simplified for brevity)
prompt_dict = {
"3": {
"class_type": "KSampler",
"inputs": {
"seed": 8566257, "steps": 20, "cfg": 8, "sampler_name": "euler",
"scheduler": "normal", "denoise": 1,
"model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0],
"latent_image": ["5", 0]
}
},
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}},
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "masterpiece best quality girl", "clip": ["4", 1]}},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "bad hands", "clip": ["4", 1]}},
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["3", 0], "vae": ["4", 2]}},
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "ComfyUI_API", "images": ["8", 0]}}
}
async def main():
# Using default configuration (runs in-process)
async with Comfy() as client:
# Queue the prompt and await the result
outputs = await client.queue_prompt(prompt_dict)
# Retrieve the output path from the SaveImage node (Node ID "9")
image_path = outputs["9"]["images"][0]["abs_path"]
print(f"Image saved to: {image_path}")
# asyncio.run(main())
```
#### 2. Using Custom Configuration (Isolated Process)
To run with specific settings like `lowvram`, pass the configuration. This implies `ProcessPoolExecutor`.
```python
async def run_lowvram():
# This will spawn a new process with lowvram enabled
async with Comfy(configuration={"lowvram": True}) as client:
outputs = await client.queue_prompt(prompt_dict)
print("Finished lowvram generation")
```
#### 3. Programmatically Building Workflows
You can use `GraphBuilder` constructing workflows with a more pythonic API.
```python
from comfy_execution.graph_utils import GraphBuilder
def build_graph():
builder = GraphBuilder()
checkpoint = builder.node("CheckpointLoaderSimple", ckpt_name="v1-5-pruned-emaonly.safetensors")
latent = builder.node("EmptyLatentImage", width=512, height=512, batch_size=1)
pos = builder.node("CLIPTextEncode", text="masterpiece", clip=checkpoint.out(1))
neg = builder.node("CLIPTextEncode", text="bad quality", clip=checkpoint.out(1))
sampler = builder.node("KSampler",
seed=42, steps=20, cfg=8, sampler_name="euler", scheduler="normal", denoise=1,
model=checkpoint.out(0), positive=pos.out(0), negative=neg.out(0), latent_image=latent.out(0)
)
vae = builder.node("VAEDecode", samples=sampler.out(0), vae=checkpoint.out(2))
builder.node("SaveImage", filename_prefix="Generated", images=vae.out(0))
return builder.finalize()
async def run_builder():
prompt = build_graph()
async with Comfy() as client:
await client.queue_prompt(prompt)
```
#### 4. Streaming Progress and Previews
To receive real-time progress updates and preview images (e.g., step-by-step decoding).
```python
from comfy.component_model.queue_types import BinaryEventTypes
async def run_streaming():
async with Comfy() as client:
# Get a task that supports progress iteration
task = client.queue_with_progress(prompt_dict)
async for notification in task.progress():
if notification.event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# 'data' contains the PIL Image and metadata
image, metadata = notification.data
print(f"Received preview: {image.size}")
elif notification.event == "progress":
print(f"Step: {notification.data['value']}/{notification.data['max']}")
# Await final result
result = await task.get()
```
"""
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None):
self._progress_handler = progress_handler or ServerStub()
self._default_configuration = default_configuration()
self._configuration = configuration
need_process_pool = requires_process_pool_executor(configuration)
if executor is None:
if need_process_pool:
self._executor = ProcessPoolExecutor(max_workers=max_workers)
self._owns_executor = True
else:
self._executor = ContextVarExecutor(max_workers=max_workers)
self._owns_executor = True
elif isinstance(executor, str):
self._owns_executor = True
if executor == "ProcessPoolExecutor":
self._executor = ProcessPoolExecutor(max_workers=max_workers)
elif executor == "ContextVarExecutor":
if need_process_pool:
raise ValueError(f"Configuration requires ProcessPoolExecutor but ContextVarExecutor was requested. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
self._executor = ContextVarExecutor(max_workers=max_workers)
else:
raise ValueError(f"Unknown executor type string: {executor}")
else:
# Executor instance passed
self._owns_executor = False
self._executor = executor
if need_process_pool and not isinstance(executor, ProcessPoolExecutor):
raise ValueError(f"Configuration requires ProcessPoolExecutor but {type(executor).__name__} was passed. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
self._is_running = False
self._task_count_lock = RLock()
self._task_count = 0
self._history = History()
self._exit_stack = None
self._async_exit_stack = None
@property
def is_running(self) -> bool:
return self._is_running
@property
def task_count(self) -> int:
return self._task_count
def __enter__(self):
self._exit_stack = contextlib.ExitStack()
self._is_running = True
from ..execution_context import context_configuration
cm = context_configuration(self._configuration)
self._exit_stack.enter_context(cm)
if self._owns_executor:
self._exit_stack.enter_context(self._executor)
return self
@property
def history(self) -> History:
return self._history
async def clear_cache(self):
await get_event_loop().run_in_executor(self._executor, _cleanup, False)
def __exit__(self, *args):
get_event_loop().run_in_executor(self._executor, _cleanup)
self._is_running = False
self._exit_stack.__exit__(*args)
async def __aenter__(self):
self._async_exit_stack = contextlib.AsyncExitStack()
self._is_running = True
from ..execution_context import context_configuration
cm = context_configuration(self._configuration)
self._async_exit_stack.enter_context(cm)
if self._owns_executor:
self._async_exit_stack.enter_context(self._executor)
return self
async def __aexit__(self, *args):
while self.task_count > 0:
await asyncio.sleep(0.1)
await get_event_loop().run_in_executor(self._executor, _cleanup)
self._is_running = False
await self._async_exit_stack.__aexit__(*args)
async def queue_prompt_api(self,
prompt: PromptDict | str | dict,
progress_handler: Optional[ExecutorToClientProgress] = None) -> V1QueuePromptResponse:
"""
Queues a prompt for execution, returning the output when it is complete.
:param prompt: a PromptDict, string or dictionary containing a so-called Workflow API prompt
:return: a response of URLs for Save-related nodes and the node outputs
"""
if isinstance(prompt, str):
prompt = json.loads(prompt)
if isinstance(prompt, dict):
from ..api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt)
outputs = await self.queue_prompt(prompt, progress_handler=progress_handler)
return V1QueuePromptResponse(urls=[], outputs=outputs)
def queue_with_progress(self, prompt: PromptDict | str | dict) -> QueuePromptWithProgress:
"""
Queues a prompt with progress notifications.
>>> from comfy.client.embedded_comfy_client import Comfy
>>> from comfy.client.client_types import ProgressNotification
>>> async with Comfy() as comfy:
>>> task = comfy.queue_with_progress({ ... })
>>> # Raises an exception while iterating
>>> notification: ProgressNotification
>>> async for notification in task.progress():
>>> print(notification.data)
>>> # If you get this far, no errors occurred.
>>> result = await task.get()
:param prompt:
:return:
"""
handler = QueuePromptWithProgress()
task = asyncio.create_task(self.queue_prompt_api(prompt, progress_handler=handler.progress_handler))
task.add_done_callback(handler.complete)
return handler
@tracer.start_as_current_span("Queue Prompt")
async def queue_prompt(self,
prompt: PromptDict | dict,
prompt_id: Optional[str] = None,
client_id: Optional[str] = None,
partial_execution_targets: Optional[list[str]] = None,
progress_handler: Optional[ExecutorToClientProgress] = None) -> dict:
if isinstance(self._executor, ProcessPoolExecutor) and progress_handler is not None:
logger.debug(f"a progress_handler={progress_handler} was passed, it must be pickleable to support ProcessPoolExecutor")
progress_handler = progress_handler or self._progress_handler
with self._task_count_lock:
self._task_count += 1
prompt_id = prompt_id or str(uuid.uuid4())
assert prompt_id is not None
client_id = client_id or self._progress_handler.client_id or None
span_context = context.get_current()
carrier = {}
propagate.inject(carrier, span_context)
# setup history
prompt = make_mutable(prompt)
try:
outputs = await get_event_loop().run_in_executor(
self._executor,
_execute_prompt,
prompt,
prompt_id,
client_id,
carrier,
# todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors
None if isinstance(self._executor, ProcessPoolExecutor) else progress_handler,
self._configuration,
partial_execution_targets,
)
fut = concurrent.futures.Future()
fut.set_result(TaskInvocation(prompt_id, copy.deepcopy(outputs), ExecutionStatus('success', True, [])))
self._history.put(QueueItem(queue_tuple=QueueTuple(float(self._task_count), prompt_id, prompt, ExtraData(), [], {}), completed=fut), outputs, ExecutionStatus('success', True, []))
return outputs
except Exception as exc_info:
fut = concurrent.futures.Future()
fut.set_exception(exc_info)
self._history.put(QueueItem(queue_tuple=QueueTuple(float(self._task_count), prompt_id, prompt, ExtraData(), [], {}), completed=fut), {}, ExecutionStatus('error', False, [str(exc_info)]))
raise exc_info
finally:
with self._task_count_lock:
self._task_count -= 1
def __str__(self):
diff = {k: v for k, v in (self._configuration or {}).items() if v != self._default_configuration.get(k)}
return f"<Comfy task_count={self.task_count} configuration={diff} executor={self._executor}>"
EmbeddedComfyClient = Comfy