diff --git a/comfy/ops.py b/comfy/ops.py index dbe910a4c..3bdf64a99 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -26,13 +26,20 @@ from .cli_args import args, PerformanceFeature from .execution_context import current_execution_context from .float import stochastic_rounding +logger = logging.getLogger(__name__) +scaled_dot_product_attention = None + + +def _scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) try: if torch.cuda.is_available(): from torch.nn.attention import SDPBackend, sdpa_kernel import inspect + if "set_priority" in inspect.signature(sdpa_kernel).parameters: SDPA_BACKEND_PRIORITY = [ SDPBackend.FLASH_ATTENTION, @@ -42,20 +49,25 @@ try: SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - def scaled_dot_product_attention(q, k, v, *args, **kwargs): + + def _scaled_dot_product_attention_sdpa(q, k, v, *args, **kwargs): with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) - else: - logging.warning("Torch version too old to set sdpa backend priority.") -except (ModuleNotFoundError, TypeError): - logging.warning("Could not set sdpa backend priority.") - def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + + scaled_dot_product_attention = _scaled_dot_product_attention_sdpa + else: + logger.warning("Torch version too old to set sdpa backend priority, even though you are using CUDA") + scaled_dot_product_attention = _scaled_dot_product_attention + else: + scaled_dot_product_attention = _scaled_dot_product_attention +except Exception as exc_info: + if torch.cuda.is_available(): + logger.debug("Could not set sdpa backend priority.", exc_info=exc_info) + scaled_dot_product_attention = _scaled_dot_product_attention cast_to = model_management.cast_to # TODO: remove once no more references -logger = logging.getLogger(__name__) def cast_to_input(weight, input, non_blocking=False, copy=True): return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) diff --git a/tests/unit/test_sdpa.py b/tests/unit/test_sdpa.py new file mode 100644 index 000000000..926160c39 --- /dev/null +++ b/tests/unit/test_sdpa.py @@ -0,0 +1,133 @@ +import pytest +import torch +import importlib +import sys +from unittest.mock import patch, MagicMock + +# For version comparison +from packaging.version import parse as parse_version + +# Module under test +import comfy.ops + +TORCH_VERSION = parse_version(torch.__version__.split('+')[0]) +CUDA_AVAILABLE = torch.cuda.is_available() + + +@pytest.fixture(autouse=True) +def cleanup_module(): + """Reloads comfy.ops after each test to reset its state.""" + yield + importlib.reload(comfy.ops) + + +def test_sdpa_no_cuda(): + """ + Tests that scaled_dot_product_attention falls back to the basic implementation + when CUDA is not available. + """ + with patch('torch.cuda.is_available', return_value=False): + # Reload the module to apply the mock + importlib.reload(comfy.ops) + + assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention + + # Test functionality + q = torch.randn(2, 4, 8, 16) + k = torch.randn(2, 4, 8, 16) + v = torch.randn(2, 4, 8, 16) + output = comfy.ops.scaled_dot_product_attention(q, k, v) + assert output.shape == q.shape + + +def test_sdpa_old_torch_with_cuda(): + """ + Tests that scaled_dot_product_attention falls back and warns + on older torch versions that have CUDA but lack 'set_priority' in sdpa_kernel. + """ + # Mock signature object without 'set_priority' + mock_signature = MagicMock() + mock_signature.parameters = {} + + # Mock the logger to capture warnings + mock_logger = MagicMock() + + # Mock the attention module to prevent import errors on non-CUDA builds + mock_attention_module = MagicMock() + mock_attention_module.sdpa_kernel = MagicMock() + mock_attention_module.SDPBackend = MagicMock() + + with patch('torch.cuda.is_available', return_value=True), \ + patch('inspect.signature', return_value=mock_signature), \ + patch('logging.getLogger', return_value=mock_logger), \ + patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}): + importlib.reload(comfy.ops) + + assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention + mock_logger.warning.assert_called_once_with("Torch version too old to set sdpa backend priority, even though you are using CUDA") + + # Test functionality + q = torch.randn(2, 4, 8, 16) + k = torch.randn(2, 4, 8, 16) + v = torch.randn(2, 4, 8, 16) + output = comfy.ops.scaled_dot_product_attention(q, k, v) + assert output.shape == q.shape + + +def test_sdpa_import_exception(): + """ + Tests that scaled_dot_product_attention falls back if an exception occurs + during the SDPA setup. + """ + mock_logger = MagicMock() + with patch('torch.cuda.is_available', return_value=True), \ + patch('inspect.signature', side_effect=Exception("Test Exception")), \ + patch('logging.getLogger', return_value=mock_logger): + # Mock the attention module to prevent import errors on non-CUDA builds + mock_attention_module = MagicMock() + mock_attention_module.sdpa_kernel = MagicMock() + mock_attention_module.SDPBackend = MagicMock() + with patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}): + importlib.reload(comfy.ops) + + assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention + mock_logger.debug.assert_called_once() + # Check that the log message contains the exception info + assert "Could not set sdpa backend priority." in mock_logger.debug.call_args[0][0] + + # Test functionality + q = torch.randn(2, 4, 8, 16) + k = torch.randn(2, 4, 8, 16) + v = torch.randn(2, 4, 8, 16) + output = comfy.ops.scaled_dot_product_attention(q, k, v) + assert output.shape == q.shape + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA is not available") +@pytest.mark.skipif(TORCH_VERSION < parse_version("2.6.0"), reason="Requires torch version 2.6.0 or greater") +def test_sdpa_with_cuda_and_priority(): + """ + Tests that the prioritized SDPA implementation is used when CUDA is available + and the torch version is new enough. + This is a real test and does not use mocks. + """ + # Reload to ensure the correct version is picked up based on the actual environment + importlib.reload(comfy.ops) + + # Check that the correct function is assigned + assert comfy.ops.scaled_dot_product_attention is not comfy.ops._scaled_dot_product_attention + assert comfy.ops.scaled_dot_product_attention.__name__ == "_scaled_dot_product_attention_sdpa" + + # Create tensors on CUDA device + device = torch.device("cuda") + q = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16) + k = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16) + v = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16) + + # Execute the function + output = comfy.ops.scaled_dot_product_attention(q, k, v) + + # Assertions + assert output.shape == q.shape + assert output.device.type == device.type + assert output.dtype == torch.float16