Fix transformers 5 compatibility

This commit is contained in:
doctorpangloss 2025-12-12 12:11:25 -08:00
parent ed4f92279c
commit b309e0424e
23 changed files with 421 additions and 56 deletions

View File

@ -306,19 +306,28 @@ choco install -y vcredist2010 vcredist2013 vcredist140
Then, visit [NVIDIA.com's CUDA Toolkit Download Page](https://developer.nvidia.com/cuda-12-6-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=Server2022&target_type=exe_network) and download and install the CUDA Toolkit. Verify it is correctly installed by running `nvcc --version`.
You are now ready to install Sage Attention 2 and Flash Attention.
You are now ready to install Sage Attention 2.
### Linux
```shell
uv pip install --no-build-isolation "sageattention@git+https://github.com/thu-ml/SageAttention.git"
uv pip install --no-build-isolation flash_attn
```
### Windows
Run this PowerShell script to install the correct version of Sage Attention for your installed PyTorch version:
```powershell
uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html sageattention
$torch_version = (uv pip freeze | Select-String "torch==").ToString().Trim()
$cuda_version = $torch_version -replace ".*(cu\d+).*", "`$1"
if ($torch_version -match "\+cu") {
$v = $torch_version -replace "torch==", ""
$package_specifier = "sageattention==2.2.0+$($cuda_version)torch$v"
uv pip install --find-links https://raw.githubusercontent.com/hiddenswitch/ComfyUI/main/pypi/sageattention_index.html $package_specifier
} else {
Write-Host "Could not determine CUDA version from torch version: $torch_version"
}
```
To start ComfyUI with it:
@ -333,6 +342,26 @@ uv run comfyui --use-sage-attention
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention**
## Flash Attention
Flash Attention 2 is supported on Linux only.
```shell
uv pip install --no-build-isolation flash_attn
```
To start ComfyUI with it:
```shell
uv run comfyui --use-flash-attention
```
![with_sage_attention.webp](./docs/assets/with_sage_attention.webp)
**With SageAttention**
![with_pytorch_attention](./docs/assets/with_pytorch_attention.webp)
**With PyTorch Attention**
## Cosmos Prompt Upsampling
The Cosmos prompt "upsampler," a fine tune of Mistral-Nemo-12b, correctly rewrites Cosmos prompts in the narrative style that NVIDIA's captioner used for the training data of Cosmos, improving generation results significantly.

View File

@ -32,6 +32,7 @@ from ..distributed.executors import ContextVarExecutor
from ..distributed.history import History
from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub
from ..component_model.configuration import MODEL_MANAGEMENT_ARGS, requires_process_pool_executor
_prompt_executor = threading.local()
@ -147,50 +148,181 @@ def _cleanup(invalidate_nodes=True):
class Comfy:
"""
This manages a single-threaded executor to run long-running or blocking workflows
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
in a dedicated thread for executing prompts and handling server-stub communications.
Example usage:
A client for running ComfyUI workflows within a Python application.
Asynchronous (non-blocking) usage with async-await:
```
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
# workspace.
This client allows you to execute ComfyUI workflows (in API JSON format) programmatically.
It manages the execution environment, including model loading and resource cleanup.
### Configuration and Executors
ComfyUI relies on global state for model management (e.g., loaded models in VRAM). To handle this safely, `Comfy`
executes workflows using one of two strategies based on your `configuration`:
1. **ContextVarExecutor (Default)**: Runs in a thread pool within the current process.
- **Pros**: Efficient, low overhead.
- **Cons**: Modifies global state in the current process.
- **Use Case**: Standard workflows where you are happy with the default ComfyUI settings or sharing state.
2. **ProcessPoolExecutor**: Runs in a separate process.
- **Pros**: Complete isolation. Configuration changes (like `lowvram`) do not affect the main process.
- **Cons**: Higher overhead (process startup).
- **Use Case**: Required when `configuration` overrides arguments that affect global model management state.
These arguments include: `lowvram`, `highvram`, `cpu`, `gpu_only`, `deterministic`, `directml`,
various `fp8`/`fp16`/`bf16` settings, and attention optimizations (e.g., `use_flash_attention`).
The client automatically selects `ProcessPoolExecutor` if you provide a `configuration` that modifies any of these
global settings, unless you explicitly pass an `executor`.
### Parameters
- **configuration** (`Optional[Configuration]`): A dictionary of arguments to override defaults.
See `comfy.cli_args_types.Configuration`.
Example: `{"lowvram": True}` or `{"gpu_only": True}`.
- **progress_handler** (`Optional[ExecutorToClientProgress]`): callback handler for progress updates and previews.
- **max_workers** (`int`): Maximum number of concurrent workflows (default: 1).
- **executor** (`Optional[Union[Executor, str]]`): Explicitly define the executor to use.
- Pass an instance of `ProcessPoolExecutor` or `ContextVarExecutor`.
- Pass the string `"ProcessPoolExecutor"` or `"ContextVarExecutor"` to force initialization of that type.
- If `None` (default), the best executor is chosen based on `configuration`.
### Examples
#### 1. Running a Workflow (Basic)
This example executes a simple workflow and prints the path of the saved image.
```python
import asyncio
from comfy.client.embedded_comfy_client import Comfy
# A simple API format workflow (simplified for brevity)
prompt_dict = {
"1": {"class_type": "KSamplerAdvanced", ...}
...
"3": {
"class_type": "KSampler",
"inputs": {
"seed": 8566257, "steps": 20, "cfg": 8, "sampler_name": "euler",
"scheduler": "normal", "denoise": 1,
"model": ["4", 0], "positive": ["6", 0], "negative": ["7", 0],
"latent_image": ["5", 0]
}
},
"4": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}},
"5": {"class_type": "EmptyLatentImage", "inputs": {"width": 512, "height": 512, "batch_size": 1}},
"6": {"class_type": "CLIPTextEncode", "inputs": {"text": "masterpiece best quality girl", "clip": ["4", 1]}},
"7": {"class_type": "CLIPTextEncode", "inputs": {"text": "bad hands", "clip": ["4", 1]}},
"8": {"class_type": "VAEDecode", "inputs": {"samples": ["3", 0], "vae": ["4", 2]}},
"9": {"class_type": "SaveImage", "inputs": {"filename_prefix": "ComfyUI_API", "images": ["8", 0]}}
}
# Validate your workflow (the prompt)
from comfy.api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt_dict)
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
# It does not connect to a remote server.
async def main():
async with EmbeddedComfyClient() as client:
outputs = await client.queue_prompt(prompt)
print(outputs)
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
if __name__ == "__main__"
asyncio.run(main())
# Using default configuration (runs in-process)
async with Comfy() as client:
# Queue the prompt and await the result
outputs = await client.queue_prompt(prompt_dict)
# Retrieve the output path from the SaveImage node (Node ID "9")
image_path = outputs["9"]["images"][0]["abs_path"]
print(f"Image saved to: {image_path}")
# asyncio.run(main())
```
In order to use this in blocking methods, learn more about asyncio online.
#### 2. Using Custom Configuration (Isolated Process)
To run with specific settings like `lowvram`, pass the configuration. This implies `ProcessPoolExecutor`.
```python
async def run_lowvram():
# This will spawn a new process with lowvram enabled
async with Comfy(configuration={"lowvram": True}) as client:
outputs = await client.queue_prompt(prompt_dict)
print("Finished lowvram generation")
```
#### 3. Programmatically Building Workflows
You can use `GraphBuilder` constructing workflows with a more pythonic API.
```python
from comfy_execution.graph_utils import GraphBuilder
def build_graph():
builder = GraphBuilder()
checkpoint = builder.node("CheckpointLoaderSimple", ckpt_name="v1-5-pruned-emaonly.safetensors")
latent = builder.node("EmptyLatentImage", width=512, height=512, batch_size=1)
pos = builder.node("CLIPTextEncode", text="masterpiece", clip=checkpoint.out(1))
neg = builder.node("CLIPTextEncode", text="bad quality", clip=checkpoint.out(1))
sampler = builder.node("KSampler",
seed=42, steps=20, cfg=8, sampler_name="euler", scheduler="normal", denoise=1,
model=checkpoint.out(0), positive=pos.out(0), negative=neg.out(0), latent_image=latent.out(0)
)
vae = builder.node("VAEDecode", samples=sampler.out(0), vae=checkpoint.out(2))
builder.node("SaveImage", filename_prefix="Generated", images=vae.out(0))
return builder.finalize()
async def run_builder():
prompt = build_graph()
async with Comfy() as client:
await client.queue_prompt(prompt)
```
#### 4. Streaming Progress and Previews
To receive real-time progress updates and preview images (e.g., step-by-step decoding).
```python
from comfy.component_model.queue_types import BinaryEventTypes
async def run_streaming():
async with Comfy() as client:
# Get a task that supports progress iteration
task = client.queue_with_progress(prompt_dict)
async for notification in task.progress():
if notification.event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA:
# 'data' contains the PIL Image and metadata
image, metadata = notification.data
print(f"Received preview: {image.size}")
elif notification.event == "progress":
print(f"Step: {notification.data['value']}/{notification.data['max']}")
# Await final result
result = await task.get()
```
"""
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None):
self._progress_handler = progress_handler or ServerStub()
self._owns_executor = executor is None or isinstance(executor, str)
if self._owns_executor:
if isinstance(executor, str):
if executor == "ProcessPoolExecutor":
self._executor = ProcessPoolExecutor(max_workers=max_workers)
else:
self._executor = ContextVarExecutor(max_workers=max_workers)
else:
assert not isinstance(executor, str)
self._executor = executor
self._default_configuration = default_configuration()
self._configuration = configuration
need_process_pool = requires_process_pool_executor(configuration)
if executor is None:
if need_process_pool:
self._executor = ProcessPoolExecutor(max_workers=max_workers)
self._owns_executor = True
else:
self._executor = ContextVarExecutor(max_workers=max_workers)
self._owns_executor = True
elif isinstance(executor, str):
self._owns_executor = True
if executor == "ProcessPoolExecutor":
self._executor = ProcessPoolExecutor(max_workers=max_workers)
elif executor == "ContextVarExecutor":
if need_process_pool:
raise ValueError(f"Configuration requires ProcessPoolExecutor but ContextVarExecutor was requested. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
self._executor = ContextVarExecutor(max_workers=max_workers)
else:
raise ValueError(f"Unknown executor type string: {executor}")
else:
# Executor instance passed
self._owns_executor = False
self._executor = executor
if need_process_pool and not isinstance(executor, ProcessPoolExecutor):
raise ValueError(f"Configuration requires ProcessPoolExecutor but {type(executor).__name__} was passed. Configuration keys causing this: {[k for k in MODEL_MANAGEMENT_ARGS if configuration.get(k) != self._default_configuration.get(k)]}")
self._is_running = False
self._task_count_lock = RLock()
self._task_count = 0

View File

@ -0,0 +1,66 @@
from __future__ import annotations
from ..cli_args_types import Configuration
from ..cli_args import default_configuration
MODEL_MANAGEMENT_ARGS = {
"deterministic",
"directml",
"cpu",
"disable_xformers",
# todo: this is the default, so it will be omitted
# "use_pytorch_cross_attention",
"use_split_cross_attention",
"use_quad_cross_attention",
"supports_fp8_compute",
"fast",
"lowvram",
"novram",
"highvram",
"gpu_only",
"force_fp32",
"force_fp16",
"force_bf16",
"reserve_vram",
"disable_smart_memory",
"disable_ipex_optimize",
"disable_pinned_memory",
"async_offload",
"disable_async_offload",
"force_non_blocking",
"force_channels_last",
"fp32_unet",
"fp64_unet",
"bf16_unet",
"fp16_unet",
"fp8_e4m3fn_unet",
"fp8_e5m2_unet",
"fp8_e8m0fnu_unet",
"fp8_e4m3fn_text_enc",
"fp8_e5m2_text_enc",
"fp16_text_enc",
"bf16_text_enc",
"fp32_text_enc",
"cpu_vae",
"fp16_vae",
"bf16_vae",
"fp32_vae",
"force_upcast_attention",
"use_sage_attention",
"use_flash_attention",
}
def requires_process_pool_executor(configuration: Configuration | None) -> bool:
if configuration is None:
return False
default = default_configuration()
for key in MODEL_MANAGEMENT_ARGS:
# Check if key is in configuration and differs from default
if key in configuration:
val = configuration[key]
# Use equality check, handling potential missing keys in default (though default should have them)
if key not in default or val != default[key]:
return True
return False

View File

@ -987,7 +987,13 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...")
import json
import base64
from transformers.convert_slow_tokenizer import bytes_to_unicode
try:
from transformers.convert_slow_tokenizer import bytes_to_unicode
except ImportError:
try:
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
except ImportError:
raise NotImplementedError("Could not import bytes_to_unicode from transformers")
reader = gguf.GGUFReader(path)

View File

@ -12,7 +12,10 @@ from pathlib import Path
from typing import Tuple, Sequence, TypeVar, Callable, Optional, Union
import torch
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
try:
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
except ImportError:
from .transformers_compat import CLIPTokenizerFast as CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model
from . import model_management

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast
from ..transformers_compat import T5TokenizerFast
from .t5 import T5
from .. import sd1_clip

View File

@ -1,7 +1,7 @@
import copy
import torch
from transformers import T5TokenizerFast, LlamaTokenizerFast
from ..transformers_compat import T5TokenizerFast, LlamaTokenizerFast
from .sd3_clip import T5XXLModel
from .. import sd1_clip, model_management

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast
from ..transformers_compat import T5TokenizerFast
from .. import sd1_clip
from . import sd3_clip

View File

@ -1,6 +1,6 @@
import re
from transformers import ByT5Tokenizer
from ..transformers_compat import ByT5Tokenizer
from .llama import Qwen25_7BVLI
from .qwen_image import QwenImageTokenizer, QwenImageTEModel

View File

@ -1,6 +1,6 @@
import torch
import numbers
from transformers import LlamaTokenizerFast
from ..transformers_compat import LlamaTokenizerFast
from .hunyuan_image import HunyuanImageTokenizer
from .llama import Llama2

View File

@ -1,7 +1,7 @@
import copy
import torch
from transformers import BertTokenizer
from ..transformers_compat import BertTokenizerFast as BertTokenizer
from .bert import BertModel
from .spiece_tokenizer import SPieceTokenizer

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast
from ..transformers_compat import T5TokenizerFast
from .genmo import mochi_te
from .. import sd1_clip

View File

@ -1,4 +1,4 @@
from transformers import Qwen2Tokenizer
from ..transformers_compat import Qwen2Tokenizer
from .llama import Qwen25_3B
from .. import sd1_clip

View File

@ -1,8 +1,7 @@
import numbers
import torch
from transformers import Qwen2Tokenizer
from ..transformers_compat import Qwen2Tokenizer
from . import llama
from .. import sd1_clip
from ..component_model import files

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast
from ..transformers_compat import T5TokenizerFast
from . import sd3_clip
from .. import sd1_clip
@ -40,7 +40,6 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options=None):

View File

@ -1,6 +1,6 @@
import numbers
import torch
from transformers import Qwen2Tokenizer
from ..transformers_compat import Qwen2Tokenizer
from .llama import Qwen25_7BVLI
from .. import sd1_clip

View File

@ -1,4 +1,4 @@
from transformers import T5TokenizerFast
from ..transformers_compat import T5TokenizerFast
from .t5 import T5
from .. import sd1_clip

View File

@ -3,7 +3,7 @@ import logging
import comfy.utils
import torch
from transformers import T5TokenizerFast
from ..transformers_compat import T5TokenizerFast
from .t5 import T5
from .. import sd1_clip, model_management

View File

@ -1,4 +1,4 @@
from transformers import Qwen2Tokenizer
from ..transformers_compat import Qwen2Tokenizer
from . import llama
from .. import sd1_clip

View File

@ -0,0 +1,56 @@
try:
from transformers import T5TokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import T5Tokenizer as T5TokenizerFast
try:
from transformers import LlamaTokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import LlamaTokenizer as LlamaTokenizerFast
try:
from transformers import CLIPTokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import CLIPTokenizer as CLIPTokenizerFast
try:
from transformers import GPT2TokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import GPT2Tokenizer as GPT2TokenizerFast
try:
from transformers import BertTokenizerFast
except (ImportError, ModuleNotFoundError):
from transformers import BertTokenizer as BertTokenizerFast
try:
from transformers import Qwen2TokenizerFast
except (ImportError, ModuleNotFoundError):
try:
from transformers import Qwen2Tokenizer as Qwen2TokenizerFast
except (ImportError, ModuleNotFoundError):
# Fallback if neither exists, primarily for earlier versions or specific environments
Qwen2TokenizerFast = None
# Alias Qwen2Tokenizer to the "Fast" version we found/aliased, as we might use either name
Qwen2Tokenizer = Qwen2TokenizerFast
try:
from transformers import ByT5TokenizerFast
except ImportError:
try:
from transformers import ByT5Tokenizer as ByT5TokenizerFast
except (ImportError, ModuleNotFoundError):
ByT5TokenizerFast = None
ByT5Tokenizer = ByT5TokenizerFast
__all__ = [
"T5TokenizerFast",
"LlamaTokenizerFast",
"CLIPTokenizerFast",
"GPT2TokenizerFast",
"BertTokenizerFast",
"Qwen2Tokenizer",
"ByT5Tokenizer",
]

View File

@ -27,7 +27,7 @@ dependencies = [
"torchsde>=0.2.6",
"einops>=0.6.0",
"open-clip-torch>=2.24.0",
"transformers>=4.57.1,<5",
"transformers>=4.57.3",
"tokenizers>=0.13.3",
"sentencepiece",
"peft>=0.10.0",

View File

@ -0,0 +1,74 @@
import pytest
from comfy.client.embedded_comfy_client import Comfy
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.executors import ContextVarExecutor
@pytest.mark.asyncio
async def test_comfy_defaults():
"""Test that default initialization uses ContextVarExecutor."""
client = Comfy()
assert isinstance(client._executor, ContextVarExecutor)
assert client._owns_executor
@pytest.mark.asyncio
async def test_comfy_config_triggers_process_pool():
"""Test that configurations affecting global state trigger ProcessPoolExecutor."""
# "lowvram" is in MODEL_MANAGEMENT_ARGS
client = Comfy(configuration={"lowvram": True})
assert isinstance(client._executor, ProcessPoolExecutor)
assert client._owns_executor
@pytest.mark.asyncio
async def test_comfy_config_unrelated_does_not_trigger():
"""Test that unrelated configuration keys do NOT trigger ProcessPoolExecutor."""
# "some_random_ui_setting" is likely not in MODEL_MANAGEMENT_ARGS
client = Comfy(configuration={"some_random_ui_setting": "value"})
assert isinstance(client._executor, ContextVarExecutor)
@pytest.mark.asyncio
async def test_comfy_explicit_executor_string():
"""Test explicitly requesting an executor by string."""
client = Comfy(executor="ProcessPoolExecutor")
assert isinstance(client._executor, ProcessPoolExecutor)
assert client._owns_executor
client2 = Comfy(executor="ContextVarExecutor")
assert isinstance(client2._executor, ContextVarExecutor)
assert client2._owns_executor
@pytest.mark.asyncio
async def test_comfy_explicit_executor_instance():
"""Test passing an executor instance."""
executor = ContextVarExecutor(max_workers=1)
client = Comfy(executor=executor)
assert client._executor is executor
assert not client._owns_executor
@pytest.mark.asyncio
async def test_comfy_mismatch_string_raises():
"""Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor is forced via string."""
with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"):
Comfy(configuration={"lowvram": True}, executor="ContextVarExecutor")
@pytest.mark.asyncio
async def test_comfy_mismatch_instance_raises():
"""Test that valid config requiring ProcessPoolExecutor raises error if ContextVarExecutor instance is passed."""
executor = ContextVarExecutor(max_workers=1)
with pytest.raises(ValueError, match="Configuration requires ProcessPoolExecutor"):
Comfy(configuration={"lowvram": True}, executor=executor)
@pytest.mark.asyncio
async def test_comfy_context_manager():
"""Test the async context manager behavior."""
async with Comfy() as client:
assert client.is_running
assert isinstance(client._executor, ContextVarExecutor)
assert not client.is_running

View File

@ -6,6 +6,7 @@ import pytest
import torch
import pebble.common.types
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import Comfy
from comfy.component_model.make_mutable import make_mutable
@ -143,7 +144,7 @@ def create_failing_workflow():
async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs):
"""Test panic behavior with different executor types"""
# Create configuration with our test exception in panic_when
config = Configuration()
config = default_configuration()
config.panic_when = [f"{__name__}.UnrecoverableError"]
# Initialize the specific executor
@ -179,7 +180,7 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg
"""Test no-panic behavior with different executor types"""
# Create configuration without the exception in panic_when
config = Configuration()
config = default_configuration()
# Initialize the specific executor
executor = executor_cls(**executor_kwargs)