This commit is contained in:
vivienfanghua 2025-10-28 19:31:23 +08:00
parent 897d2662be
commit 1c5a503715

View File

@ -628,35 +628,18 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out return out
try: def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
@torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=())
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, softmax_scale: Optional[float] = None, dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
causal: bool = False, window_size: tuple = (-1, -1), causal: bool = False, window_size: tuple = (-1, -1),
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False) -> torch.Tensor: deterministic: bool = False) -> torch.Tensor:
"""Wrapper for aiter.flash_attn_func to handle its specific parameters"""
return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
causal=causal, window_size=window_size, bias=bias, causal=causal, window_size=window_size, bias=bias,
alibi_slopes=alibi_slopes, deterministic=deterministic, alibi_slopes=alibi_slopes, deterministic=deterministic,
return_lse=False, return_attn_probs=False, return_lse=False, return_attn_probs=False,
cu_seqlens_q=None, cu_seqlens_kv=None) cu_seqlens_q=None, cu_seqlens_kv=None)
@aiter_flash_attn_wrapper.register_fake
def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
AITER_ATTN_ERROR = error
def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, softmax_scale: Optional[float] = None,
causal: bool = False, window_size: tuple = (-1, -1),
bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False) -> torch.Tensor:
assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}"
@wrap_attn @wrap_attn
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape: if skip_reshape: