mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
fix sdpa priorities
This commit is contained in:
parent
306fcbaa0e
commit
1e938f5feb
28
comfy/ops.py
28
comfy/ops.py
@ -26,13 +26,20 @@ from .cli_args import args, PerformanceFeature
|
||||
from .execution_context import current_execution_context
|
||||
from .float import stochastic_rounding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
scaled_dot_product_attention = None
|
||||
|
||||
|
||||
def _scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
SDPA_BACKEND_PRIORITY = [
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
@ -42,20 +49,25 @@ try:
|
||||
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
|
||||
def _scaled_dot_product_attention_sdpa(q, k, v, *args, **kwargs):
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||
except (ModuleNotFoundError, TypeError):
|
||||
logging.warning("Could not set sdpa backend priority.")
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
|
||||
scaled_dot_product_attention = _scaled_dot_product_attention_sdpa
|
||||
else:
|
||||
logger.warning("Torch version too old to set sdpa backend priority, even though you are using CUDA")
|
||||
scaled_dot_product_attention = _scaled_dot_product_attention
|
||||
else:
|
||||
scaled_dot_product_attention = _scaled_dot_product_attention
|
||||
except Exception as exc_info:
|
||||
if torch.cuda.is_available():
|
||||
logger.debug("Could not set sdpa backend priority.", exc_info=exc_info)
|
||||
scaled_dot_product_attention = _scaled_dot_product_attention
|
||||
|
||||
cast_to = model_management.cast_to # TODO: remove once no more references
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
133
tests/unit/test_sdpa.py
Normal file
133
tests/unit/test_sdpa.py
Normal file
@ -0,0 +1,133 @@
|
||||
import pytest
|
||||
import torch
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# For version comparison
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
# Module under test
|
||||
import comfy.ops
|
||||
|
||||
TORCH_VERSION = parse_version(torch.__version__.split('+')[0])
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_module():
|
||||
"""Reloads comfy.ops after each test to reset its state."""
|
||||
yield
|
||||
importlib.reload(comfy.ops)
|
||||
|
||||
|
||||
def test_sdpa_no_cuda():
|
||||
"""
|
||||
Tests that scaled_dot_product_attention falls back to the basic implementation
|
||||
when CUDA is not available.
|
||||
"""
|
||||
with patch('torch.cuda.is_available', return_value=False):
|
||||
# Reload the module to apply the mock
|
||||
importlib.reload(comfy.ops)
|
||||
|
||||
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
|
||||
|
||||
# Test functionality
|
||||
q = torch.randn(2, 4, 8, 16)
|
||||
k = torch.randn(2, 4, 8, 16)
|
||||
v = torch.randn(2, 4, 8, 16)
|
||||
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
assert output.shape == q.shape
|
||||
|
||||
|
||||
def test_sdpa_old_torch_with_cuda():
|
||||
"""
|
||||
Tests that scaled_dot_product_attention falls back and warns
|
||||
on older torch versions that have CUDA but lack 'set_priority' in sdpa_kernel.
|
||||
"""
|
||||
# Mock signature object without 'set_priority'
|
||||
mock_signature = MagicMock()
|
||||
mock_signature.parameters = {}
|
||||
|
||||
# Mock the logger to capture warnings
|
||||
mock_logger = MagicMock()
|
||||
|
||||
# Mock the attention module to prevent import errors on non-CUDA builds
|
||||
mock_attention_module = MagicMock()
|
||||
mock_attention_module.sdpa_kernel = MagicMock()
|
||||
mock_attention_module.SDPBackend = MagicMock()
|
||||
|
||||
with patch('torch.cuda.is_available', return_value=True), \
|
||||
patch('inspect.signature', return_value=mock_signature), \
|
||||
patch('logging.getLogger', return_value=mock_logger), \
|
||||
patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}):
|
||||
importlib.reload(comfy.ops)
|
||||
|
||||
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
|
||||
mock_logger.warning.assert_called_once_with("Torch version too old to set sdpa backend priority, even though you are using CUDA")
|
||||
|
||||
# Test functionality
|
||||
q = torch.randn(2, 4, 8, 16)
|
||||
k = torch.randn(2, 4, 8, 16)
|
||||
v = torch.randn(2, 4, 8, 16)
|
||||
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
assert output.shape == q.shape
|
||||
|
||||
|
||||
def test_sdpa_import_exception():
|
||||
"""
|
||||
Tests that scaled_dot_product_attention falls back if an exception occurs
|
||||
during the SDPA setup.
|
||||
"""
|
||||
mock_logger = MagicMock()
|
||||
with patch('torch.cuda.is_available', return_value=True), \
|
||||
patch('inspect.signature', side_effect=Exception("Test Exception")), \
|
||||
patch('logging.getLogger', return_value=mock_logger):
|
||||
# Mock the attention module to prevent import errors on non-CUDA builds
|
||||
mock_attention_module = MagicMock()
|
||||
mock_attention_module.sdpa_kernel = MagicMock()
|
||||
mock_attention_module.SDPBackend = MagicMock()
|
||||
with patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}):
|
||||
importlib.reload(comfy.ops)
|
||||
|
||||
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
|
||||
mock_logger.debug.assert_called_once()
|
||||
# Check that the log message contains the exception info
|
||||
assert "Could not set sdpa backend priority." in mock_logger.debug.call_args[0][0]
|
||||
|
||||
# Test functionality
|
||||
q = torch.randn(2, 4, 8, 16)
|
||||
k = torch.randn(2, 4, 8, 16)
|
||||
v = torch.randn(2, 4, 8, 16)
|
||||
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
assert output.shape == q.shape
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA is not available")
|
||||
@pytest.mark.skipif(TORCH_VERSION < parse_version("2.6.0"), reason="Requires torch version 2.6.0 or greater")
|
||||
def test_sdpa_with_cuda_and_priority():
|
||||
"""
|
||||
Tests that the prioritized SDPA implementation is used when CUDA is available
|
||||
and the torch version is new enough.
|
||||
This is a real test and does not use mocks.
|
||||
"""
|
||||
# Reload to ensure the correct version is picked up based on the actual environment
|
||||
importlib.reload(comfy.ops)
|
||||
|
||||
# Check that the correct function is assigned
|
||||
assert comfy.ops.scaled_dot_product_attention is not comfy.ops._scaled_dot_product_attention
|
||||
assert comfy.ops.scaled_dot_product_attention.__name__ == "_scaled_dot_product_attention_sdpa"
|
||||
|
||||
# Create tensors on CUDA device
|
||||
device = torch.device("cuda")
|
||||
q = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
|
||||
k = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
|
||||
v = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
|
||||
|
||||
# Execute the function
|
||||
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# Assertions
|
||||
assert output.shape == q.shape
|
||||
assert output.device.type == device.type
|
||||
assert output.dtype == torch.float16
|
||||
Loading…
Reference in New Issue
Block a user