Add support for flash attention 3

This commit is contained in:
simonri 2025-11-16 16:36:48 +00:00
parent 7d6103325e
commit 603a721405

View File

@ -32,7 +32,7 @@ except ImportError as e:
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
from flash_attn_interface import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled():
@ -565,7 +565,7 @@ try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
return flash_attn_func(q, k, v, causal=causal)
@flash_attn_wrapper.register_fake