This commit is contained in:
vivienfanghua 2025-10-28 19:31:23 +08:00
parent 897d2662be
commit 1c5a503715

View File

@ -628,34 +628,17 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out
try:
@torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=())
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
causal: bool = False, window_size: tuple = (-1, -1),
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False) -> torch.Tensor:
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
causal=causal, window_size=window_size, bias=bias,
alibi_slopes=alibi_slopes, deterministic=deterministic,
return_lse=False, return_attn_probs=False,
cu_seqlens_q=None, cu_seqlens_kv=None)
@aiter_flash_attn_wrapper.register_fake
def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
AITER_ATTN_ERROR = error
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
causal: bool = False, window_size: tuple = (-1, -1),
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False) -> torch.Tensor:
assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}"
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
causal: bool = False, window_size: tuple = (-1, -1),
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False) -> torch.Tensor:
"""Wrapper for aiter.flash_attn_func to handle its specific parameters"""
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
causal=causal, window_size=window_size, bias=bias,
alibi_slopes=alibi_slopes, deterministic=deterministic,
return_lse=False, return_attn_probs=False,
cu_seqlens_q=None, cu_seqlens_kv=None)
@wrap_attn
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):