mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fix transformers 5 compatibility
This commit is contained in:
parent
ed4f92279c
commit
b309e0424e
35
README.md
35
README.md
@ -306,19 +306,28 @@ choco install -y vcredist2010 vcredist2013 vcredist140
|
||||
|
||||
Then, visit [NVIDIA.com's CUDA Toolkit Download Page](https://developer.nvidia.com/cuda-12-6-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=Server2022&target_type=exe_network) and download and install the CUDA Toolkit. Verify it is correctly installed by running `nvcc --version`.
|
||||
|
||||
You are now ready to install Sage Attention 2 and Flash Attention.
|
||||
You are now ready to install Sage Attention 2.
|
||||
|
||||
### Linux
|
||||
|
||||
```shell
|
||||
uv pip install --no-build-isolation "sageattention@git+https://github.com/thu-ml/SageAttention.git"
|
||||
uv pip install --no-build-isolation flash_attn
|
||||
```
|
||||
|
||||
### Windows
|
||||
|
||||
Run this PowerShell script to install the correct version of Sage Attention for your installed PyTorch version:
|
||||
|
||||
```powershell
|
||||
uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html sageattention
|
||||
$torch_version = (uv pip freeze | Select-String "torch==").ToString().Trim()
|
||||
$cuda_version = $torch_version -replace ".*(cu\d+).*", "`$1"
|
||||
if ($torch_version -match "\+cu") {
|
||||
$v = $torch_version -replace "torch==", ""
|
||||
$package_specifier = "sageattention==2.2.0+$($cuda_version)torch$v"
|
||||
uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html $package_specifier
|
||||
} else {
|
||||
Write-Host "Could not determine CUDA version from torch version: $torch_version"
|
||||
}
|
||||
```
|
||||
|
||||
To start ComfyUI with it:
|
||||
@ -333,6 +342,26 @@ uv run comfyui --use-sage-attention
|
||||

|
||||
**With PyTorch Attention**
|
||||
|
||||
## Flash Attention
|
||||
|
||||
Flash Attention 2 is supported on Linux only.
|
||||
|
||||
```shell
|
||||
uv pip install --no-build-isolation flash_attn
|
||||
```
|
||||
|
||||
To start ComfyUI with it:
|
||||
|
||||
```shell
|
||||
uv run comfyui --use-flash-attention
|
||||
```
|
||||
|
||||

|
||||
**With SageAttention**
|
||||
|
||||

|
||||
**With PyTorch Attention**
|
||||
|
||||
## Cosmos Prompt Upsampling
|
||||
|
||||
The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly.
|
||||
|
||||
@ -32,6 +32,7 @@ from ..distributed.executors import ContextVarExecutor
|
||||
from ..distributed.history import History
|
||||
from ..distributed.process_pool_executor import ProcessPoolExecutor
|
||||
from ..distributed.server_stub import ServerStub
|
||||
from ..component_model.configuration import MODEL_MANAGEMENT_ARGS, requires_process_pool_executor
|
||||
|
||||
_prompt_executor = threading.local()
|
||||
|
||||
@ -147,50 +148,181 @@ def _cleanup(invalidate_nodes=True):
|
||||
|
||||
class Comfy:
|
||||
"""
|
||||
This manages a single-threaded executor to run long-running or blocking workflows
|
||||
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
|
||||
in a dedicated thread for executing prompts and handling server-stub communications.
|
||||
Example usage:
|
||||
A client for running ComfyUI workflows within a Python application.
|
||||
|
||||
Asynchronous (non-blocking) usage with async-await:
|
||||
```
|
||||
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
|
||||
# workspace.
|
||||
This client allows you to execute ComfyUI workflows (in API JSON format) programmatically.
|
||||
It manages the execution environment, including model loading and resource cleanup.
|
||||
|
||||
### Configuration and Executors
|
||||
|
||||
ComfyUI relies on global state for model management (e.g., loaded models in VRAM). To handle this safely, `Comfy`
|
||||
executes workflows using one of two strategies based on your `configuration`:
|
||||
|
||||
1. **ContextVarExecutor (Default)**: Runs in a thread pool within the current process.
|
||||
- **Pros**: Efficient, low overhead.
|
||||
- **Cons**: Modifies global state in the current process.
|
||||
- **Use Case**: Standard workflows where you are happy with the default ComfyUI settings or sharing state.
|
||||
|
||||
2. **ProcessPoolExecutor**: Runs in a separate process.
|
||||
- **Pros**: Complete isolation. Configuration changes (like `lowvram`) do not affect the main process.
|
||||
- **Cons**: Higher overhead (process startup).
|
||||
- **Use Case**: Required when `configuration` overrides arguments that affect global model management state.
|
||||
These arguments include: `lowvram`, `highvram`, `cpu`, `gpu_only`, `deterministic`, `directml`,
|
||||
various `fp8`/`fp16`/`bf16` settings, and attention optimizations (e.g., `use_flash_attention`).
|
||||
|
||||
The client automatically selects `ProcessPoolExecutor` if you provide a `configuration` that modifies any of these
|
||||
global settings, unless you explicitly pass an `executor`.
|
||||
|
||||
### Parameters
|
||||
|
||||
- **configuration** (`Optional[Configuration]`): A dictionary of arguments to override defaults.
|
||||
See `comfy.cli_args_types.Configuration`.
|
||||
Example: `{"lowvram": True}` or `{"gpu_only": True}`.
|
||||
- **progress_handler** (`Optional[ExecutorToClientProgress]`): callback handler for progress updates and previews.
|
||||
- **max_workers** (`int`): Maximum number of concurrent workflows (default: 1).
|
||||
- **executor** (`Optional[Union[Executor, str]]`): Explicitly define the executor to use.
|
||||
- Pass an instance of `ProcessPoolExecutor` or `ContextVarExecutor`.
|
||||
- Pass the string `"ProcessPoolExecutor"` or `"ContextVarExecutor"` to force initialization of that type.
|
||||
- If `None` (default), the best executor is chosen based on `configuration`.
|
||||
|
||||
### Examples
|
||||
|
||||
#### 1. Running a Workflow (Basic)
|
||||
|
||||
This example executes a simple workflow and prints the path of the saved image.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
|
||||
# A simple API format workflow (simplified for brevity)
|
||||
prompt_dict = {
|
||||
"1": {"class_type": "KSamplerAdvanced", ...}
|
||||
...
|
||||
"3": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 8566257, "steps": 20, "cfg": 8, "sampler_name": "euler",
|
||||
"scheduler": "normal", "denoise": 1,
|
||||
"model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0],
|
||||
"latent_image": ["5", 0]
|
||||
}
|
||||
},
|
||||
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}},
|
||||
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}},
|
||||
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "masterpiece best quality girl", "clip": ["4", 1]}},
|
||||
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "bad hands", "clip": ["4", 1]}},
|
||||
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["3", 0], "vae": ["4", 2]}},
|
||||
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "ComfyUI_API", "images": ["8", 0]}}
|
||||
}
|
||||
# Validate your workflow (the prompt)
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
prompt = Prompt.validate(prompt_dict)
|
||||
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
|
||||
# It does not connect to a remote server.
|
||||
|
||||
async def main():
|
||||
async with EmbeddedComfyClient() as client:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
print(outputs)
|
||||
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
|
||||
if __name__ == "__main__"
|
||||
asyncio.run(main())
|
||||
# Using default configuration (runs in-process)
|
||||
async with Comfy() as client:
|
||||
# Queue the prompt and await the result
|
||||
outputs = await client.queue_prompt(prompt_dict)
|
||||
|
||||
# Retrieve the output path from the SaveImage node (Node ID "9")
|
||||
image_path = outputs["9"]["images"][0]["abs_path"]
|
||||
print(f"Image saved to: {image_path}")
|
||||
|
||||
# asyncio.run(main())
|
||||
```
|
||||
|
||||
In order to use this in blocking methods, learn more about asyncio online.
|
||||
#### 2. Using Custom Configuration (Isolated Process)
|
||||
|
||||
To run with specific settings like `lowvram`, pass the configuration. This implies `ProcessPoolExecutor`.
|
||||
|
||||
```python
|
||||
async def run_lowvram():
|
||||
# This will spawn a new process with lowvram enabled
|
||||
async with Comfy(configuration={"lowvram": True}) as client:
|
||||
outputs = await client.queue_prompt(prompt_dict)
|
||||
print("Finished lowvram generation")
|
||||
```
|
||||
|
||||
#### 3. Programmatically Building Workflows
|
||||
|
||||
You can use `GraphBuilder` constructing workflows with a more pythonic API.
|
||||
|
||||
```python
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
|
||||
def build_graph():
|
||||
builder = GraphBuilder()
|
||||
checkpoint = builder.node("CheckpointLoaderSimple", ckpt_name="v1-5-pruned-emaonly.safetensors")
|
||||
latent = builder.node("EmptyLatentImage", width=512, height=512, batch_size=1)
|
||||
pos = builder.node("CLIPTextEncode", text="masterpiece", clip=checkpoint.out(1))
|
||||
neg = builder.node("CLIPTextEncode", text="bad quality", clip=checkpoint.out(1))
|
||||
|
||||
sampler = builder.node("KSampler",
|
||||
seed=42, steps=20, cfg=8, sampler_name="euler", scheduler="normal", denoise=1,
|
||||
model=checkpoint.out(0), positive=pos.out(0), negative=neg.out(0), latent_image=latent.out(0)
|
||||
)
|
||||
vae = builder.node("VAEDecode", samples=sampler.out(0), vae=checkpoint.out(2))
|
||||
builder.node("SaveImage", filename_prefix="Generated", images=vae.out(0))
|
||||
return builder.finalize()
|
||||
|
||||
async def run_builder():
|
||||
prompt = build_graph()
|
||||
async with Comfy() as client:
|
||||
await client.queue_prompt(prompt)
|
||||
```
|
||||
|
||||
#### 4. Streaming Progress and Previews
|
||||
|
||||
To receive real-time progress updates and preview images (e.g., step-by-step decoding).
|
||||
|
||||
```python
|
||||
from comfy.component_model.queue_types import BinaryEventTypes
|
||||
|
||||
async def run_streaming():
|
||||
async with Comfy() as client:
|
||||
# Get a task that supports progress iteration
|
||||
task = client.queue_with_progress(prompt_dict)
|
||||
|
||||
async for notification in task.progress():
|
||||
if notification.event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
|
||||
# 'data' contains the PIL Image and metadata
|
||||
image, metadata = notification.data
|
||||
print(f"Received preview: {image.size}")
|
||||
elif notification.event == "progress":
|
||||
print(f"Step: {notification.data['value']}/{notification.data['max']}")
|
||||
|
||||
# Await final result
|
||||
result = await task.get()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None):
|
||||
self._progress_handler = progress_handler or ServerStub()
|
||||
self._owns_executor = executor is None or isinstance(executor, str)
|
||||
if self._owns_executor:
|
||||
if isinstance(executor, str):
|
||||
if executor == "ProcessPoolExecutor":
|
||||
self._executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
else:
|
||||
self._executor = ContextVarExecutor(max_workers=max_workers)
|
||||
else:
|
||||
assert not isinstance(executor, str)
|
||||
self._executor = executor
|
||||
self._default_configuration = default_configuration()
|
||||
self._configuration = configuration
|
||||
|
||||
need_process_pool = requires_process_pool_executor(configuration)
|
||||
|
||||
if executor is None:
|
||||
if need_process_pool:
|
||||
self._executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
self._owns_executor = True
|
||||
else:
|
||||
self._executor = ContextVarExecutor(max_workers=max_workers)
|
||||
self._owns_executor = True
|
||||
elif isinstance(executor, str):
|
||||
self._owns_executor = True
|
||||
if executor == "ProcessPoolExecutor":
|
||||
self._executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
elif executor == "ContextVarExecutor":
|
||||
if need_process_pool:
|
||||
raise ValueError(f"Configuration requires ProcessPoolExecutor but ContextVarExecutor was requested. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
|
||||
self._executor = ContextVarExecutor(max_workers=max_workers)
|
||||
else:
|
||||
raise ValueError(f"Unknown executor type string: {executor}")
|
||||
else:
|
||||
# Executor instance passed
|
||||
self._owns_executor = False
|
||||
self._executor = executor
|
||||
if need_process_pool and not isinstance(executor, ProcessPoolExecutor):
|
||||
raise ValueError(f"Configuration requires ProcessPoolExecutor but {type(executor).__name__} was passed. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
|
||||
|
||||
self._is_running = False
|
||||
self._task_count_lock = RLock()
|
||||
self._task_count = 0
|
||||
|
||||
66
comfy/component_model/configuration.py
Normal file
66
comfy/component_model/configuration.py
Normal file
@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..cli_args_types import Configuration
|
||||
from ..cli_args import default_configuration
|
||||
|
||||
MODEL_MANAGEMENT_ARGS = {
|
||||
"deterministic",
|
||||
"directml",
|
||||
"cpu",
|
||||
"disable_xformers",
|
||||
# todo: this is the default, so it will be omitted
|
||||
# "use_pytorch_cross_attention",
|
||||
"use_split_cross_attention",
|
||||
"use_quad_cross_attention",
|
||||
"supports_fp8_compute",
|
||||
"fast",
|
||||
"lowvram",
|
||||
"novram",
|
||||
"highvram",
|
||||
"gpu_only",
|
||||
"force_fp32",
|
||||
"force_fp16",
|
||||
"force_bf16",
|
||||
"reserve_vram",
|
||||
"disable_smart_memory",
|
||||
"disable_ipex_optimize",
|
||||
"disable_pinned_memory",
|
||||
"async_offload",
|
||||
"disable_async_offload",
|
||||
"force_non_blocking",
|
||||
"force_channels_last",
|
||||
"fp32_unet",
|
||||
"fp64_unet",
|
||||
"bf16_unet",
|
||||
"fp16_unet",
|
||||
"fp8_e4m3fn_unet",
|
||||
"fp8_e5m2_unet",
|
||||
"fp8_e8m0fnu_unet",
|
||||
"fp8_e4m3fn_text_enc",
|
||||
"fp8_e5m2_text_enc",
|
||||
"fp16_text_enc",
|
||||
"bf16_text_enc",
|
||||
"fp32_text_enc",
|
||||
"cpu_vae",
|
||||
"fp16_vae",
|
||||
"bf16_vae",
|
||||
"fp32_vae",
|
||||
"force_upcast_attention",
|
||||
"use_sage_attention",
|
||||
"use_flash_attention",
|
||||
}
|
||||
|
||||
|
||||
def requires_process_pool_executor(configuration: Configuration | None) -> bool:
|
||||
if configuration is None:
|
||||
return False
|
||||
|
||||
default = default_configuration()
|
||||
for key in MODEL_MANAGEMENT_ARGS:
|
||||
# Check if key is in configuration and differs from default
|
||||
if key in configuration:
|
||||
val = configuration[key]
|
||||
# Use equality check, handling potential missing keys in default (though default should have them)
|
||||
if key not in default or val != default[key]:
|
||||
return True
|
||||
return False
|
||||
@ -987,7 +987,13 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
|
||||
logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...")
|
||||
import json
|
||||
import base64
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
try:
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
except ImportError:
|
||||
try:
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
except ImportError:
|
||||
raise NotImplementedError("Could not import bytes_to_unicode from transformers")
|
||||
|
||||
reader = gguf.GGUFReader(path)
|
||||
|
||||
|
||||
@ -12,7 +12,10 @@ from pathlib import Path
|
||||
from typing import Tuple, Sequence, TypeVar, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||
try:
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||
except ImportError:
|
||||
from .transformers_compat import CLIPTokenizerFast as CLIPTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from . import clip_model
|
||||
from . import model_management
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from transformers import T5TokenizerFast
|
||||
from ..transformers_compat import T5TokenizerFast
|
||||
|
||||
from .t5 import T5
|
||||
from .. import sd1_clip
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from transformers import T5TokenizerFast, LlamaTokenizerFast
|
||||
from ..transformers_compat import T5TokenizerFast, LlamaTokenizerFast
|
||||
|
||||
from .sd3_clip import T5XXLModel
|
||||
from .. import sd1_clip, model_management
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from transformers import T5TokenizerFast
|
||||
from ..transformers_compat import T5TokenizerFast
|
||||
|
||||
from .. import sd1_clip
|
||||
from . import sd3_clip
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import re
|
||||
|
||||
from transformers import ByT5Tokenizer
|
||||
from ..transformers_compat import ByT5Tokenizer
|
||||
|
||||
from .llama import Qwen25_7BVLI
|
||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import numbers
|
||||
from transformers import LlamaTokenizerFast
|
||||
from ..transformers_compat import LlamaTokenizerFast
|
||||
|
||||
from .hunyuan_image import HunyuanImageTokenizer
|
||||
from .llama import Llama2
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from transformers import BertTokenizer
|
||||
from ..transformers_compat import BertTokenizerFast as BertTokenizer
|
||||
|
||||
from .bert import BertModel
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from transformers import T5TokenizerFast
|
||||
from ..transformers_compat import T5TokenizerFast
|
||||
|
||||
from .genmo import mochi_te
|
||||
from .. import sd1_clip
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
from ..transformers_compat import Qwen2Tokenizer
|
||||
|
||||
from .llama import Qwen25_3B
|
||||
from .. import sd1_clip
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
from ..transformers_compat import Qwen2Tokenizer
|
||||
from . import llama
|
||||
from .. import sd1_clip
|
||||
from ..component_model import files
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from transformers import T5TokenizerFast
|
||||
from ..transformers_compat import T5TokenizerFast
|
||||
|
||||
from . import sd3_clip
|
||||
from .. import sd1_clip
|
||||
@ -40,7 +40,6 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
|
||||
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import numbers
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer
|
||||
from ..transformers_compat import Qwen2Tokenizer
|
||||
|
||||
from .llama import Qwen25_7BVLI
|
||||
from .. import sd1_clip
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from transformers import T5TokenizerFast
|
||||
from ..transformers_compat import T5TokenizerFast
|
||||
|
||||
from .t5 import T5
|
||||
from .. import sd1_clip
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import comfy.utils
|
||||
|
||||
import torch
|
||||
from transformers import T5TokenizerFast
|
||||
from ..transformers_compat import T5TokenizerFast
|
||||
|
||||
from .t5 import T5
|
||||
from .. import sd1_clip, model_management
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
from ..transformers_compat import Qwen2Tokenizer
|
||||
|
||||
from . import llama
|
||||
from .. import sd1_clip
|
||||
|
||||
56
comfy/transformers_compat.py
Normal file
56
comfy/transformers_compat.py
Normal file
@ -0,0 +1,56 @@
|
||||
try:
|
||||
from transformers import T5TokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from transformers import T5Tokenizer as T5TokenizerFast
|
||||
|
||||
try:
|
||||
from transformers import LlamaTokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from transformers import LlamaTokenizer as LlamaTokenizerFast
|
||||
|
||||
try:
|
||||
from transformers import CLIPTokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from transformers import CLIPTokenizer as CLIPTokenizerFast
|
||||
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from transformers import GPT2Tokenizer as GPT2TokenizerFast
|
||||
|
||||
try:
|
||||
from transformers import BertTokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from transformers import BertTokenizer as BertTokenizerFast
|
||||
|
||||
try:
|
||||
from transformers import Qwen2TokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
try:
|
||||
from transformers import Qwen2Tokenizer as Qwen2TokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
# Fallback if neither exists, primarily for earlier versions or specific environments
|
||||
Qwen2TokenizerFast = None
|
||||
|
||||
# Alias Qwen2Tokenizer to the "Fast" version we found/aliased, as we might use either name
|
||||
Qwen2Tokenizer = Qwen2TokenizerFast
|
||||
|
||||
try:
|
||||
from transformers import ByT5TokenizerFast
|
||||
except ImportError:
|
||||
try:
|
||||
from transformers import ByT5Tokenizer as ByT5TokenizerFast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
ByT5TokenizerFast = None
|
||||
|
||||
ByT5Tokenizer = ByT5TokenizerFast
|
||||
|
||||
__all__ = [
|
||||
"T5TokenizerFast",
|
||||
"LlamaTokenizerFast",
|
||||
"CLIPTokenizerFast",
|
||||
"GPT2TokenizerFast",
|
||||
"BertTokenizerFast",
|
||||
"Qwen2Tokenizer",
|
||||
"ByT5Tokenizer",
|
||||
]
|
||||
@ -27,7 +27,7 @@ dependencies = [
|
||||
"torchsde>=0.2.6",
|
||||
"einops>=0.6.0",
|
||||
"open-clip-torch>=2.24.0",
|
||||
"transformers>=4.57.1,<5",
|
||||
"transformers>=4.57.3",
|
||||
"tokenizers>=0.13.3",
|
||||
"sentencepiece",
|
||||
"peft>=0.10.0",
|
||||
|
||||
74
tests/unit/test_comfy_instance.py
Normal file
74
tests/unit/test_comfy_instance.py
Normal file
@ -0,0 +1,74 @@
|
||||
import pytest
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
|
||||
from comfy.distributed.executors import ContextVarExecutor
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_defaults():
|
||||
"""Test that default initialization uses ContextVarExecutor."""
|
||||
client = Comfy()
|
||||
assert isinstance(client._executor, ContextVarExecutor)
|
||||
assert client._owns_executor
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_config_triggers_process_pool():
|
||||
"""Test that configurations affecting global state trigger ProcessPoolExecutor."""
|
||||
# "lowvram" is in MODEL_MANAGEMENT_ARGS
|
||||
client = Comfy(configuration={"lowvram": True})
|
||||
assert isinstance(client._executor, ProcessPoolExecutor)
|
||||
assert client._owns_executor
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_config_unrelated_does_not_trigger():
|
||||
"""Test that unrelated configuration keys do NOT trigger ProcessPoolExecutor."""
|
||||
# "some_random_ui_setting" is likely not in MODEL_MANAGEMENT_ARGS
|
||||
client = Comfy(configuration={"some_random_ui_setting": "value"})
|
||||
assert isinstance(client._executor, ContextVarExecutor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_explicit_executor_string():
|
||||
"""Test explicitly requesting an executor by string."""
|
||||
client = Comfy(executor="ProcessPoolExecutor")
|
||||
assert isinstance(client._executor, ProcessPoolExecutor)
|
||||
assert client._owns_executor
|
||||
|
||||
client2 = Comfy(executor="ContextVarExecutor")
|
||||
assert isinstance(client2._executor, ContextVarExecutor)
|
||||
assert client2._owns_executor
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_explicit_executor_instance():
|
||||
"""Test passing an executor instance."""
|
||||
executor = ContextVarExecutor(max_workers=1)
|
||||
client = Comfy(executor=executor)
|
||||
assert client._executor is executor
|
||||
assert not client._owns_executor
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_mismatch_string_raises():
|
||||
"""Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor is forced via string."""
|
||||
with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"):
|
||||
Comfy(configuration={"lowvram": True}, executor="ContextVarExecutor")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_mismatch_instance_raises():
|
||||
"""Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor instance is passed."""
|
||||
executor = ContextVarExecutor(max_workers=1)
|
||||
with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"):
|
||||
Comfy(configuration={"lowvram": True}, executor=executor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comfy_context_manager():
|
||||
"""Test the async context manager behavior."""
|
||||
async with Comfy() as client:
|
||||
assert client.is_running
|
||||
assert isinstance(client._executor, ContextVarExecutor)
|
||||
assert not client.is_running
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import pebble.common.types
|
||||
|
||||
from comfy.cli_args import default_configuration
|
||||
from comfy.cli_args_types import Configuration
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
from comfy.component_model.make_mutable import make_mutable
|
||||
@ -143,7 +144,7 @@ def create_failing_workflow():
|
||||
async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs):
|
||||
"""Test panic behavior with different executor types"""
|
||||
# Create configuration with our test exception in panic_when
|
||||
config = Configuration()
|
||||
config = default_configuration()
|
||||
config.panic_when = [f"{__name__}.UnrecoverableError"]
|
||||
|
||||
# Initialize the specific executor
|
||||
@ -179,7 +180,7 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg
|
||||
"""Test no-panic behavior with different executor types"""
|
||||
|
||||
# Create configuration without the exception in panic_when
|
||||
config = Configuration()
|
||||
config = default_configuration()
|
||||
|
||||
# Initialize the specific executor
|
||||
executor = executor_cls(**executor_kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user