mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
94 lines
4.0 KiB
Python
94 lines
4.0 KiB
Python
from typing import Optional, Any
|
|
|
|
import torch
|
|
# only imported when sage attention is enabled
|
|
from sageattention import * # pylint: disable=import-error
|
|
|
|
|
|
def get_cuda_arch_versions():
|
|
cuda_archs = []
|
|
for i in range(torch.cuda.device_count()):
|
|
major, minor = torch.cuda.get_device_capability(i)
|
|
cuda_archs.append(f"sm{major}{minor}")
|
|
return cuda_archs
|
|
|
|
|
|
def sageattn(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
tensor_layout: str = "HND",
|
|
is_causal: bool = False,
|
|
sm_scale: Optional[float] = None,
|
|
return_lse: bool = False,
|
|
**kwargs: Any,
|
|
):
|
|
"""
|
|
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
|
|
|
|
Parameters
|
|
----------
|
|
q : torch.Tensor
|
|
The query tensor. Shape:
|
|
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
|
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
|
|
|
k : torch.Tensor
|
|
The key tensor. Shape:
|
|
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
|
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
|
|
|
v : torch.Tensor
|
|
The value tensor. Shape:
|
|
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
|
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
|
|
|
tensor_layout : str
|
|
The tensor layout, either "HND" or "NHD".
|
|
Default: "HND".
|
|
|
|
is_causal : bool
|
|
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
|
|
Default: False.
|
|
|
|
sm_scale : Optional[float]
|
|
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
|
|
|
return_lse : bool
|
|
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
|
|
Default: False.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
The output tensor. Shape:
|
|
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
|
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
|
|
|
torch.Tensor
|
|
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
|
|
Shape: ``[batch_size, num_qo_heads, qo_len]``.
|
|
Only returned if `return_lse` is True.
|
|
|
|
Note
|
|
----
|
|
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
|
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
|
|
- All tensors must be on the same cuda device.
|
|
"""
|
|
|
|
arch = get_cuda_arch_versions()[q.device.index]
|
|
if arch in ("sm80", "sm86"):
|
|
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
|
|
# todo: the triton kernel is broken on ampere, so disable it
|
|
# elif arch == "sm86":
|
|
# return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse)
|
|
elif arch == "sm89":
|
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
|
|
elif arch == "sm90":
|
|
return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
|
|
elif arch == "sm120":
|
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
|
|
else:
|
|
raise ValueError(f"Unsupported CUDA architecture: {arch}")
|