Fix transformers 5 compatibility

This commit is contained in:
doctorpangloss 2025-12-12 12:11:25 -08:00
parent ed4f92279c
commit b309e0424e
23 changed files with 421 additions and 56 deletions

View File

@ -306,19 +306,28 @@ choco install -y vcredist2010 vcredist2013 vcredist140
Then, visit [NVIDIA.com's CUDA Toolkit Download Page](https://developer.nvidia.com/cuda-12-6-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=Server2022&target_type=exe_network) and download and install the CUDA Toolkit. Verify it is correctly installed by running `nvcc --version`. Then, visit [NVIDIA.com's CUDA Toolkit Download Page](https://developer.nvidia.com/cuda-12-6-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=Server2022&target_type=exe_network) and download and install the CUDA Toolkit. Verify it is correctly installed by running `nvcc --version`.
You are now ready to install Sage Attention 2 and Flash Attention. You are now ready to install Sage Attention 2.
### Linux ### Linux
```shell ```shell
uv pip install --no-build-isolation "sageattention@git+https://github.com/thu-ml/SageAttention.git" uv pip install --no-build-isolation "sageattention@git+https://github.com/thu-ml/SageAttention.git"
uv pip install --no-build-isolation flash_attn
``` ```
### Windows ### Windows
Run this PowerShell script to install the correct version of Sage Attention for your installed PyTorch version:
```powershell ```powershell
uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html sageattention $torch_version = (uv pip freeze | Select-String "torch==").ToString().Trim()
$cuda_version = $torch_version -replace ".*(cu\d+).*", "`$1"
if ($torch_version -match "\+cu") {
$v = $torch_version -replace "torch==", ""
$package_specifier = "sageattention==2.2.0+$($cuda_version)torch$v"
uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html $package_specifier
} else {
Write-Host "Could not determine CUDA version from torch version: $torch_version"
}
``` ```
To start ComfyUI with it: To start ComfyUI with it:
@ -333,6 +342,26 @@ uv run comfyui --use-sage-attention
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp) ![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention** **With PyTorch Attention**
## Flash Attention
Flash Attention 2 is supported on Linux only.
```shell
uv pip install --no-build-isolation flash_attn
```
To start ComfyUI with it:
```shell
uv run comfyui --use-flash-attention
```
![with_sage_attention.webp](./docs/assets/with_sage_attention.webp)
**With SageAttention**
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention**
## Cosmos Prompt Upsampling ## Cosmos Prompt Upsampling
The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly. The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly.

View File

@ -32,6 +32,7 @@ from ..distributed.executors import ContextVarExecutor
from ..distributed.history import History from ..distributed.history import History
from ..distributed.process_pool_executor import ProcessPoolExecutor from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
from ..component_model.configuration import MODEL_MANAGEMENT_ARGS, requires_process_pool_executor
_prompt_executor = threading.local() _prompt_executor = threading.local()
@ -147,50 +148,181 @@ def _cleanup(invalidate_nodes=True):
class Comfy: class Comfy:
""" """
This manages a single-threaded executor to run long-running or blocking workflows A client for running ComfyUI workflows within a Python application.
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
in a dedicated thread for executing prompts and handling server-stub communications.
Example usage:
Asynchronous (non-blocking) usage with async-await: This client allows you to execute ComfyUI workflows (in API JSON format) programmatically.
``` It manages the execution environment, including model loading and resource cleanup.
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
# workspace. ### Configuration and Executors
ComfyUI relies on global state for model management (e.g., loaded models in VRAM). To handle this safely, `Comfy`
executes workflows using one of two strategies based on your `configuration`:
1. **ContextVarExecutor (Default)**: Runs in a thread pool within the current process.
- **Pros**: Efficient, low overhead.
- **Cons**: Modifies global state in the current process.
- **Use Case**: Standard workflows where you are happy with the default ComfyUI settings or sharing state.
2. **ProcessPoolExecutor**: Runs in a separate process.
- **Pros**: Complete isolation. Configuration changes (like `lowvram`) do not affect the main process.
- **Cons**: Higher overhead (process startup).
- **Use Case**: Required when `configuration` overrides arguments that affect global model management state.
These arguments include: `lowvram`, `highvram`, `cpu`, `gpu_only`, `deterministic`, `directml`,
various `fp8`/`fp16`/`bf16` settings, and attention optimizations (e.g., `use_flash_attention`).
The client automatically selects `ProcessPoolExecutor` if you provide a `configuration` that modifies any of these
global settings, unless you explicitly pass an `executor`.
### Parameters
- **configuration** (`Optional[Configuration]`): A dictionary of arguments to override defaults.
See `comfy.cli_args_types.Configuration`.
Example: `{"lowvram": True}` or `{"gpu_only": True}`.
- **progress_handler** (`Optional[ExecutorToClientProgress]`): callback handler for progress updates and previews.
- **max_workers** (`int`): Maximum number of concurrent workflows (default: 1).
- **executor** (`Optional[Union[Executor, str]]`): Explicitly define the executor to use.
- Pass an instance of `ProcessPoolExecutor` or `ContextVarExecutor`.
- Pass the string `"ProcessPoolExecutor"` or `"ContextVarExecutor"` to force initialization of that type.
- If `None` (default), the best executor is chosen based on `configuration`.
### Examples
#### 1. Running a Workflow (Basic)
This example executes a simple workflow and prints the path of the saved image.
```python
import asyncio
from comfy.client.embedded_comfy_client import Comfy
# A simple API format workflow (simplified for brevity)
prompt_dict = { prompt_dict = {
"1": {"class_type": "KSamplerAdvanced", ...} "3": {
... "class_type": "KSampler",
"inputs": {
"seed": 8566257, "steps": 20, "cfg": 8, "sampler_name": "euler",
"scheduler": "normal", "denoise": 1,
"model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0],
"latent_image": ["5", 0]
}
},
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}},
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "masterpiece best quality girl", "clip": ["4", 1]}},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "bad hands", "clip": ["4", 1]}},
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["3", 0], "vae": ["4", 2]}},
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "ComfyUI_API", "images": ["8", 0]}}
} }
# Validate your workflow (the prompt)
from comfy.api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt_dict)
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
# It does not connect to a remote server.
async def main(): async def main():
async with EmbeddedComfyClient() as client: # Using default configuration (runs in-process)
outputs = await client.queue_prompt(prompt) async with Comfy() as client:
print(outputs) # Queue the prompt and await the result
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI") outputs = await client.queue_prompt(prompt_dict)
if __name__ == "__main__"
asyncio.run(main()) # Retrieve the output path from the SaveImage node (Node ID "9")
image_path = outputs["9"]["images"][0]["abs_path"]
print(f"Image saved to: {image_path}")
# asyncio.run(main())
``` ```
In order to use this in blocking methods, learn more about asyncio online. #### 2. Using Custom Configuration (Isolated Process)
To run with specific settings like `lowvram`, pass the configuration. This implies `ProcessPoolExecutor`.
```python
async def run_lowvram():
# This will spawn a new process with lowvram enabled
async with Comfy(configuration={"lowvram": True}) as client:
outputs = await client.queue_prompt(prompt_dict)
print("Finished lowvram generation")
```
#### 3. Programmatically Building Workflows
You can use `GraphBuilder` constructing workflows with a more pythonic API.
```python
from comfy_execution.graph_utils import GraphBuilder
def build_graph():
builder = GraphBuilder()
checkpoint = builder.node("CheckpointLoaderSimple", ckpt_name="v1-5-pruned-emaonly.safetensors")
latent = builder.node("EmptyLatentImage", width=512, height=512, batch_size=1)
pos = builder.node("CLIPTextEncode", text="masterpiece", clip=checkpoint.out(1))
neg = builder.node("CLIPTextEncode", text="bad quality", clip=checkpoint.out(1))
sampler = builder.node("KSampler",
seed=42, steps=20, cfg=8, sampler_name="euler", scheduler="normal", denoise=1,
model=checkpoint.out(0), positive=pos.out(0), negative=neg.out(0), latent_image=latent.out(0)
)
vae = builder.node("VAEDecode", samples=sampler.out(0), vae=checkpoint.out(2))
builder.node("SaveImage", filename_prefix="Generated", images=vae.out(0))
return builder.finalize()
async def run_builder():
prompt = build_graph()
async with Comfy() as client:
await client.queue_prompt(prompt)
```
#### 4. Streaming Progress and Previews
To receive real-time progress updates and preview images (e.g., step-by-step decoding).
```python
from comfy.component_model.queue_types import BinaryEventTypes
async def run_streaming():
async with Comfy() as client:
# Get a task that supports progress iteration
task = client.queue_with_progress(prompt_dict)
async for notification in task.progress():
if notification.event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# 'data' contains the PIL Image and metadata
image, metadata = notification.data
print(f"Received preview: {image.size}")
elif notification.event == "progress":
print(f"Step: {notification.data['value']}/{notification.data['max']}")
# Await final result
result = await task.get()
```
""" """
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None): def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None):
self._progress_handler = progress_handler or ServerStub() self._progress_handler = progress_handler or ServerStub()
self._owns_executor = executor is None or isinstance(executor, str)
if self._owns_executor:
if isinstance(executor, str):
if executor == "ProcessPoolExecutor":
self._executor = ProcessPoolExecutor(max_workers=max_workers)
else:
self._executor = ContextVarExecutor(max_workers=max_workers)
else:
assert not isinstance(executor, str)
self._executor = executor
self._default_configuration = default_configuration() self._default_configuration = default_configuration()
self._configuration = configuration self._configuration = configuration
need_process_pool = requires_process_pool_executor(configuration)
if executor is None:
if need_process_pool:
self._executor = ProcessPoolExecutor(max_workers=max_workers)
self._owns_executor = True
else:
self._executor = ContextVarExecutor(max_workers=max_workers)
self._owns_executor = True
elif isinstance(executor, str):
self._owns_executor = True
if executor == "ProcessPoolExecutor":
self._executor = ProcessPoolExecutor(max_workers=max_workers)
elif executor == "ContextVarExecutor":
if need_process_pool:
raise ValueError(f"Configuration requires ProcessPoolExecutor but ContextVarExecutor was requested. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
self._executor = ContextVarExecutor(max_workers=max_workers)
else:
raise ValueError(f"Unknown executor type string: {executor}")
else:
# Executor instance passed
self._owns_executor = False
self._executor = executor
if need_process_pool and not isinstance(executor, ProcessPoolExecutor):
raise ValueError(f"Configuration requires ProcessPoolExecutor but {type(executor).__name__} was passed. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
self._is_running = False self._is_running = False
self._task_count_lock = RLock() self._task_count_lock = RLock()
self._task_count = 0 self._task_count = 0

View File

@ -0,0 +1,66 @@
from __future__ import annotations
from ..cli_args_types import Configuration
from ..cli_args import default_configuration
MODEL_MANAGEMENT_ARGS = {
"deterministic",
"directml",
"cpu",
"disable_xformers",
# todo: this is the default, so it will be omitted
# "use_pytorch_cross_attention",
"use_split_cross_attention",
"use_quad_cross_attention",
"supports_fp8_compute",
"fast",
"lowvram",
"novram",
"highvram",
"gpu_only",
"force_fp32",
"force_fp16",
"force_bf16",
"reserve_vram",
"disable_smart_memory",
"disable_ipex_optimize",
"disable_pinned_memory",
"async_offload",
"disable_async_offload",
"force_non_blocking",
"force_channels_last",
"fp32_unet",
"fp64_unet",
"bf16_unet",
"fp16_unet",
"fp8_e4m3fn_unet",
"fp8_e5m2_unet",
"fp8_e8m0fnu_unet",
"fp8_e4m3fn_text_enc",
"fp8_e5m2_text_enc",
"fp16_text_enc",
"bf16_text_enc",
"fp32_text_enc",
"cpu_vae",
"fp16_vae",
"bf16_vae",
"fp32_vae",
"force_upcast_attention",
"use_sage_attention",
"use_flash_attention",
}
def requires_process_pool_executor(configuration: Configuration | None) -> bool:
if configuration is None:
return False
default = default_configuration()
for key in MODEL_MANAGEMENT_ARGS:
# Check if key is in configuration and differs from default
if key in configuration:
val = configuration[key]
# Use equality check, handling potential missing keys in default (though default should have them)
if key not in default or val != default[key]:
return True
return False

View File

@ -987,7 +987,13 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...") logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...")
import json import json
import base64 import base64
from transformers.convert_slow_tokenizer import bytes_to_unicode try:
from transformers.convert_slow_tokenizer import bytes_to_unicode
except ImportError:
try:
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
except ImportError:
raise NotImplementedError("Could not import bytes_to_unicode from transformers")
reader = gguf.GGUFReader(path) reader = gguf.GGUFReader(path)

View File

@ -12,7 +12,10 @@ from pathlib import Path
from typing import Tuple, Sequence, TypeVar, Callable, Optional, Union from typing import Tuple, Sequence, TypeVar, Callable, Optional, Union
import torch import torch
from transformers import CLIPTokenizer, PreTrainedTokenizerBase try:
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
except ImportError:
from .transformers_compat import CLIPTokenizerFast as CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model from . import clip_model
from . import model_management from . import model_management

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast from ..transformers_compat import T5TokenizerFast
from .t5 import T5 from .t5 import T5
from .. import sd1_clip from .. import sd1_clip

View File

@ -1,7 +1,7 @@
import copy import copy
import torch import torch
from transformers import T5TokenizerFast, LlamaTokenizerFast from ..transformers_compat import T5TokenizerFast, LlamaTokenizerFast
from .sd3_clip import T5XXLModel from .sd3_clip import T5XXLModel
from .. import sd1_clip, model_management from .. import sd1_clip, model_management

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast from ..transformers_compat import T5TokenizerFast
from .. import sd1_clip from .. import sd1_clip
from . import sd3_clip from . import sd3_clip

View File

@ -1,6 +1,6 @@
import re import re
from transformers import ByT5Tokenizer from ..transformers_compat import ByT5Tokenizer
from .llama import Qwen25_7BVLI from .llama import Qwen25_7BVLI
from .qwen_image import QwenImageTokenizer, QwenImageTEModel from .qwen_image import QwenImageTokenizer, QwenImageTEModel

View File

@ -1,6 +1,6 @@
import torch import torch
import numbers import numbers
from transformers import LlamaTokenizerFast from ..transformers_compat import LlamaTokenizerFast
from .hunyuan_image import HunyuanImageTokenizer from .hunyuan_image import HunyuanImageTokenizer
from .llama import Llama2 from .llama import Llama2

View File

@ -1,7 +1,7 @@
import copy import copy
import torch import torch
from transformers import BertTokenizer from ..transformers_compat import BertTokenizerFast as BertTokenizer
from .bert import BertModel from .bert import BertModel
from .spiece_tokenizer import SPieceTokenizer from .spiece_tokenizer import SPieceTokenizer

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast from ..transformers_compat import T5TokenizerFast
from .genmo import mochi_te from .genmo import mochi_te
from .. import sd1_clip from .. import sd1_clip

View File

@ -1,4 +1,4 @@
from transformers import Qwen2Tokenizer from ..transformers_compat import Qwen2Tokenizer
from .llama import Qwen25_3B from .llama import Qwen25_3B
from .. import sd1_clip from .. import sd1_clip

View File

@ -1,8 +1,7 @@
import numbers import numbers
import torch import torch
from transformers import Qwen2Tokenizer from ..transformers_compat import Qwen2Tokenizer
from . import llama from . import llama
from .. import sd1_clip from .. import sd1_clip
from ..component_model import files from ..component_model import files

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast from ..transformers_compat import T5TokenizerFast
from . import sd3_clip from . import sd3_clip
from .. import sd1_clip from .. import sd1_clip
@ -40,7 +40,6 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5_quantization_metadata=None): def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
class PixArtTEModel_(PixArtT5XXL): class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options=None): def __init__(self, device="cpu", dtype=None, model_options=None):

View File

@ -1,6 +1,6 @@
import numbers import numbers
import torch import torch
from transformers import Qwen2Tokenizer from ..transformers_compat import Qwen2Tokenizer
from .llama import Qwen25_7BVLI from .llama import Qwen25_7BVLI
from .. import sd1_clip from .. import sd1_clip

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast from ..transformers_compat import T5TokenizerFast
from .t5 import T5 from .t5 import T5
from .. import sd1_clip from .. import sd1_clip

View File

@ -3,7 +3,7 @@ import logging
import comfy.utils import comfy.utils
import torch import torch
from transformers import T5TokenizerFast from ..transformers_compat import T5TokenizerFast
from .t5 import T5 from .t5 import T5
from .. import sd1_clip, model_management from .. import sd1_clip, model_management

View File

@ -1,4 +1,4 @@
from transformers import Qwen2Tokenizer from ..transformers_compat import Qwen2Tokenizer
from . import llama from . import llama
from .. import sd1_clip from .. import sd1_clip

View File

@ -0,0 +1,56 @@
try:
from transformers import T5TokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import T5Tokenizer as T5TokenizerFast
try:
from transformers import LlamaTokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import LlamaTokenizer as LlamaTokenizerFast
try:
from transformers import CLIPTokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import CLIPTokenizer as CLIPTokenizerFast
try:
from transformers import GPT2TokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import GPT2Tokenizer as GPT2TokenizerFast
try:
from transformers import BertTokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import BertTokenizer as BertTokenizerFast
try:
from transformers import Qwen2TokenizerFast
except (ImportError, ModuleNotFoundError):
try:
from transformers import Qwen2Tokenizer as Qwen2TokenizerFast
except (ImportError, ModuleNotFoundError):
# Fallback if neither exists, primarily for earlier versions or specific environments
Qwen2TokenizerFast = None
# Alias Qwen2Tokenizer to the "Fast" version we found/aliased, as we might use either name
Qwen2Tokenizer = Qwen2TokenizerFast
try:
from transformers import ByT5TokenizerFast
except ImportError:
try:
from transformers import ByT5Tokenizer as ByT5TokenizerFast
except (ImportError, ModuleNotFoundError):
ByT5TokenizerFast = None
ByT5Tokenizer = ByT5TokenizerFast
__all__ = [
"T5TokenizerFast",
"LlamaTokenizerFast",
"CLIPTokenizerFast",
"GPT2TokenizerFast",
"BertTokenizerFast",
"Qwen2Tokenizer",
"ByT5Tokenizer",
]

View File

@ -27,7 +27,7 @@ dependencies = [
"torchsde>=0.2.6", "torchsde>=0.2.6",
"einops>=0.6.0", "einops>=0.6.0",
"open-clip-torch>=2.24.0", "open-clip-torch>=2.24.0",
"transformers>=4.57.1,<5", "transformers>=4.57.3",
"tokenizers>=0.13.3", "tokenizers>=0.13.3",
"sentencepiece", "sentencepiece",
"peft>=0.10.0", "peft>=0.10.0",

View File

@ -0,0 +1,74 @@
import pytest
from comfy.client.embedded_comfy_client import Comfy
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.executors import ContextVarExecutor
@pytest.mark.asyncio
async def test_comfy_defaults():
"""Test that default initialization uses ContextVarExecutor."""
client = Comfy()
assert isinstance(client._executor, ContextVarExecutor)
assert client._owns_executor
@pytest.mark.asyncio
async def test_comfy_config_triggers_process_pool():
"""Test that configurations affecting global state trigger ProcessPoolExecutor."""
# "lowvram" is in MODEL_MANAGEMENT_ARGS
client = Comfy(configuration={"lowvram": True})
assert isinstance(client._executor, ProcessPoolExecutor)
assert client._owns_executor
@pytest.mark.asyncio
async def test_comfy_config_unrelated_does_not_trigger():
"""Test that unrelated configuration keys do NOT trigger ProcessPoolExecutor."""
# "some_random_ui_setting" is likely not in MODEL_MANAGEMENT_ARGS
client = Comfy(configuration={"some_random_ui_setting": "value"})
assert isinstance(client._executor, ContextVarExecutor)
@pytest.mark.asyncio
async def test_comfy_explicit_executor_string():
"""Test explicitly requesting an executor by string."""
client = Comfy(executor="ProcessPoolExecutor")
assert isinstance(client._executor, ProcessPoolExecutor)
assert client._owns_executor
client2 = Comfy(executor="ContextVarExecutor")
assert isinstance(client2._executor, ContextVarExecutor)
assert client2._owns_executor
@pytest.mark.asyncio
async def test_comfy_explicit_executor_instance():
"""Test passing an executor instance."""
executor = ContextVarExecutor(max_workers=1)
client = Comfy(executor=executor)
assert client._executor is executor
assert not client._owns_executor
@pytest.mark.asyncio
async def test_comfy_mismatch_string_raises():
"""Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor is forced via string."""
with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"):
Comfy(configuration={"lowvram": True}, executor="ContextVarExecutor")
@pytest.mark.asyncio
async def test_comfy_mismatch_instance_raises():
"""Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor instance is passed."""
executor = ContextVarExecutor(max_workers=1)
with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"):
Comfy(configuration={"lowvram": True}, executor=executor)
@pytest.mark.asyncio
async def test_comfy_context_manager():
"""Test the async context manager behavior."""
async with Comfy() as client:
assert client.is_running
assert isinstance(client._executor, ContextVarExecutor)
assert not client.is_running

View File

@ -6,6 +6,7 @@ import pytest
import torch import torch
import pebble.common.types import pebble.common.types
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import Comfy from comfy.client.embedded_comfy_client import Comfy
from comfy.component_model.make_mutable import make_mutable from comfy.component_model.make_mutable import make_mutable
@ -143,7 +144,7 @@ def create_failing_workflow():
async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs): async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs):
"""Test panic behavior with different executor types""" """Test panic behavior with different executor types"""
# Create configuration with our test exception in panic_when # Create configuration with our test exception in panic_when
config = Configuration() config = default_configuration()
config.panic_when = [f"{__name__}.UnrecoverableError"] config.panic_when = [f"{__name__}.UnrecoverableError"]
# Initialize the specific executor # Initialize the specific executor
@ -179,7 +180,7 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg
"""Test no-panic behavior with different executor types""" """Test no-panic behavior with different executor types"""
# Create configuration without the exception in panic_when # Create configuration without the exception in panic_when
config = Configuration() config = default_configuration()
# Initialize the specific executor # Initialize the specific executor
executor = executor_cls(**executor_kwargs) executor = executor_cls(**executor_kwargs)