diff --git a/README.md b/README.md index af7f4701c..e99b5797b 100644 --- a/README.md +++ b/README.md @@ -306,19 +306,28 @@ choco install -y vcredist2010 vcredist2013 vcredist140 Then, visit [NVIDIA.com's CUDA Toolkit Download Page](https://developer.nvidia.com/cuda-12-6-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=Server2022&target_type=exe_network) and download and install the CUDA Toolkit. Verify it is correctly installed by running `nvcc --version`. -You are now ready to install Sage Attention 2 and Flash Attention. +You are now ready to install Sage Attention 2. ### Linux ```shell uv pip install --no-build-isolation "sageattention@git+https://github.com/thu-ml/SageAttention.git" -uv pip install --no-build-isolation flash_attn ``` ### Windows +Run this PowerShell script to install the correct version of Sage Attention for your installed PyTorch version: + ```powershell -uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html sageattention +$torch_version = (uv pip freeze | Select-String "torch==").ToString().Trim() +$cuda_version = $torch_version -replace ".*(cu\d+).*", "`$1" +if ($torch_version -match "\+cu") { + $v = $torch_version -replace "torch==", "" + $package_specifier = "sageattention==2.2.0+$($cuda_version)torch$v" + uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html $package_specifier +} else { + Write-Host "Could not determine CUDA version from torch version: $torch_version" +} ``` To start ComfyUI with it: @@ -333,6 +342,26 @@ uv run comfyui --use-sage-attention ![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp) **With PyTorch Attention** +## Flash Attention + +Flash Attention 2 is supported on Linux only. + +```shell +uv pip install --no-build-isolation flash_attn +``` + +To start ComfyUI with it: + +```shell +uv run comfyui --use-flash-attention +``` + +![with_sage_attention.webp](./docs/assets/with_sage_attention.webp) +**With SageAttention** + +![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp) +**With PyTorch Attention** + ## Cosmos Prompt Upsampling The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly. diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index ee178cea5..cde995430 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -32,6 +32,7 @@ from ..distributed.executors import ContextVarExecutor from ..distributed.history import History from ..distributed.process_pool_executor import ProcessPoolExecutor from ..distributed.server_stub import ServerStub +from ..component_model.configuration import MODEL_MANAGEMENT_ARGS, requires_process_pool_executor _prompt_executor = threading.local() @@ -147,50 +148,181 @@ def _cleanup(invalidate_nodes=True): class Comfy: """ - This manages a single-threaded executor to run long-running or blocking workflows - asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor - in a dedicated thread for executing prompts and handling server-stub communications. - Example usage: + A client for running ComfyUI workflows within a Python application. - Asynchronous (non-blocking) usage with async-await: - ``` - # Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your - # workspace. + This client allows you to execute ComfyUI workflows (in API JSON format) programmatically. + It manages the execution environment, including model loading and resource cleanup. + + ### Configuration and Executors + + ComfyUI relies on global state for model management (e.g., loaded models in VRAM). To handle this safely, `Comfy` + executes workflows using one of two strategies based on your `configuration`: + + 1. **ContextVarExecutor (Default)**: Runs in a thread pool within the current process. + - **Pros**: Efficient, low overhead. + - **Cons**: Modifies global state in the current process. + - **Use Case**: Standard workflows where you are happy with the default ComfyUI settings or sharing state. + + 2. **ProcessPoolExecutor**: Runs in a separate process. + - **Pros**: Complete isolation. Configuration changes (like `lowvram`) do not affect the main process. + - **Cons**: Higher overhead (process startup). + - **Use Case**: Required when `configuration` overrides arguments that affect global model management state. + These arguments include: `lowvram`, `highvram`, `cpu`, `gpu_only`, `deterministic`, `directml`, + various `fp8`/`fp16`/`bf16` settings, and attention optimizations (e.g., `use_flash_attention`). + + The client automatically selects `ProcessPoolExecutor` if you provide a `configuration` that modifies any of these + global settings, unless you explicitly pass an `executor`. + + ### Parameters + + - **configuration** (`Optional[Configuration]`): A dictionary of arguments to override defaults. + See `comfy.cli_args_types.Configuration`. + Example: `{"lowvram": True}` or `{"gpu_only": True}`. + - **progress_handler** (`Optional[ExecutorToClientProgress]`): callback handler for progress updates and previews. + - **max_workers** (`int`): Maximum number of concurrent workflows (default: 1). + - **executor** (`Optional[Union[Executor, str]]`): Explicitly define the executor to use. + - Pass an instance of `ProcessPoolExecutor` or `ContextVarExecutor`. + - Pass the string `"ProcessPoolExecutor"` or `"ContextVarExecutor"` to force initialization of that type. + - If `None` (default), the best executor is chosen based on `configuration`. + + ### Examples + + #### 1. Running a Workflow (Basic) + + This example executes a simple workflow and prints the path of the saved image. + + ```python + import asyncio + from comfy.client.embedded_comfy_client import Comfy + + # A simple API format workflow (simplified for brevity) prompt_dict = { - "1": {"class_type": "KSamplerAdvanced", ...} - ... + "3": { + "class_type": "KSampler", + "inputs": { + "seed": 8566257, "steps": 20, "cfg": 8, "sampler_name": "euler", + "scheduler": "normal", "denoise": 1, + "model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0], + "latent_image": ["5", 0] + } + }, + "4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}}, + "5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}}, + "6": {"class_type": "CLIPTextEncode", "inputs": {"text": "masterpiece best quality girl", "clip": ["4", 1]}}, + "7": {"class_type": "CLIPTextEncode", "inputs": {"text": "bad hands", "clip": ["4", 1]}}, + "8": {"class_type": "VAEDecode", "inputs": {"samples": ["3", 0], "vae": ["4", 2]}}, + "9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "ComfyUI_API", "images": ["8", 0]}} } - # Validate your workflow (the prompt) - from comfy.api.components.schema.prompt import Prompt - prompt = Prompt.validate(prompt_dict) - # Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor. - # It does not connect to a remote server. + async def main(): - async with EmbeddedComfyClient() as client: - outputs = await client.queue_prompt(prompt) - print(outputs) - print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI") - if __name__ == "__main__" - asyncio.run(main()) + # Using default configuration (runs in-process) + async with Comfy() as client: + # Queue the prompt and await the result + outputs = await client.queue_prompt(prompt_dict) + + # Retrieve the output path from the SaveImage node (Node ID "9") + image_path = outputs["9"]["images"][0]["abs_path"] + print(f"Image saved to: {image_path}") + + # asyncio.run(main()) ``` - In order to use this in blocking methods, learn more about asyncio online. + #### 2. Using Custom Configuration (Isolated Process) + + To run with specific settings like `lowvram`, pass the configuration. This implies `ProcessPoolExecutor`. + + ```python + async def run_lowvram(): + # This will spawn a new process with lowvram enabled + async with Comfy(configuration={"lowvram": True}) as client: + outputs = await client.queue_prompt(prompt_dict) + print("Finished lowvram generation") + ``` + + #### 3. Programmatically Building Workflows + + You can use `GraphBuilder` constructing workflows with a more pythonic API. + + ```python + from comfy_execution.graph_utils import GraphBuilder + + def build_graph(): + builder = GraphBuilder() + checkpoint = builder.node("CheckpointLoaderSimple", ckpt_name="v1-5-pruned-emaonly.safetensors") + latent = builder.node("EmptyLatentImage", width=512, height=512, batch_size=1) + pos = builder.node("CLIPTextEncode", text="masterpiece", clip=checkpoint.out(1)) + neg = builder.node("CLIPTextEncode", text="bad quality", clip=checkpoint.out(1)) + + sampler = builder.node("KSampler", + seed=42, steps=20, cfg=8, sampler_name="euler", scheduler="normal", denoise=1, + model=checkpoint.out(0), positive=pos.out(0), negative=neg.out(0), latent_image=latent.out(0) + ) + vae = builder.node("VAEDecode", samples=sampler.out(0), vae=checkpoint.out(2)) + builder.node("SaveImage", filename_prefix="Generated", images=vae.out(0)) + return builder.finalize() + + async def run_builder(): + prompt = build_graph() + async with Comfy() as client: + await client.queue_prompt(prompt) + ``` + + #### 4. Streaming Progress and Previews + + To receive real-time progress updates and preview images (e.g., step-by-step decoding). + + ```python + from comfy.component_model.queue_types import BinaryEventTypes + + async def run_streaming(): + async with Comfy() as client: + # Get a task that supports progress iteration + task = client.queue_with_progress(prompt_dict) + + async for notification in task.progress(): + if notification.event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA: + # 'data' contains the PIL Image and metadata + image, metadata = notification.data + print(f"Received preview: {image.size}") + elif notification.event == "progress": + print(f"Step: {notification.data['value']}/{notification.data['max']}") + + # Await final result + result = await task.get() + ``` """ def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None): self._progress_handler = progress_handler or ServerStub() - self._owns_executor = executor is None or isinstance(executor, str) - if self._owns_executor: - if isinstance(executor, str): - if executor == "ProcessPoolExecutor": - self._executor = ProcessPoolExecutor(max_workers=max_workers) - else: - self._executor = ContextVarExecutor(max_workers=max_workers) - else: - assert not isinstance(executor, str) - self._executor = executor self._default_configuration = default_configuration() self._configuration = configuration + + need_process_pool = requires_process_pool_executor(configuration) + + if executor is None: + if need_process_pool: + self._executor = ProcessPoolExecutor(max_workers=max_workers) + self._owns_executor = True + else: + self._executor = ContextVarExecutor(max_workers=max_workers) + self._owns_executor = True + elif isinstance(executor, str): + self._owns_executor = True + if executor == "ProcessPoolExecutor": + self._executor = ProcessPoolExecutor(max_workers=max_workers) + elif executor == "ContextVarExecutor": + if need_process_pool: + raise ValueError(f"Configuration requires ProcessPoolExecutor but ContextVarExecutor was requested. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}") + self._executor = ContextVarExecutor(max_workers=max_workers) + else: + raise ValueError(f"Unknown executor type string: {executor}") + else: + # Executor instance passed + self._owns_executor = False + self._executor = executor + if need_process_pool and not isinstance(executor, ProcessPoolExecutor): + raise ValueError(f"Configuration requires ProcessPoolExecutor but {type(executor).__name__} was passed. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}") + self._is_running = False self._task_count_lock = RLock() self._task_count = 0 diff --git a/comfy/component_model/configuration.py b/comfy/component_model/configuration.py new file mode 100644 index 000000000..7dabfc6cf --- /dev/null +++ b/comfy/component_model/configuration.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from ..cli_args_types import Configuration +from ..cli_args import default_configuration + +MODEL_MANAGEMENT_ARGS = { + "deterministic", + "directml", + "cpu", + "disable_xformers", + # todo: this is the default, so it will be omitted + # "use_pytorch_cross_attention", + "use_split_cross_attention", + "use_quad_cross_attention", + "supports_fp8_compute", + "fast", + "lowvram", + "novram", + "highvram", + "gpu_only", + "force_fp32", + "force_fp16", + "force_bf16", + "reserve_vram", + "disable_smart_memory", + "disable_ipex_optimize", + "disable_pinned_memory", + "async_offload", + "disable_async_offload", + "force_non_blocking", + "force_channels_last", + "fp32_unet", + "fp64_unet", + "bf16_unet", + "fp16_unet", + "fp8_e4m3fn_unet", + "fp8_e5m2_unet", + "fp8_e8m0fnu_unet", + "fp8_e4m3fn_text_enc", + "fp8_e5m2_text_enc", + "fp16_text_enc", + "bf16_text_enc", + "fp32_text_enc", + "cpu_vae", + "fp16_vae", + "bf16_vae", + "fp32_vae", + "force_upcast_attention", + "use_sage_attention", + "use_flash_attention", +} + + +def requires_process_pool_executor(configuration: Configuration | None) -> bool: + if configuration is None: + return False + + default = default_configuration() + for key in MODEL_MANAGEMENT_ARGS: + # Check if key is in configuration and differs from default + if key in configuration: + val = configuration[key] + # Use equality check, handling potential missing keys in default (though default should have them) + if key not in default or val != default[key]: + return True + return False diff --git a/comfy/gguf.py b/comfy/gguf.py index 985a828f8..ba49981c9 100644 --- a/comfy/gguf.py +++ b/comfy/gguf.py @@ -987,7 +987,13 @@ def gguf_tekken_tokenizer_loader(path, temb_shape): logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...") import json import base64 - from transformers.convert_slow_tokenizer import bytes_to_unicode + try: + from transformers.convert_slow_tokenizer import bytes_to_unicode + except ImportError: + try: + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + except ImportError: + raise NotImplementedError("Could not import bytes_to_unicode from transformers") reader = gguf.GGUFReader(path) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 8404c5a3f..36cb286da 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -12,7 +12,10 @@ from pathlib import Path from typing import Tuple, Sequence, TypeVar, Callable, Optional, Union import torch -from transformers import CLIPTokenizer, PreTrainedTokenizerBase +try: + from transformers import CLIPTokenizer, PreTrainedTokenizerBase +except ImportError: + from .transformers_compat import CLIPTokenizerFast as CLIPTokenizer, PreTrainedTokenizerBase from . import clip_model from . import model_management diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 98a02a72b..464d5425a 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -1,4 +1,4 @@ -from transformers import T5TokenizerFast +from ..transformers_compat import T5TokenizerFast from .t5 import T5 from .. import sd1_clip diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 5e93e0942..6ffeccf79 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,7 +1,7 @@ import copy import torch -from transformers import T5TokenizerFast, LlamaTokenizerFast +from ..transformers_compat import T5TokenizerFast, LlamaTokenizerFast from .sd3_clip import T5XXLModel from .. import sd1_clip, model_management diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 019f1f512..b8dfedc6a 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -1,4 +1,4 @@ -from transformers import T5TokenizerFast +from ..transformers_compat import T5TokenizerFast from .. import sd1_clip from . import sd3_clip diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff93ee356..9f2783e86 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -1,6 +1,6 @@ import re -from transformers import ByT5Tokenizer +from ..transformers_compat import ByT5Tokenizer from .llama import Qwen25_7BVLI from .qwen_image import QwenImageTokenizer, QwenImageTEModel diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index fc2c60afb..00c0e2b83 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -1,6 +1,6 @@ import torch import numbers -from transformers import LlamaTokenizerFast +from ..transformers_compat import LlamaTokenizerFast from .hunyuan_image import HunyuanImageTokenizer from .llama import Llama2 diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index fbd57674b..017c0037c 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -1,7 +1,7 @@ import copy import torch -from transformers import BertTokenizer +from ..transformers_compat import BertTokenizerFast as BertTokenizer from .bert import BertModel from .spiece_tokenizer import SPieceTokenizer diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index d9598b8df..199a85561 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -1,4 +1,4 @@ -from transformers import T5TokenizerFast +from ..transformers_compat import T5TokenizerFast from .genmo import mochi_te from .. import sd1_clip diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index dc9941488..6ecb02b33 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -1,4 +1,4 @@ -from transformers import Qwen2Tokenizer +from ..transformers_compat import Qwen2Tokenizer from .llama import Qwen25_3B from .. import sd1_clip diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index e56c21ccc..b604e926b 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -1,8 +1,7 @@ import numbers import torch -from transformers import Qwen2Tokenizer - +from ..transformers_compat import Qwen2Tokenizer from . import llama from .. import sd1_clip from ..component_model import files diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index 44adde026..f56057db3 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -1,4 +1,4 @@ -from transformers import T5TokenizerFast +from ..transformers_compat import T5TokenizerFast from . import sd3_clip from .. import sd1_clip @@ -40,7 +40,6 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) - def pixart_te(dtype_t5=None, t5_quantization_metadata=None): class PixArtTEModel_(PixArtT5XXL): def __init__(self, device="cpu", dtype=None, model_options=None): diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index ce4773d27..c8498e21c 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -1,6 +1,6 @@ import numbers import torch -from transformers import Qwen2Tokenizer +from ..transformers_compat import Qwen2Tokenizer from .llama import Qwen25_7BVLI from .. import sd1_clip diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py index 90f8a8b25..824265fe2 100644 --- a/comfy/text_encoders/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -1,4 +1,4 @@ -from transformers import T5TokenizerFast +from ..transformers_compat import T5TokenizerFast from .t5 import T5 from .. import sd1_clip diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 2511969e3..ac99e413d 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -3,7 +3,7 @@ import logging import comfy.utils import torch -from transformers import T5TokenizerFast +from ..transformers_compat import T5TokenizerFast from .t5 import T5 from .. import sd1_clip, model_management diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index 1c050ad7b..cd268cf14 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -1,4 +1,4 @@ -from transformers import Qwen2Tokenizer +from ..transformers_compat import Qwen2Tokenizer from . import llama from .. import sd1_clip diff --git a/comfy/transformers_compat.py b/comfy/transformers_compat.py new file mode 100644 index 000000000..ac60ff379 --- /dev/null +++ b/comfy/transformers_compat.py @@ -0,0 +1,56 @@ +try: + from transformers import T5TokenizerFast +except (ImportError, ModuleNotFoundError): + from transformers import T5Tokenizer as T5TokenizerFast + +try: + from transformers import LlamaTokenizerFast +except (ImportError, ModuleNotFoundError): + from transformers import LlamaTokenizer as LlamaTokenizerFast + +try: + from transformers import CLIPTokenizerFast +except (ImportError, ModuleNotFoundError): + from transformers import CLIPTokenizer as CLIPTokenizerFast + +try: + from transformers import GPT2TokenizerFast +except (ImportError, ModuleNotFoundError): + from transformers import GPT2Tokenizer as GPT2TokenizerFast + +try: + from transformers import BertTokenizerFast +except (ImportError, ModuleNotFoundError): + from transformers import BertTokenizer as BertTokenizerFast + +try: + from transformers import Qwen2TokenizerFast +except (ImportError, ModuleNotFoundError): + try: + from transformers import Qwen2Tokenizer as Qwen2TokenizerFast + except (ImportError, ModuleNotFoundError): + # Fallback if neither exists, primarily for earlier versions or specific environments + Qwen2TokenizerFast = None + +# Alias Qwen2Tokenizer to the "Fast" version we found/aliased, as we might use either name +Qwen2Tokenizer = Qwen2TokenizerFast + +try: + from transformers import ByT5TokenizerFast +except ImportError: + try: + from transformers import ByT5Tokenizer as ByT5TokenizerFast + except (ImportError, ModuleNotFoundError): + ByT5TokenizerFast = None + +ByT5Tokenizer = ByT5TokenizerFast + +__all__ = [ + "T5TokenizerFast", + "LlamaTokenizerFast", + "CLIPTokenizerFast", + "GPT2TokenizerFast", + "BertTokenizerFast", + "Qwen2Tokenizer", + "ByT5Tokenizer", +] diff --git a/pyproject.toml b/pyproject.toml index 59468c9d5..609f4d5df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "torchsde>=0.2.6", "einops>=0.6.0", "open-clip-torch>=2.24.0", - "transformers>=4.57.1,<5", + "transformers>=4.57.3", "tokenizers>=0.13.3", "sentencepiece", "peft>=0.10.0", diff --git a/tests/unit/test_comfy_instance.py b/tests/unit/test_comfy_instance.py new file mode 100644 index 000000000..34b142bd4 --- /dev/null +++ b/tests/unit/test_comfy_instance.py @@ -0,0 +1,74 @@ +import pytest +from comfy.client.embedded_comfy_client import Comfy +from comfy.distributed.process_pool_executor import ProcessPoolExecutor +from comfy.distributed.executors import ContextVarExecutor + + +@pytest.mark.asyncio +async def test_comfy_defaults(): + """Test that default initialization uses ContextVarExecutor.""" + client = Comfy() + assert isinstance(client._executor, ContextVarExecutor) + assert client._owns_executor + + +@pytest.mark.asyncio +async def test_comfy_config_triggers_process_pool(): + """Test that configurations affecting global state trigger ProcessPoolExecutor.""" + # "lowvram" is in MODEL_MANAGEMENT_ARGS + client = Comfy(configuration={"lowvram": True}) + assert isinstance(client._executor, ProcessPoolExecutor) + assert client._owns_executor + + +@pytest.mark.asyncio +async def test_comfy_config_unrelated_does_not_trigger(): + """Test that unrelated configuration keys do NOT trigger ProcessPoolExecutor.""" + # "some_random_ui_setting" is likely not in MODEL_MANAGEMENT_ARGS + client = Comfy(configuration={"some_random_ui_setting": "value"}) + assert isinstance(client._executor, ContextVarExecutor) + + +@pytest.mark.asyncio +async def test_comfy_explicit_executor_string(): + """Test explicitly requesting an executor by string.""" + client = Comfy(executor="ProcessPoolExecutor") + assert isinstance(client._executor, ProcessPoolExecutor) + assert client._owns_executor + + client2 = Comfy(executor="ContextVarExecutor") + assert isinstance(client2._executor, ContextVarExecutor) + assert client2._owns_executor + + +@pytest.mark.asyncio +async def test_comfy_explicit_executor_instance(): + """Test passing an executor instance.""" + executor = ContextVarExecutor(max_workers=1) + client = Comfy(executor=executor) + assert client._executor is executor + assert not client._owns_executor + + +@pytest.mark.asyncio +async def test_comfy_mismatch_string_raises(): + """Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor is forced via string.""" + with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"): + Comfy(configuration={"lowvram": True}, executor="ContextVarExecutor") + + +@pytest.mark.asyncio +async def test_comfy_mismatch_instance_raises(): + """Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor instance is passed.""" + executor = ContextVarExecutor(max_workers=1) + with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"): + Comfy(configuration={"lowvram": True}, executor=executor) + + +@pytest.mark.asyncio +async def test_comfy_context_manager(): + """Test the async context manager behavior.""" + async with Comfy() as client: + assert client.is_running + assert isinstance(client._executor, ContextVarExecutor) + assert not client.is_running diff --git a/tests/unit/test_panics.py b/tests/unit/test_panics.py index 190484727..c95b731e9 100644 --- a/tests/unit/test_panics.py +++ b/tests/unit/test_panics.py @@ -6,6 +6,7 @@ import pytest import torch import pebble.common.types +from comfy.cli_args import default_configuration from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import Comfy from comfy.component_model.make_mutable import make_mutable @@ -143,7 +144,7 @@ def create_failing_workflow(): async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs): """Test panic behavior with different executor types""" # Create configuration with our test exception in panic_when - config = Configuration() + config = default_configuration() config.panic_when = [f"{__name__}.UnrecoverableError"] # Initialize the specific executor @@ -179,7 +180,7 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg """Test no-panic behavior with different executor types""" # Create configuration without the exception in panic_when - config = Configuration() + config = default_configuration() # Initialize the specific executor executor = executor_cls(**executor_kwargs)