fix sdpa priorities

This commit is contained in:
doctorpangloss 2025-08-26 14:33:00 -07:00
parent 306fcbaa0e
commit 1e938f5feb
2 changed files with 153 additions and 8 deletions

View File

@ -26,13 +26,20 @@ from .cli_args import args, PerformanceFeature
from .execution_context import current_execution_context
from .float import stochastic_rounding
logger = logging.getLogger(__name__)
scaled_dot_product_attention = None
def _scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
try:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
SDPA_BACKEND_PRIORITY = [
SDPBackend.FLASH_ATTENTION,
@ -42,20 +49,25 @@ try:
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
def _scaled_dot_product_attention_sdpa(q, k, v, *args, **kwargs):
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
else:
logging.warning("Torch version too old to set sdpa backend priority.")
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
scaled_dot_product_attention = _scaled_dot_product_attention_sdpa
else:
logger.warning("Torch version too old to set sdpa backend priority, even though you are using CUDA")
scaled_dot_product_attention = _scaled_dot_product_attention
else:
scaled_dot_product_attention = _scaled_dot_product_attention
except Exception as exc_info:
if torch.cuda.is_available():
logger.debug("Could not set sdpa backend priority.", exc_info=exc_info)
scaled_dot_product_attention = _scaled_dot_product_attention
cast_to = model_management.cast_to # TODO: remove once no more references
logger = logging.getLogger(__name__)
def cast_to_input(weight, input, non_blocking=False, copy=True):
return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

133
tests/unit/test_sdpa.py Normal file
View File

@ -0,0 +1,133 @@
import pytest
import torch
import importlib
import sys
from unittest.mock import patch, MagicMock
# For version comparison
from packaging.version import parse as parse_version
# Module under test
import comfy.ops
TORCH_VERSION = parse_version(torch.__version__.split('+')[0])
CUDA_AVAILABLE = torch.cuda.is_available()
@pytest.fixture(autouse=True)
def cleanup_module():
"""Reloads comfy.ops after each test to reset its state."""
yield
importlib.reload(comfy.ops)
def test_sdpa_no_cuda():
"""
Tests that scaled_dot_product_attention falls back to the basic implementation
when CUDA is not available.
"""
with patch('torch.cuda.is_available', return_value=False):
# Reload the module to apply the mock
importlib.reload(comfy.ops)
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
# Test functionality
q = torch.randn(2, 4, 8, 16)
k = torch.randn(2, 4, 8, 16)
v = torch.randn(2, 4, 8, 16)
output = comfy.ops.scaled_dot_product_attention(q, k, v)
assert output.shape == q.shape
def test_sdpa_old_torch_with_cuda():
"""
Tests that scaled_dot_product_attention falls back and warns
on older torch versions that have CUDA but lack 'set_priority' in sdpa_kernel.
"""
# Mock signature object without 'set_priority'
mock_signature = MagicMock()
mock_signature.parameters = {}
# Mock the logger to capture warnings
mock_logger = MagicMock()
# Mock the attention module to prevent import errors on non-CUDA builds
mock_attention_module = MagicMock()
mock_attention_module.sdpa_kernel = MagicMock()
mock_attention_module.SDPBackend = MagicMock()
with patch('torch.cuda.is_available', return_value=True), \
patch('inspect.signature', return_value=mock_signature), \
patch('logging.getLogger', return_value=mock_logger), \
patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}):
importlib.reload(comfy.ops)
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
mock_logger.warning.assert_called_once_with("Torch version too old to set sdpa backend priority, even though you are using CUDA")
# Test functionality
q = torch.randn(2, 4, 8, 16)
k = torch.randn(2, 4, 8, 16)
v = torch.randn(2, 4, 8, 16)
output = comfy.ops.scaled_dot_product_attention(q, k, v)
assert output.shape == q.shape
def test_sdpa_import_exception():
"""
Tests that scaled_dot_product_attention falls back if an exception occurs
during the SDPA setup.
"""
mock_logger = MagicMock()
with patch('torch.cuda.is_available', return_value=True), \
patch('inspect.signature', side_effect=Exception("Test Exception")), \
patch('logging.getLogger', return_value=mock_logger):
# Mock the attention module to prevent import errors on non-CUDA builds
mock_attention_module = MagicMock()
mock_attention_module.sdpa_kernel = MagicMock()
mock_attention_module.SDPBackend = MagicMock()
with patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}):
importlib.reload(comfy.ops)
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
mock_logger.debug.assert_called_once()
# Check that the log message contains the exception info
assert "Could not set sdpa backend priority." in mock_logger.debug.call_args[0][0]
# Test functionality
q = torch.randn(2, 4, 8, 16)
k = torch.randn(2, 4, 8, 16)
v = torch.randn(2, 4, 8, 16)
output = comfy.ops.scaled_dot_product_attention(q, k, v)
assert output.shape == q.shape
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA is not available")
@pytest.mark.skipif(TORCH_VERSION < parse_version("2.6.0"), reason="Requires torch version 2.6.0 or greater")
def test_sdpa_with_cuda_and_priority():
"""
Tests that the prioritized SDPA implementation is used when CUDA is available
and the torch version is new enough.
This is a real test and does not use mocks.
"""
# Reload to ensure the correct version is picked up based on the actual environment
importlib.reload(comfy.ops)
# Check that the correct function is assigned
assert comfy.ops.scaled_dot_product_attention is not comfy.ops._scaled_dot_product_attention
assert comfy.ops.scaled_dot_product_attention.__name__ == "_scaled_dot_product_attention_sdpa"
# Create tensors on CUDA device
device = torch.device("cuda")
q = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
k = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
v = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
# Execute the function
output = comfy.ops.scaled_dot_product_attention(q, k, v)
# Assertions
assert output.shape == q.shape
assert output.device.type == device.type
assert output.dtype == torch.float16