from typing import Optional, Any import torch # only imported when sage attention is enabled from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp8_cuda, sageattn_qk_int8_pv_fp8_cuda_sm90 # pylint: disable=import-error def get_cuda_arch_versions(): cuda_archs = [] for i in range(torch.cuda.device_count()): major, minor = torch.cuda.get_device_capability(i) cuda_archs.append(f"sm{major}{minor}") return cuda_archs def sageattn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, sm_scale: Optional[float] = None, return_lse: bool = False, **kwargs: Any, ): """ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. Parameters ---------- q : torch.Tensor The query tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. k : torch.Tensor The key tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. v : torch.Tensor The value tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. tensor_layout : str The tensor layout, either "HND" or "NHD". Default: "HND". is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. return_lse : bool Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. Default: False. Returns ------- torch.Tensor The output tensor. Shape: - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. torch.Tensor The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: ``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. Note ---- - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. """ arch = get_cuda_arch_versions()[q.device.index] if arch in ("sm80", "sm86"): return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # todo: the triton kernel is broken on ampere, so disable it # elif arch == "sm86": # return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse) elif arch == "sm89": return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") elif arch == "sm90": return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32") elif arch == "sm120": return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. else: raise ValueError(f"Unsupported CUDA architecture: {arch}")