mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
131 lines
4.9 KiB
Python
131 lines
4.9 KiB
Python
import pytest
|
|
import torch
|
|
import importlib
|
|
import sys
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
# For version comparison
|
|
from packaging.version import parse as parse_version
|
|
|
|
# Module under test
|
|
import comfy.ops
|
|
|
|
TORCH_VERSION = parse_version(torch.__version__.split('+')[0])
|
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def cleanup_module():
|
|
"""Reloads comfy.ops after each test to reset its state."""
|
|
yield
|
|
importlib.reload(comfy.ops)
|
|
|
|
|
|
def test_sdpa_no_cuda():
|
|
"""
|
|
Tests that scaled_dot_product_attention falls back to the basic implementation
|
|
when CUDA is not available.
|
|
"""
|
|
with patch('torch.cuda.is_available', return_value=False):
|
|
# Reload the module to apply the mock
|
|
importlib.reload(comfy.ops)
|
|
|
|
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
|
|
|
|
# Test functionality
|
|
q = torch.randn(2, 4, 8, 16)
|
|
k = torch.randn(2, 4, 8, 16)
|
|
v = torch.randn(2, 4, 8, 16)
|
|
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
|
assert output.shape == q.shape
|
|
|
|
|
|
def test_sdpa_old_torch_with_cuda():
|
|
"""
|
|
Tests that scaled_dot_product_attention falls back and warns
|
|
on older torch versions that have CUDA but lack 'set_priority' in sdpa_kernel.
|
|
"""
|
|
# Mock signature object without 'set_priority'
|
|
mock_signature = MagicMock()
|
|
mock_signature.parameters = {}
|
|
|
|
# Mock the logger to capture warnings
|
|
mock_logger = MagicMock()
|
|
|
|
# Mock the attention module to prevent import errors on non-CUDA builds
|
|
mock_attention_module = MagicMock()
|
|
mock_attention_module.sdpa_kernel = MagicMock()
|
|
mock_attention_module.SDPBackend = MagicMock()
|
|
|
|
with patch('torch.cuda.is_available', return_value=True), \
|
|
patch('inspect.signature', return_value=mock_signature), \
|
|
patch('logging.getLogger', return_value=mock_logger), \
|
|
patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}):
|
|
importlib.reload(comfy.ops)
|
|
|
|
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
|
|
mock_logger.warning.assert_called_once_with("Torch version too old to set sdpa backend priority, even though you are using CUDA")
|
|
|
|
# Test functionality
|
|
q = torch.randn(2, 4, 8, 16)
|
|
k = torch.randn(2, 4, 8, 16)
|
|
v = torch.randn(2, 4, 8, 16)
|
|
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
|
assert output.shape == q.shape
|
|
|
|
|
|
def test_sdpa_import_exception():
|
|
"""
|
|
Tests that scaled_dot_product_attention falls back if an exception occurs
|
|
during the SDPA setup.
|
|
"""
|
|
mock_logger = MagicMock()
|
|
with patch('torch.cuda.is_available', return_value=True), \
|
|
patch('inspect.signature', side_effect=Exception("Test Exception")), \
|
|
patch('logging.getLogger', return_value=mock_logger):
|
|
# Mock the attention module to prevent import errors on non-CUDA builds
|
|
mock_attention_module = MagicMock()
|
|
mock_attention_module.sdpa_kernel = MagicMock()
|
|
mock_attention_module.SDPBackend = MagicMock()
|
|
with patch.dict('sys.modules', {'torch.nn.attention': mock_attention_module}):
|
|
importlib.reload(comfy.ops)
|
|
|
|
assert comfy.ops.scaled_dot_product_attention is comfy.ops._scaled_dot_product_attention
|
|
|
|
# Test functionality
|
|
q = torch.randn(2, 4, 8, 16)
|
|
k = torch.randn(2, 4, 8, 16)
|
|
v = torch.randn(2, 4, 8, 16)
|
|
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
|
assert output.shape == q.shape
|
|
|
|
|
|
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA is not available")
|
|
@pytest.mark.skipif(TORCH_VERSION < parse_version("2.6.0"), reason="Requires torch version 2.6.0 or greater")
|
|
def test_sdpa_with_cuda_and_priority():
|
|
"""
|
|
Tests that the prioritized SDPA implementation is used when CUDA is available
|
|
and the torch version is new enough.
|
|
This is a real test and does not use mocks.
|
|
"""
|
|
# Reload to ensure the correct version is picked up based on the actual environment
|
|
importlib.reload(comfy.ops)
|
|
|
|
# Check that the correct function is assigned
|
|
assert comfy.ops.scaled_dot_product_attention is not comfy.ops._scaled_dot_product_attention
|
|
assert comfy.ops.scaled_dot_product_attention.__name__ == "_scaled_dot_product_attention_sdpa"
|
|
|
|
# Create tensors on CUDA device
|
|
device = torch.device("cuda")
|
|
q = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
|
|
k = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
|
|
v = torch.randn(2, 4, 8, 16, device=device, dtype=torch.float16)
|
|
|
|
# Execute the function
|
|
output = comfy.ops.scaled_dot_product_attention(q, k, v)
|
|
|
|
# Assertions
|
|
assert output.shape == q.shape
|
|
assert output.device.type == device.type
|
|
assert output.dtype == torch.float16
|