ComfyUI/comfy/ldm/modules/sage_attention_dispatcher.py

94 lines
4.0 KiB
Python

from typing import Optional, Any
import torch
# only imported when sage attention is enabled
from sageattention import * # pylint: disable=import-error
def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs
def sageattn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: bool = False,
**kwargs: Any,
):
"""
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""
arch = get_cuda_arch_versions()[q.device.index]
if arch in ("sm80", "sm86"):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
# todo: the triton kernel is broken on ampere, so disable it
# elif arch == "sm86":
# return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse)
elif arch == "sm89":
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
elif arch == "sm90":
return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
elif arch == "sm120":
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
else:
raise ValueError(f"Unsupported CUDA architecture: {arch}")