mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Add files via upload
This commit is contained in:
parent
caa3597bb7
commit
aad001bfbe
0
comfy/flash_attn_triton_amd/__init__.py
Normal file
0
comfy/flash_attn_triton_amd/__init__.py
Normal file
606
comfy/flash_attn_triton_amd/bwd_prefill.py
Normal file
606
comfy/flash_attn_triton_amd/bwd_prefill.py
Normal file
@ -0,0 +1,606 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from comfy.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess_use_o(
|
||||
Out,
|
||||
DO,
|
||||
Delta,
|
||||
stride_oz, stride_oh, stride_om, stride_ok,
|
||||
stride_doz, stride_doh, stride_dom, stride_dok, # pylint: disable=unused-argument
|
||||
stride_deltaz, stride_deltah, stride_deltam,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
N_CTX_Q: tl.constexpr,
|
||||
Z: tl.constexpr, # pylint: disable=unused-argument
|
||||
H: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr
|
||||
):
|
||||
pid_m = tl.program_id(0)
|
||||
pid_bh = tl.program_id(1)
|
||||
|
||||
# Compute batch and head indices
|
||||
off_z = pid_bh // H
|
||||
off_h = pid_bh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
# Compute sequence lengths for the current batch
|
||||
q_start = tl.load(cu_seqlens_q + off_z)
|
||||
q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
k_start = tl.load(cu_seqlens_k + off_z)
|
||||
k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
|
||||
# Compute actual sequence lengths
|
||||
N_CTX_Q = q_end - q_start
|
||||
N_CTX_K = k_end - k_start # pylint: disable=unused-variable
|
||||
else:
|
||||
q_start = 0
|
||||
k_start = 0
|
||||
N_CTX_Q = max_seqlen_q
|
||||
N_CTX_K = max_seqlen_k # pylint: disable=unused-variable
|
||||
|
||||
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
# create masks
|
||||
mask_m = off_m < N_CTX_Q
|
||||
mask_d = off_d < ACTUAL_BLOCK_DMODEL
|
||||
|
||||
# compute offsets
|
||||
o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
|
||||
do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
|
||||
|
||||
# compute pointers
|
||||
out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok
|
||||
do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok
|
||||
|
||||
# load
|
||||
o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
|
||||
do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
|
||||
|
||||
# compute delta
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
|
||||
# write-back delta
|
||||
delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
|
||||
delta_ptrs = delta_offset + off_m * stride_deltam
|
||||
tl.store(delta_ptrs, delta, mask=mask_m)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out, DO, DQ, DK, DV, L, D, # pylint: disable=unused-argument
|
||||
q_offset,
|
||||
k_offset,
|
||||
v_offset,
|
||||
do_offset,
|
||||
dq_offset,
|
||||
dk_offset,
|
||||
dv_offset,
|
||||
d_offset,
|
||||
l_offset,
|
||||
stride_dq_all, stride_qz, stride_qh, # pylint: disable=unused-argument
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz, stride_kh, # pylint: disable=unused-argument
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz, stride_vh, # pylint: disable=unused-argument
|
||||
stride_vn,
|
||||
stride_vk,
|
||||
stride_deltaz, stride_deltah, # pylint: disable=unused-argument
|
||||
stride_deltam,
|
||||
Z, H, # pylint: disable=unused-argument
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
off_h, off_z, off_hz, # pylint: disable=unused-argument
|
||||
start_n,
|
||||
num_block_m,
|
||||
num_block_n, # pylint: disable=unused-argument
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
USE_EXP2: tl.constexpr,
|
||||
):
|
||||
if CAUSAL:
|
||||
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
|
||||
lo = 0
|
||||
else:
|
||||
lo = 0
|
||||
|
||||
# initialize col and head offsets
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
# masks
|
||||
mask_n = offs_n < N_CTX_K
|
||||
mask_d = offs_d < ACTUAL_BLOCK_DMODEL
|
||||
kv_mask = mask_n[:, None] & mask_d[None, :]
|
||||
|
||||
# initialize grad accumulators
|
||||
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
# load k and v once per column block
|
||||
k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
|
||||
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
|
||||
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
|
||||
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M)
|
||||
q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
|
||||
# update mask as row block changes
|
||||
mask_m = offs_m < N_CTX_Q
|
||||
q_mask = mask_m[:, None] & mask_d[None, :]
|
||||
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
|
||||
do = tl.load(do_ptrs, mask=q_mask, other=0.0)
|
||||
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
|
||||
if CAUSAL:
|
||||
col_offset = N_CTX_Q - N_CTX_K
|
||||
causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :])
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
|
||||
l_ptrs = l_offset + offs_m * stride_deltam
|
||||
l_i = tl.load(l_ptrs, mask=mask_m)
|
||||
|
||||
# compute p
|
||||
if USE_EXP2:
|
||||
RCP_LN2: tl.constexpr = 1.4426950408889634
|
||||
qk *= sm_scale * RCP_LN2
|
||||
l_i *= RCP_LN2
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
else:
|
||||
qk *= sm_scale
|
||||
p = tl.math.exp(qk - l_i[:, None])
|
||||
|
||||
# mask block in the cases where the data is smaller the block size
|
||||
p_mask = mask_m[:, None] & mask_n[None, :]
|
||||
p = tl.where(p_mask, p, 0.0)
|
||||
|
||||
# compute dv
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
|
||||
# compute dp
|
||||
dp = tl.dot(do, tl.trans(v))
|
||||
|
||||
# compute ds , ds = p * (dp - delta[:, None])
|
||||
d_ptrs = d_offset + offs_m * stride_deltam
|
||||
Di = tl.load(d_ptrs, mask=mask_m)
|
||||
ds = (p * (dp - Di[:, None])) * sm_scale
|
||||
ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
|
||||
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(tl.trans(ds), q)
|
||||
|
||||
# compute dq
|
||||
if SEQUENCE_PARALLEL:
|
||||
dq = tl.dot(ds, k)
|
||||
else:
|
||||
dq = tl.load(dq_ptrs, mask=q_mask, other=0.0)
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask)
|
||||
|
||||
# write-back dv and dk
|
||||
dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
|
||||
|
||||
# write-back
|
||||
tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
|
||||
tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out,
|
||||
DO,
|
||||
DQ,
|
||||
DK,
|
||||
DV,
|
||||
L,
|
||||
D,
|
||||
stride_dq_all,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_vk,
|
||||
stride_deltaz,
|
||||
stride_deltah,
|
||||
stride_deltam,
|
||||
Z,
|
||||
H,
|
||||
num_block_m,
|
||||
num_block_n,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
USE_EXP2: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
# program ids
|
||||
off_hz = tl.program_id(0)
|
||||
if SEQUENCE_PARALLEL:
|
||||
start_n = tl.program_id(1)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
|
||||
if IS_VARLEN:
|
||||
# Compute sequence lengths for the current batch
|
||||
q_start = tl.load(cu_seqlens_q + off_z)
|
||||
q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
k_start = tl.load(cu_seqlens_k + off_z)
|
||||
k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
|
||||
# Compute actual sequence lengths
|
||||
N_CTX_Q = q_end - q_start
|
||||
N_CTX_K = k_end - k_start
|
||||
else:
|
||||
q_start = 0
|
||||
k_start = 0
|
||||
N_CTX_Q = max_seqlen_q
|
||||
N_CTX_K = max_seqlen_k
|
||||
|
||||
# input tensor offsets
|
||||
q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
|
||||
k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
|
||||
v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
|
||||
do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
|
||||
l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
|
||||
d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
|
||||
|
||||
# output tensor offsets
|
||||
dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
|
||||
dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
|
||||
if SEQUENCE_PARALLEL:
|
||||
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
|
||||
else:
|
||||
dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
|
||||
|
||||
# inner loop
|
||||
if SEQUENCE_PARALLEL:
|
||||
_bwd_kernel_one_col_block(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out,
|
||||
DO,
|
||||
DQ,
|
||||
DK,
|
||||
DV,
|
||||
L,
|
||||
D,
|
||||
q_offset,
|
||||
k_offset,
|
||||
v_offset,
|
||||
do_offset,
|
||||
dq_offset,
|
||||
dk_offset,
|
||||
dv_offset,
|
||||
d_offset,
|
||||
l_offset,
|
||||
stride_dq_all,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_vk,
|
||||
stride_deltaz,
|
||||
stride_deltah,
|
||||
stride_deltam,
|
||||
Z,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
off_h,
|
||||
off_z,
|
||||
off_hz,
|
||||
start_n,
|
||||
num_block_m,
|
||||
num_block_n,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
USE_EXP2=USE_EXP2,
|
||||
)
|
||||
else:
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out,
|
||||
DO,
|
||||
DQ,
|
||||
DK,
|
||||
DV,
|
||||
L,
|
||||
D,
|
||||
q_offset,
|
||||
k_offset,
|
||||
v_offset,
|
||||
do_offset,
|
||||
dq_offset,
|
||||
dk_offset,
|
||||
dv_offset,
|
||||
d_offset,
|
||||
l_offset,
|
||||
stride_dq_all,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_vk,
|
||||
stride_deltaz,
|
||||
stride_deltah,
|
||||
stride_deltam,
|
||||
Z,
|
||||
H,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
off_h,
|
||||
off_z,
|
||||
off_hz,
|
||||
start_n,
|
||||
num_block_m,
|
||||
num_block_n,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
CAUSAL=CAUSAL,
|
||||
USE_EXP2=USE_EXP2,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
|
||||
def attention_prefill_backward_triton_impl(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
sm_scale: float,
|
||||
alibi_slopes, # pylint: disable=unused-argument
|
||||
causal,
|
||||
layout: str,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
use_exp2: bool,
|
||||
sequence_parallel = True,
|
||||
):
|
||||
# make contigious
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
softmax_lse = softmax_lse.contiguous()
|
||||
|
||||
# get strides and shape
|
||||
batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
|
||||
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
|
||||
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
|
||||
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
|
||||
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
|
||||
stride_oz, stride_oh, stride_om, stride_ok = o_strides
|
||||
batch_headsize = batch * nheads_q
|
||||
is_varlen = layout == "thd"
|
||||
|
||||
# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
|
||||
if max_seqlen_q <= 32 or max_seqlen_k <= 32:
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
else:
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 64
|
||||
num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
|
||||
num_stages = 1
|
||||
waves_per_eu = 1
|
||||
|
||||
# divide up the problem
|
||||
num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M)
|
||||
num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N)
|
||||
|
||||
# get closest power of 2 over or equal to 32.
|
||||
padded_d_model = 1 << (head_size - 1).bit_length()
|
||||
padded_d_model = max(padded_d_model, 16)
|
||||
BLOCK_DMODEL = padded_d_model
|
||||
ACTUAL_BLOCK_DMODEL = head_size
|
||||
|
||||
do = do.contiguous()
|
||||
# NOTE: we might need to copy the output tensor if they are not continuous or have other issues
|
||||
copy_back = {"dq": False, "dk": False, "dv": False}
|
||||
|
||||
dq_og = None
|
||||
# deal with dq
|
||||
if dq is None:
|
||||
if sequence_parallel:
|
||||
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
|
||||
else:
|
||||
dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
|
||||
else:
|
||||
dq_og = dq
|
||||
if not dq.is_contiguous():
|
||||
dq = dq.contiguous()
|
||||
copy_back["dq"] = True
|
||||
|
||||
if sequence_parallel:
|
||||
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
|
||||
copy_back["dq"] = True
|
||||
else:
|
||||
# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
|
||||
dq.zero_()
|
||||
stride_dq_all = dq.stride()[0]
|
||||
|
||||
dk_og = None
|
||||
dv_og = None
|
||||
# deal with dk, dv
|
||||
if (dk is None) or (dv is None):
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
else:
|
||||
if not dk.is_contiguous():
|
||||
dk_og = dk
|
||||
dk = dk.contiguous()
|
||||
copy_back["dk"] = True
|
||||
|
||||
if not dv.is_contiguous():
|
||||
dv_og = dv
|
||||
dv = dv.contiguous()
|
||||
copy_back["dv"] = True
|
||||
|
||||
# assert contigious
|
||||
assert do.is_contiguous()
|
||||
assert q.is_contiguous()
|
||||
assert k.is_contiguous()
|
||||
assert v.is_contiguous()
|
||||
assert o.is_contiguous()
|
||||
assert softmax_lse.is_contiguous()
|
||||
|
||||
# init delta
|
||||
delta = torch.empty_like(softmax_lse)
|
||||
if is_varlen:
|
||||
stride_deltam, stride_deltah = delta.stride()
|
||||
stride_deltaz = 0
|
||||
else:
|
||||
stride_deltaz, stride_deltah, stride_deltam = delta.stride()
|
||||
|
||||
_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
|
||||
o,
|
||||
do,
|
||||
delta,
|
||||
stride_oz, stride_oh, stride_om, stride_ok,
|
||||
stride_oz, stride_oh, stride_om, stride_ok,
|
||||
stride_deltaz, stride_deltah, stride_deltam,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
|
||||
N_CTX_Q=max_seqlen_q,
|
||||
Z=batch,
|
||||
H=nheads_q,
|
||||
IS_VARLEN=is_varlen
|
||||
)
|
||||
|
||||
_bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
o,
|
||||
do,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
softmax_lse,
|
||||
delta,
|
||||
stride_dq_all,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
stride_deltaz, stride_deltah, stride_deltam,
|
||||
batch,
|
||||
nheads_q,
|
||||
num_blocks_m,
|
||||
num_blocks_n,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
CAUSAL=causal,
|
||||
USE_EXP2=use_exp2,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
waves_per_eu = waves_per_eu,
|
||||
IS_VARLEN=is_varlen
|
||||
)
|
||||
|
||||
if sequence_parallel:
|
||||
dq = dq.sum(dim=0)
|
||||
|
||||
if copy_back["dq"]:
|
||||
dq_og.copy_(dq)
|
||||
dq = dq_og
|
||||
if copy_back["dk"]:
|
||||
dk_og.copy_(dk)
|
||||
dk = dk_og
|
||||
if copy_back["dv"]:
|
||||
dv_og.copy_(dv)
|
||||
dv = dv_og
|
||||
|
||||
return dq, dk, dv, delta, None, None
|
||||
271
comfy/flash_attn_triton_amd/bwd_ref.py
Normal file
271
comfy/flash_attn_triton_amd/bwd_ref.py
Normal file
@ -0,0 +1,271 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
def attention_backward_core_ref_impl(
|
||||
do, q, k, v, o, softmax_lse, sm_scale, causal, use_exp2
|
||||
):
|
||||
# cast to float32
|
||||
do = do.to(torch.float32)
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
v = v.to(torch.float32)
|
||||
o = o.to(torch.float32)
|
||||
softmax_lse = softmax_lse.to(torch.float32)
|
||||
|
||||
# recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32
|
||||
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
|
||||
|
||||
# scale scores
|
||||
attention_scaled_scores = sm_scale * attention_scores
|
||||
|
||||
# Apply causal mask if necessary
|
||||
if causal:
|
||||
L_q, L_k = q.shape[1], k.shape[1]
|
||||
row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
|
||||
col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
|
||||
col_offset = L_q-L_k
|
||||
causal_mask = row_idx >= (col_offset + col_idx)
|
||||
# set -inf to places the causal mask is false
|
||||
attention_scaled_scores = attention_scaled_scores.masked_fill(
|
||||
torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
|
||||
)
|
||||
|
||||
# compute probabilities using softmax_lse
|
||||
if use_exp2:
|
||||
RCP_LN = 1 / math.log(2)
|
||||
attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN
|
||||
softmax_lse_base2 = softmax_lse * RCP_LN
|
||||
softmax_lse_3d = softmax_lse_base2.unsqueeze(-1)
|
||||
p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d)
|
||||
else:
|
||||
softmax_lse_3d = softmax_lse.unsqueeze(-1)
|
||||
p = torch.exp(attention_scaled_scores - softmax_lse_3d)
|
||||
|
||||
# compute gradient wrt v
|
||||
dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32))
|
||||
|
||||
# compute dp
|
||||
dp = torch.matmul(do, v.transpose(-2, -1))
|
||||
|
||||
# calculate ds using dp
|
||||
delta = torch.sum(o * do, axis=-1).to(torch.float32) # what OAI kernel uses
|
||||
delta_3d = delta.unsqueeze(-1)
|
||||
ds = (p * (dp - delta_3d)) * sm_scale
|
||||
|
||||
# compute gradient wrt k
|
||||
dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32))
|
||||
|
||||
# compute gradient wrt q
|
||||
dq = torch.matmul(ds, k.to(torch.float32))
|
||||
|
||||
# cast back to original dtype
|
||||
dq = dq.to(torch.float16)
|
||||
dk = dk.to(torch.float16)
|
||||
dv = dv.to(torch.float16)
|
||||
|
||||
# remove d dim with size 1
|
||||
delta = delta_3d.squeeze(-1)
|
||||
|
||||
return dq, dk, dv, delta
|
||||
|
||||
def attention_varlen_backward_pytorch_ref_impl(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
softmax_lse,
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument
|
||||
use_exp2,
|
||||
):
|
||||
# Ensure the layout is 'thd'
|
||||
if layout != 'thd':
|
||||
raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
|
||||
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
num_heads = q.shape[1]
|
||||
head_dim = q.shape[2] # pylint: disable=unused-variable
|
||||
|
||||
# Pre-allocate outputs
|
||||
total_L_q = q.shape[0]
|
||||
total_L_k = k.shape[0] # pylint: disable=unused-variable
|
||||
|
||||
dq = torch.zeros_like(q)
|
||||
dk = torch.zeros_like(k)
|
||||
dv = torch.zeros_like(v)
|
||||
# delta has the same shape as softmax_lse: [total_L_q, num_heads]
|
||||
delta = torch.zeros((total_L_q, num_heads), dtype=torch.float32, device=o.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
# Get the start and end indices for the current sequence
|
||||
start_q = cu_seqlens_q[i].item()
|
||||
end_q = cu_seqlens_q[i + 1].item()
|
||||
start_k = cu_seqlens_k[i].item()
|
||||
end_k = cu_seqlens_k[i + 1].item()
|
||||
|
||||
# Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i
|
||||
q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
|
||||
k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
|
||||
v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
|
||||
do_i = do[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
|
||||
o_i = o[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
|
||||
# softmax_lse has shape [total_L_q, num_heads]
|
||||
softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, num_heads]
|
||||
softmax_lse_i = softmax_lse_i.transpose(0, 1) # [num_heads, L_q_i]
|
||||
|
||||
# Permute to [num_heads, L_q_i, head_dim]
|
||||
q_i = q_i.permute(1, 0, 2)
|
||||
k_i = k_i.permute(1, 0, 2)
|
||||
v_i = v_i.permute(1, 0, 2)
|
||||
do_i = do_i.permute(1, 0, 2)
|
||||
o_i = o_i.permute(1, 0, 2)
|
||||
# softmax_lse_i is already in [num_heads, L_q_i]
|
||||
|
||||
# Call the core backward function for this sequence
|
||||
dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl(
|
||||
do_i,
|
||||
q_i,
|
||||
k_i,
|
||||
v_i,
|
||||
o_i,
|
||||
softmax_lse_i,
|
||||
sm_scale,
|
||||
causal,
|
||||
use_exp2
|
||||
)
|
||||
|
||||
# Convert back to 'thd' layout
|
||||
dq_i = dq_i.permute(1, 0, 2) # [L_q_i, num_heads, head_dim]
|
||||
dk_i = dk_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
|
||||
dv_i = dv_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim]
|
||||
|
||||
# Place outputs in pre-allocated tensors
|
||||
dq[start_q:end_q, :, :] = dq_i
|
||||
dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys
|
||||
dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values
|
||||
# delta_i has shape [num_heads, L_q_i]
|
||||
delta_i = delta_i.transpose(1, 0) # [L_q_i, num_heads]
|
||||
delta[start_q:end_q, :] = delta_i
|
||||
|
||||
return dq, dk, dv, delta
|
||||
|
||||
def attention_vanilla_backward_pytorch_ref_impl(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
softmax_lse,
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
use_exp2,
|
||||
):
|
||||
if layout == "bshd":
|
||||
do = do.transpose(1, 2).contiguous()
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
v = v.transpose(1, 2).contiguous()
|
||||
o = o.transpose(1, 2).contiguous()
|
||||
elif layout == "bhsd":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown layout {layout}")
|
||||
|
||||
# Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
|
||||
batch_size, num_heads, seq_len_q, head_dim = q.shape
|
||||
seq_len_k = k.shape[2]
|
||||
|
||||
# Merge batch and heads dimensions
|
||||
do = do.reshape(batch_size * num_heads, seq_len_q, head_dim)
|
||||
q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
|
||||
k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
|
||||
v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
|
||||
softmax_lse = softmax_lse.reshape(batch_size * num_heads, seq_len_q)
|
||||
o = o.reshape(batch_size * num_heads, seq_len_q, head_dim)
|
||||
|
||||
dq, dk, dv, delta = attention_backward_core_ref_impl(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
softmax_lse,
|
||||
sm_scale,
|
||||
causal,
|
||||
use_exp2
|
||||
)
|
||||
|
||||
# Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
|
||||
dq = dq.reshape(batch_size, num_heads, seq_len_q, head_dim)
|
||||
dk = dk.reshape(batch_size, num_heads, seq_len_k, head_dim)
|
||||
dv = dv.reshape(batch_size, num_heads, seq_len_k, head_dim)
|
||||
delta = delta.reshape(batch_size, num_heads, seq_len_q)
|
||||
|
||||
# Go back to original layout
|
||||
if layout == "bshd":
|
||||
dq = dq.transpose(1, 2)
|
||||
dk = dk.transpose(1, 2)
|
||||
dv = dv.transpose(1, 2)
|
||||
elif layout == "bhsd":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown layout {layout}")
|
||||
|
||||
return dq, dk, dv, delta
|
||||
|
||||
|
||||
def attention_backward_pytorch_ref_impl(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
softmax_lse,
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
use_exp2
|
||||
):
|
||||
if layout == "thd":
|
||||
dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
softmax_lse,
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
use_exp2,
|
||||
)
|
||||
else:
|
||||
dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
softmax_lse,
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
use_exp2,
|
||||
)
|
||||
|
||||
return dq, dk, dv, delta
|
||||
700
comfy/flash_attn_triton_amd/fwd_decode.py
Normal file
700
comfy/flash_attn_triton_amd/fwd_decode.py
Normal file
@ -0,0 +1,700 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from comfy.flash_attn_triton_amd.utils import _strides, get_padded_headsize
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_splitK(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
K_new,
|
||||
V_new,
|
||||
Cache_seqlens,
|
||||
Cache_batch_idx,
|
||||
Alibi_slopes,
|
||||
stride_qz,
|
||||
stride_qm,
|
||||
stride_qg,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
stride_kz,
|
||||
stride_kn,
|
||||
stride_kg,
|
||||
stride_kh,
|
||||
stride_kd,
|
||||
stride_vz,
|
||||
stride_vn,
|
||||
stride_vg,
|
||||
stride_vh,
|
||||
stride_vd,
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_d, # pylint: disable=unused-argument
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm, # pylint: disable=unused-argument
|
||||
stride_kn_z,
|
||||
stride_kn_n,
|
||||
stride_kn_g,
|
||||
stride_kn_h,
|
||||
stride_kn_d,
|
||||
stride_vn_z,
|
||||
stride_vn_n,
|
||||
stride_vn_g,
|
||||
stride_vn_h,
|
||||
stride_vn_d,
|
||||
stride_az,
|
||||
stride_ah,
|
||||
Z, # pylint: disable=unused-argument
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
N_CTX_NEW,
|
||||
BLOCK_N_PER_SPLIT,
|
||||
H_q: tl.constexpr,
|
||||
H_kv: tl.constexpr,
|
||||
G_q: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
USE_CACHE_SEQLENs: tl.constexpr,
|
||||
USE_CACHE_BATCH_IDX: tl.constexpr,
|
||||
NEW_KV: tl.constexpr,
|
||||
IS_GQA: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
USE_ALIBI: tl.constexpr,
|
||||
):
|
||||
# Padding
|
||||
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
if PADDED_HEAD:
|
||||
d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL
|
||||
|
||||
start_m = tl.program_id(0)
|
||||
off_zhg = tl.program_id(1)
|
||||
off_z = off_zhg // (H_q * G_q)
|
||||
off_h_q = (off_zhg // G_q) % H_q
|
||||
off_g_q = off_zhg % G_q
|
||||
splitk_idx = tl.program_id(2)
|
||||
|
||||
# pick batch index
|
||||
if USE_CACHE_BATCH_IDX:
|
||||
cache_batch_idx = tl.load(Cache_batch_idx + off_z)
|
||||
else:
|
||||
cache_batch_idx = off_z
|
||||
|
||||
# Load ALiBi slope if enabled
|
||||
if USE_ALIBI:
|
||||
a_offset = off_z * stride_az + off_h_q * stride_ah
|
||||
alibi_slope = tl.load(Alibi_slopes + a_offset)
|
||||
else:
|
||||
alibi_slope = None
|
||||
|
||||
lo = splitk_idx * BLOCK_N_PER_SPLIT
|
||||
if USE_CACHE_SEQLENs:
|
||||
cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z)
|
||||
if NEW_KV:
|
||||
kv_len = cache_seqlen_last_idx + N_CTX_NEW
|
||||
else:
|
||||
kv_len = cache_seqlen_last_idx
|
||||
else:
|
||||
kv_len = N_CTX_K
|
||||
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
|
||||
|
||||
HEAD_RATIO: tl.constexpr = H_q // H_kv
|
||||
if IS_GQA:
|
||||
k_head_idx = off_h_q // HEAD_RATIO
|
||||
v_head_idx = k_head_idx
|
||||
else:
|
||||
k_head_idx = off_h_q
|
||||
v_head_idx = off_h_q
|
||||
|
||||
# calculate base offset
|
||||
k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg
|
||||
v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg
|
||||
|
||||
# Copy new Keys and Values into Cache
|
||||
if NEW_KV:
|
||||
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g
|
||||
|
||||
# Determine the starting position for new data in the cache
|
||||
if USE_CACHE_SEQLENs:
|
||||
start_idx = tl.load(Cache_seqlens + off_z)
|
||||
else:
|
||||
start_idx = N_CTX_K - N_CTX_NEW
|
||||
|
||||
# Copy new Keys
|
||||
for i in range(0, N_CTX_NEW, BLOCK_N):
|
||||
# Load from K_new
|
||||
k_new_block = tl.load(
|
||||
knew_base +
|
||||
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d +
|
||||
(tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n,
|
||||
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
|
||||
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
|
||||
other=0
|
||||
)
|
||||
|
||||
# Store to K
|
||||
tl.store(
|
||||
k_base +
|
||||
tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd +
|
||||
(tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn,
|
||||
k_new_block,
|
||||
mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) &
|
||||
(tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL),
|
||||
)
|
||||
|
||||
# Copy new Values
|
||||
vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g
|
||||
for i in range(0, N_CTX_NEW, BLOCK_N):
|
||||
# Load from V_new
|
||||
v_new_block = tl.load(
|
||||
vnew_base +
|
||||
(tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n +
|
||||
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d,
|
||||
mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
|
||||
(tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
|
||||
other=0
|
||||
)
|
||||
|
||||
# Store to V
|
||||
tl.store(
|
||||
v_base +
|
||||
(tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn +
|
||||
tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd,
|
||||
v_new_block,
|
||||
mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) &
|
||||
(tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL),
|
||||
)
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg,
|
||||
shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qd),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=k_base,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, hi),
|
||||
strides=(stride_kd, stride_kn),
|
||||
offsets=(0, lo),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=v_base,
|
||||
shape=(hi, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vd),
|
||||
offsets=(lo, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
K_scale_shift_block_ptr = None
|
||||
V_scale_shift_block_ptr = None
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821
|
||||
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load( # noqa: F821
|
||||
tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, ))
|
||||
q = (q * qk_scale).to(q.dtype)
|
||||
if PADDED_HEAD:
|
||||
q = tl.where(d_mask[None, :], q, 0.0)
|
||||
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
k, v = load_k_v_group(
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr,
|
||||
V_scale_shift_block_ptr,
|
||||
BOUNDS_CHECKS_N,
|
||||
1,
|
||||
BLOCK_DMODEL,
|
||||
ACTUAL_BLOCK_DMODEL,
|
||||
Q.dtype.element_ty,
|
||||
0,
|
||||
)
|
||||
if PADDED_HEAD:
|
||||
k = tl.where(d_mask[:, None], k, 0.0)
|
||||
v = tl.where(d_mask[None, :], v, 0.0)
|
||||
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k) # noqa: F821
|
||||
|
||||
if USE_ALIBI:
|
||||
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
col_idx = start_n + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Compute relative positions
|
||||
relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :])
|
||||
relative_pos = tl.abs(relative_pos)
|
||||
|
||||
# Compute ALiBi bias
|
||||
alibi_bias = -1 * alibi_slope * relative_pos
|
||||
qk += (alibi_bias * 1.44269504)
|
||||
|
||||
# Apply causal mask if IS_CAUSAL is True
|
||||
if IS_CAUSAL:
|
||||
row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
col_idx = start_n + tl.arange(0, BLOCK_N)
|
||||
|
||||
# create a N_CTX_Q x kv_len causal mask
|
||||
col_offset = N_CTX_Q - kv_len
|
||||
causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :])
|
||||
|
||||
# Apply the mask
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
|
||||
# TODO: This is slow, and only needed at the last iteration.
|
||||
# Maybe we can unroll the last iteration instead?
|
||||
if BOUNDS_CHECKS_N:
|
||||
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
|
||||
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
if IS_CAUSAL:
|
||||
alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf")))
|
||||
else:
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
# cause of nan because subtracting infs
|
||||
if IS_CAUSAL:
|
||||
qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf"))
|
||||
else:
|
||||
qk = qk - m_i_new[:, None]
|
||||
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
m_i = m_i_new
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
|
||||
# -- scale and update acc --
|
||||
acc *= alpha[:, None]
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
|
||||
shape=(N_CTX_Q, BLOCK_DMODEL),
|
||||
strides=(stride_osk_m, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
tl.store(
|
||||
tl.advance(O_block_ptr, (0, 0)),
|
||||
acc,
|
||||
boundary_check=(0, ),
|
||||
)
|
||||
# Write metadata for split-K reduction
|
||||
Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M +
|
||||
tl.arange(0, BLOCK_M))
|
||||
tl.store(Metadata_ptr, m_i)
|
||||
tl.store(Metadata_ptr + stride_m2, l_i)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_k_v_group(
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr, V_scale_shift_block_ptr, # pylint: disable=unused-argument
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
PACKED_PER_VAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # pylint: disable=unused-argument
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
dtype: tl.constexpr, # pylint: disable=unused-argument
|
||||
group_id: tl.constexpr,
|
||||
):
|
||||
# Load K/V for a given block
|
||||
# Advance to the current quantization group
|
||||
K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id))
|
||||
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ())
|
||||
v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ())
|
||||
|
||||
return k, v
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cast_uint32_to_half2(scale_shift):
|
||||
# Extract two float16 packed into one int32
|
||||
scale = scale_shift & 0xFFFF
|
||||
shift = scale_shift >> 16
|
||||
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
return scale, shift
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dequantize(
|
||||
x_,
|
||||
scale,
|
||||
shift,
|
||||
PACKED_PER_VAL: tl.constexpr = 8,
|
||||
):
|
||||
# PACKED_PER_VAL is the number of values packed into
|
||||
# each element x_. For example, for int4 quantization
|
||||
#and x_ of type int32, PACKED_PER_VAL is 8.
|
||||
|
||||
BLOCK_N: tl.constexpr = x_.shape[0]
|
||||
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
|
||||
offsets = tl.arange(0, PACKED_PER_VAL) * 4
|
||||
quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
|
||||
|
||||
quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL))
|
||||
# Trick - instead of converting int4 to float16 we view it as float16
|
||||
# and then multiply by 32768 * 512 == 2**24
|
||||
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
quant_offset = (quant_offset * 32768.0).to(tl.float16)
|
||||
scale_512 = scale * 512
|
||||
|
||||
dequant = quant_offset * scale_512 + shift
|
||||
return dequant
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _splitK_reduce(
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
Out, # [B, H, M, K]
|
||||
LSE, # [B, H, M]
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_k,
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_og,
|
||||
stride_om,
|
||||
stride_ok, # pylint: disable=unused-argument
|
||||
stride_lse_zhg,
|
||||
stride_lse_m, M_ceil: tl.constexpr, # pylint: disable=unused-argument
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
G: tl.constexpr,
|
||||
split_k: tl.constexpr,
|
||||
splitK_pow2: tl.constexpr,
|
||||
use_mask: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
):
|
||||
off_zhg = tl.program_id(0)
|
||||
off_z = off_zhg // (H * G)
|
||||
off_h = (off_zhg // G) % H
|
||||
off_g = off_zhg % G
|
||||
off_m = tl.program_id(1)
|
||||
off_k = tl.program_id(2)
|
||||
|
||||
# read chunk
|
||||
spk_idx = tl.arange(0, splitK_pow2)
|
||||
kidx = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
Metadata_ptr = Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm
|
||||
|
||||
o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE +
|
||||
stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k)
|
||||
|
||||
# read max values of each splitK
|
||||
if use_mask:
|
||||
spk_mask = spk_idx < split_k
|
||||
l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf"))
|
||||
l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0)
|
||||
acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0)
|
||||
else:
|
||||
l_m = tl.load(Metadata_ptr)
|
||||
l_sum = tl.load(Metadata_ptr + stride_m2)
|
||||
acc = tl.load(o_ptr)
|
||||
|
||||
g_m = tl.max(l_m, axis=0)
|
||||
|
||||
if IS_CAUSAL:
|
||||
l_m_offset = l_m - g_m
|
||||
alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0)
|
||||
else:
|
||||
alpha = tl.math.exp2(l_m - g_m)
|
||||
|
||||
# read sum
|
||||
l_sum *= alpha
|
||||
g_sum = tl.sum(l_sum, axis=0)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
if IS_CAUSAL:
|
||||
# Avoid division by zero
|
||||
g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0)
|
||||
acc_out = tl.sum(acc, axis=0) / g_sum_safe
|
||||
else:
|
||||
acc_out = tl.sum(acc, axis=0) / g_sum
|
||||
|
||||
# Store output
|
||||
Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m +
|
||||
off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE))
|
||||
tl.store(Out_ptr, acc_out)
|
||||
|
||||
# Store lse
|
||||
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
|
||||
if IS_CAUSAL:
|
||||
lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m)
|
||||
tl.store(l_ptrs, lse)
|
||||
else:
|
||||
tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504)
|
||||
|
||||
|
||||
def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
|
||||
# Scale and shift are such that quantization linearly maps
|
||||
# int4 values range [0..15] to input values range min(k)..max(k)
|
||||
# individually for every row
|
||||
k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups)
|
||||
max_vals = torch.max(k, dim=-1, keepdim=True).values
|
||||
min_vals = torch.min(k, dim=-1, keepdim=True).values
|
||||
scale_k: torch.Tensor = (max_vals - min_vals) / 15
|
||||
|
||||
shift_k = torch.min(k, dim=-1, keepdim=True).values
|
||||
scale_k = scale_k.to(torch.float16)
|
||||
shift_k = shift_k.to(torch.float16)
|
||||
|
||||
in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5
|
||||
in_bytes = in_bytes.to(torch.uint8)
|
||||
in_int4 = in_bytes & 0xF
|
||||
in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4)
|
||||
scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1)
|
||||
k_quant = torch.concat(
|
||||
[
|
||||
scale_shift.flatten(start_dim=-2),
|
||||
in_int4_packed.flatten(start_dim=-2),
|
||||
],
|
||||
dim=-1,
|
||||
).view(torch.int16)
|
||||
return k_quant
|
||||
|
||||
|
||||
def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor:
|
||||
k_i16 = quant_k.view(torch.int16)
|
||||
k_ui8 = k_i16.view(torch.uint8)
|
||||
|
||||
ss_size = num_groups * 4
|
||||
scale_shift_ui8 = k_ui8[..., 0:ss_size]
|
||||
scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4)
|
||||
scale = scale_shift_ui8[..., 0:2].view(torch.float16)
|
||||
shift = scale_shift_ui8[..., 2:4].view(torch.float16)
|
||||
|
||||
kv_ui8 = k_ui8[..., ss_size:]
|
||||
k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1)
|
||||
k1_i4 = k_ui8 & 0xF
|
||||
k2_i4 = (k_ui8 & 0xF0) >> 4
|
||||
k_shape = k1_i4.shape
|
||||
k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
|
||||
k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape)
|
||||
|
||||
out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device)
|
||||
out[..., ::2] = k1_f16
|
||||
out[..., 1::2] = k2_f16
|
||||
out = out.reshape(*k_shape[:-2], -1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
|
||||
"""Heuristic for the number of splits"""
|
||||
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
|
||||
split_k = max(Mk, 1024) // bh
|
||||
max_chunk_size = 64
|
||||
while split_k > 0 and Mk / split_k < max_chunk_size:
|
||||
split_k = split_k // 2
|
||||
while B * H * G * split_k >= 1024:
|
||||
split_k = split_k // 2
|
||||
split_k = min(split_k, 512)
|
||||
split_k = max(split_k, 1)
|
||||
return split_k
|
||||
|
||||
def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
|
||||
# kernel config
|
||||
BLOCK_M = 16
|
||||
BLOCK_N = 64
|
||||
SPLIT_K = None
|
||||
NUM_QUANT_GROUPS = 1 # pylint: disable=unused-variable
|
||||
|
||||
# kernels expects "bsghd"
|
||||
original_layout = layout
|
||||
if layout == "bshd":
|
||||
q = q.unsqueeze(2)
|
||||
k = k.unsqueeze(2)
|
||||
v = v.unsqueeze(2)
|
||||
if new_kv:
|
||||
k_new = k_new.unsqueeze(2)
|
||||
v_new = v_new.unsqueeze(2)
|
||||
layout = "bsghd"
|
||||
elif layout == "bhsd":
|
||||
q = q.permute(0, 2, 1, 3).unsqueeze(2)
|
||||
k = k.permute(0, 2, 1, 3).unsqueeze(2)
|
||||
v = v.permute(0, 2, 1, 3).unsqueeze(2)
|
||||
if new_kv:
|
||||
k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2)
|
||||
v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2)
|
||||
layout = "bsghd"
|
||||
elif layout == "bsghd":
|
||||
pass
|
||||
elif layout is None:
|
||||
raise ValueError("Layout not given")
|
||||
assert layout == "bsghd"
|
||||
|
||||
# get dims
|
||||
batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
|
||||
_, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape # pylint: disable=unused-variable
|
||||
_, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape # pylint: disable=unused-variable
|
||||
|
||||
assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}"
|
||||
|
||||
# get padded size
|
||||
dim_padded = get_padded_headsize(dim_k)
|
||||
|
||||
# Handle MQA/GQA case
|
||||
if heads_per_group_q > heads_per_group_k:
|
||||
is_gqa = True
|
||||
elif heads_per_group_q < heads_per_group_k:
|
||||
raise ValueError("heads_per_group_q < heads_per_group_k")
|
||||
else:
|
||||
is_gqa = False
|
||||
|
||||
assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}"
|
||||
|
||||
if SPLIT_K is not None:
|
||||
split_k = SPLIT_K
|
||||
else:
|
||||
# Use heuristics
|
||||
split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens?
|
||||
|
||||
seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M
|
||||
out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device)
|
||||
metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device)
|
||||
lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32)
|
||||
grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k)
|
||||
|
||||
num_warps = 1
|
||||
split_size = (seqlen_k + split_k - 1) // split_k
|
||||
use_cache_seqlens = cache_seqlens is not None
|
||||
|
||||
# TODO: enable quantization
|
||||
_fwd_kernel_splitK[grid](
|
||||
Q=q,
|
||||
K=k,
|
||||
V=v,
|
||||
sm_scale=sm_scale,
|
||||
Out_splitK=out_splitk,
|
||||
Metadata=metadata,
|
||||
K_new = k_new,
|
||||
V_new = v_new,
|
||||
Cache_seqlens=cache_seqlens,
|
||||
Cache_batch_idx=cache_batch_idx,
|
||||
Alibi_slopes=alibi_slopes,
|
||||
**_strides(q, "qz", "qm", "qg", "qh", "qd"),
|
||||
**_strides(k, "kz", "kn", "kg", "kh", "kd"),
|
||||
**_strides(v, "vz", "vn", "vg", "vh", "vd"),
|
||||
**_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
**_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"),
|
||||
**_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"),
|
||||
**_strides(alibi_slopes, "az", "ah"),
|
||||
Z=batch_size,
|
||||
H_q=heads_per_group_q,
|
||||
H_kv=heads_per_group_k,
|
||||
G_q=n_group_q,
|
||||
N_CTX_Q=seqlen_q,
|
||||
N_CTX_K=seqlen_k,
|
||||
N_CTX_NEW=k_new.shape[1] if new_kv else None,
|
||||
BLOCK_N_PER_SPLIT=split_size,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=dim_padded,
|
||||
ACTUAL_BLOCK_DMODEL=dim_k,
|
||||
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens,
|
||||
USE_CACHE_SEQLENs=use_cache_seqlens,
|
||||
USE_CACHE_BATCH_IDX=cache_batch_idx is not None,
|
||||
NEW_KV=new_kv,
|
||||
IS_GQA=is_gqa,
|
||||
IS_CAUSAL=causal,
|
||||
USE_ALIBI=False if alibi_slopes is None else True,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype)
|
||||
|
||||
# Merge together
|
||||
splitK_pow2 = triton.next_power_of_2(split_k)
|
||||
use_mask = splitK_pow2 > split_k
|
||||
if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512:
|
||||
k_block_num = 1
|
||||
else:
|
||||
k_block_num = 2
|
||||
assert dim_padded % k_block_num == 0
|
||||
k_block_size = dim_padded // k_block_num
|
||||
grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num)
|
||||
|
||||
_splitK_reduce[grid](
|
||||
out_splitk,
|
||||
metadata,
|
||||
out,
|
||||
lse,
|
||||
**_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
**_strides(out, "oz", "om", "og", "oh", "ok"),
|
||||
**_strides(lse, "lse_zhg", "lse_m"),
|
||||
M_ceil=seqlen_q_ceil,
|
||||
BLOCK_SIZE=k_block_size,
|
||||
G=n_group_q,
|
||||
H=heads_per_group_q,
|
||||
# TODO: Tune num_warps
|
||||
split_k=split_k,
|
||||
splitK_pow2=splitK_pow2,
|
||||
use_mask=use_mask,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=4)
|
||||
|
||||
lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q])
|
||||
if q.ndim == 4:
|
||||
# BMGHK -> BMHK
|
||||
assert n_group_q == 1
|
||||
out = out[:, :, 0]
|
||||
lse = lse[:, 0]
|
||||
if seqlen_k == 0:
|
||||
out.zero_()
|
||||
out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous()
|
||||
|
||||
# output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q
|
||||
if original_layout == "bshd":
|
||||
# out=out.transpose(1, 2).contiguous() # this screws up heads and data.
|
||||
# the data is laid out properly. Just need to reshape dims
|
||||
out = out.reshape(batch_size, seqlen_q, -1, dim_padded)
|
||||
|
||||
return out.narrow(-1, 0, dim_k), lse
|
||||
634
comfy/flash_attn_triton_amd/fwd_prefill.py
Normal file
634
comfy/flash_attn_triton_amd/fwd_prefill.py
Normal file
@ -0,0 +1,634 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from comfy.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, AUTOTUNE
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): # pylint: disable=unused-argument
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
# Convenience function to load with optional boundary checks.
|
||||
# "First" is the major dim, "second" is the minor dim.
|
||||
@triton.jit
|
||||
def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second):
|
||||
if offset_first is not None and offset_second is not None:
|
||||
mask = (offset_first[:, None] < boundary_first) & \
|
||||
(offset_second[None, :] < boundary_second)
|
||||
tensor = tl.load(ptrs, mask=mask, other=0.0)
|
||||
elif offset_first is not None:
|
||||
mask = offset_first[:, None] < boundary_first
|
||||
tensor = tl.load(ptrs, mask=mask, other=0.0)
|
||||
elif offset_second is not None:
|
||||
mask = offset_second[None, :] < boundary_second
|
||||
tensor = tl.load(ptrs, mask=mask, other=0.0)
|
||||
else:
|
||||
tensor = tl.load(ptrs)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False):
|
||||
# when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix
|
||||
# for casual mask we want something like this where (1 is kept and 0 is masked)
|
||||
# seqlen_q = 2 and seqlen_k = 5
|
||||
# 1 1 1 1 0
|
||||
# 1 1 1 1 1
|
||||
# seqlen_q = 5 and seqlen_k = 2
|
||||
# 0 0
|
||||
# 0 0
|
||||
# 0 0
|
||||
# 1 0
|
||||
# 1 1
|
||||
# for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal
|
||||
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
|
||||
# 1. offs_m[:,None] = [[0],
|
||||
# [1],
|
||||
# 2. offs_m[:,None] + seqlen_k = [[5],
|
||||
# [6],
|
||||
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
|
||||
# [4],
|
||||
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1],
|
||||
# [4], [ 4, 3, 2, 1, 0]]
|
||||
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
|
||||
# [ -4, -3, -2, -1, 0]],
|
||||
relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :]
|
||||
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
|
||||
if transpose:
|
||||
return alibi_block.T
|
||||
else:
|
||||
return alibi_block
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m,
|
||||
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs,
|
||||
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, # pylint: disable=unused-argument
|
||||
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr,
|
||||
RETURN_SCORES: tl.constexpr):
|
||||
if USE_EXP2:
|
||||
RCP_LN2: tl.constexpr = 1.4426950408889634
|
||||
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
if MASK_STEPS:
|
||||
k_offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
else:
|
||||
k_offs_n = None
|
||||
k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
|
||||
k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k)
|
||||
if PRE_LOAD_V:
|
||||
# We can use the same offsets as k, just with dims transposed.
|
||||
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS:
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn.
|
||||
# last step might get wasted but that is okay. check if this masking works For
|
||||
# that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
qk_scaled = qk * SM_SCALE
|
||||
if RETURN_SCORES:
|
||||
score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
|
||||
tl.store(score_ptrs, qk_scaled, mask=score_mask)
|
||||
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf"))
|
||||
if bias_ptrs is not None:
|
||||
bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None
|
||||
bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k)
|
||||
qk_scaled += bias
|
||||
|
||||
if alibi_slope is not None:
|
||||
# Compute the global position of each token within the sequence
|
||||
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
global_n_positions = start_n + tl.arange(0, BLOCK_N)
|
||||
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions,
|
||||
global_n_positions)
|
||||
qk_scaled += alibi_block
|
||||
# get max scores so far
|
||||
m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1))
|
||||
|
||||
# scale and subtract max
|
||||
q_shifted = qk_scaled - m_ij[:, None]
|
||||
if RETURN_SCORES:
|
||||
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
|
||||
scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
|
||||
tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask)
|
||||
|
||||
# Compute scaled QK and softmax probabilities
|
||||
if USE_EXP2:
|
||||
p = tl.math.exp2(q_shifted * RCP_LN2)
|
||||
else:
|
||||
p = tl.math.exp(q_shifted)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N
|
||||
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k)
|
||||
if RETURN_SCORES:
|
||||
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
|
||||
exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
|
||||
tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_SCORES:
|
||||
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
|
||||
exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
|
||||
tl.store(exp_scores_ptrs, p, mask=exp_score_mask)
|
||||
|
||||
# -- update output accumulator --
|
||||
# alpha is an adjustment factor for acc and li as we loop and find new maxes
|
||||
# store the diff in maxes to adjust acc and li as we discover new maxes
|
||||
m_diff = m_i - m_ij
|
||||
if USE_EXP2:
|
||||
alpha = tl.math.exp2(m_diff * RCP_LN2)
|
||||
else:
|
||||
alpha = tl.math.exp(m_diff)
|
||||
acc = acc * alpha[:, None]
|
||||
v = None
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
acc += tl.dot(p.to(v.type.element_ty), v)
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
if bias_ptrs is not None:
|
||||
bias_ptrs += BLOCK_N * stride_bn
|
||||
if RETURN_SCORES:
|
||||
score_ptrs += BLOCK_N
|
||||
scores_scaled_shifted_ptrs += BLOCK_N
|
||||
exp_scores_ptrs += BLOCK_N
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
def get_cdna_autotune_configs():
|
||||
return [
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=4),
|
||||
# Fall-back config.
|
||||
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=4),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
|
||||
|
||||
|
||||
def get_rdna_autotune_configs():
|
||||
return [
|
||||
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=2),
|
||||
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=2),
|
||||
# Fall-back config.
|
||||
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
|
||||
num_warps=2),
|
||||
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
|
||||
|
||||
|
||||
def get_autotune_configs():
|
||||
if AUTOTUNE:
|
||||
if is_rdna():
|
||||
return get_rdna_autotune_configs()
|
||||
elif is_cdna():
|
||||
return get_cdna_autotune_configs()
|
||||
else:
|
||||
raise ValueError("Unknown Device Type")
|
||||
else:
|
||||
return [
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
], [
|
||||
"IS_CAUSAL",
|
||||
"dropout_p",
|
||||
"MAX_SEQLENS_Q",
|
||||
"MAX_SEQLENS_K",
|
||||
"ACTUAL_BLOCK_DMODEL",
|
||||
"VARLEN",
|
||||
"HQ",
|
||||
"HK",
|
||||
]
|
||||
|
||||
|
||||
autotune_configs, autotune_keys = get_autotune_configs()
|
||||
|
||||
@triton.autotune(
|
||||
configs=autotune_configs,
|
||||
key=autotune_keys,
|
||||
# use_cuda_graph=True,
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, # pylint: disable=unused-argument
|
||||
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
|
||||
dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr,
|
||||
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
# print("cu_seqlens_q_start:", cu_seqlens_q_start)
|
||||
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn matrix
|
||||
n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is part of
|
||||
# the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
|
||||
o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
o_ptrs_mask = offs_m[:, None] < seqlen_q
|
||||
# We still need to write 0s to the result
|
||||
tl.store(o_ptrs, acc, mask=o_ptrs_mask)
|
||||
# The tensor allocated for L is based on MAX_SEQLENS_Q as that is
|
||||
# statically known.
|
||||
l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
|
||||
l_ptrs = l_offset + offs_m * stride_lse_m
|
||||
|
||||
l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32)
|
||||
|
||||
# mask_m_offsets = start_m + tl.arange(0, BLOCK_M)
|
||||
# lse_mask = mask_m_offsets < causal_start_idx
|
||||
# softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
|
||||
l_ptrs_mask = offs_m < MAX_SEQLENS_Q
|
||||
tl.store(l_ptrs, l, mask=l_ptrs_mask)
|
||||
# TODO: Should dropout and return encoded softmax be handled here too?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
if GROUP_SIZE != 1:
|
||||
off_h_k = off_h_q // GROUP_SIZE
|
||||
else:
|
||||
off_h_k = off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
# print("n_extra_tokens:", n_extra_tokens)
|
||||
# print("seqlen_k:", seqlen_k)
|
||||
# print("BLOCK_N:", BLOCK_N)
|
||||
# return
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
||||
q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
||||
k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn
|
||||
v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
||||
v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn
|
||||
if USE_BIAS:
|
||||
# Note: this might get large enough to overflow on some configs
|
||||
bias_offset = off_h_q * stride_bh
|
||||
bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn
|
||||
else:
|
||||
bias_ptrs = None
|
||||
|
||||
if USE_ALIBI:
|
||||
a_offset = off_z * stride_az + off_h_q * stride_ah
|
||||
alibi_slope = tl.load(alibi_slopes + a_offset)
|
||||
else:
|
||||
alibi_slope = None
|
||||
|
||||
if RETURN_SCORES:
|
||||
scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
|
||||
score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
|
||||
|
||||
scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
|
||||
scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
|
||||
|
||||
exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm
|
||||
exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
|
||||
else:
|
||||
score_ptrs = None
|
||||
scores_scaled_shifted_ptrs = None
|
||||
exp_scores_ptrs = None
|
||||
|
||||
if ENABLE_DROPOUT:
|
||||
off_hz = off_z * HQ + off_h_q
|
||||
batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q_ptrs_mask = offs_m[:, None] < seqlen_q
|
||||
if PADDED_HEAD:
|
||||
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
|
||||
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional block.
|
||||
# In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its actual
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
|
||||
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
|
||||
exp_scores_ptrs,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
|
||||
# IS_CAUSAL, ....
|
||||
False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD,
|
||||
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
if IS_CAUSAL:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k)
|
||||
else:
|
||||
offs_n_causal = 0
|
||||
k_ptrs += n_full_blocks * BLOCK_N * stride_kn
|
||||
v_ptrs += n_full_blocks * BLOCK_N * stride_vk
|
||||
if USE_BIAS:
|
||||
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
|
||||
if RETURN_SCORES:
|
||||
score_ptrs += n_full_blocks * BLOCK_N
|
||||
scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N
|
||||
exp_scores_ptrs += n_full_blocks * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn,
|
||||
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset,
|
||||
exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
|
||||
n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs,
|
||||
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD,
|
||||
ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES)
|
||||
# epilogue
|
||||
# This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
|
||||
l_recip = 1 / l_i[:, None]
|
||||
acc = acc * l_recip
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL:
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
||||
z: tl.tensor = 0.0
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
|
||||
# write back LSE(Log Sum Exponents), the log of the normalization constant
|
||||
l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
|
||||
l_ptrs = l_offset + offs_m * stride_lse_m
|
||||
if USE_EXP2:
|
||||
RCP_LN2: tl.constexpr = 1.4426950408889634
|
||||
LN2: tl.constexpr = 0.6931471824645996
|
||||
# compute log-sum-exp in base 2 units
|
||||
mi_base2 = m_i * RCP_LN2
|
||||
softmax_lse = mi_base2 + tl.math.log2(l_i)
|
||||
# convert back to natural units
|
||||
softmax_lse *= LN2
|
||||
else:
|
||||
softmax_lse = m_i + tl.math.log(l_i)
|
||||
|
||||
if IS_CAUSAL:
|
||||
# zero out nans caused by -infs when doing causal
|
||||
lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx
|
||||
softmax_lse = tl.where(lse_mask, 0.0, softmax_lse)
|
||||
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
|
||||
# This is only true for the last M block. For others, overflow_size will be -ve
|
||||
overflow_size = end_m_idx - seqlen_q
|
||||
if overflow_size > 0:
|
||||
boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
|
||||
tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant
|
||||
else:
|
||||
tl.store(l_ptrs, softmax_lse) # the log of the normalization constant
|
||||
|
||||
# write back O
|
||||
o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om
|
||||
o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on
|
||||
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
|
||||
if overflow_size > 0:
|
||||
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
|
||||
if PADDED_HEAD:
|
||||
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
|
||||
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
|
||||
|
||||
|
||||
def attention_prefill_forward_triton_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
sm_scale,
|
||||
alibi_slopes,
|
||||
causal,
|
||||
bias,
|
||||
dropout_p,
|
||||
layout,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
return_scores,
|
||||
use_exp2):
|
||||
# check if varlen
|
||||
is_varlen = layout == "thd"
|
||||
|
||||
# NOTE: a large bias tensor leads to overflow during pointer arithmetic
|
||||
if bias is not None:
|
||||
assert bias.numel() < 2**31
|
||||
|
||||
batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) # pylint: disable=unused-variable
|
||||
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
padded_d_model = 1 << (head_size - 1).bit_length()
|
||||
# Smallest head_dim supported is 16. If smaller, the tile in the
|
||||
# kernel is padded - there is no padding in memory for any dims.
|
||||
padded_d_model = max(padded_d_model, 16)
|
||||
|
||||
grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # pylint: disable=unnecessary-lambda-assignment
|
||||
|
||||
if return_scores:
|
||||
scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
|
||||
dtype=torch.float32)
|
||||
scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
|
||||
dtype=torch.float32)
|
||||
scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3))
|
||||
else:
|
||||
scores = None
|
||||
scores_scaled_shifted = None
|
||||
scores_strides = (0, 0 , 0 , 0)
|
||||
|
||||
# exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out
|
||||
# to give a consistent starting point and then populate it with the output of softmax with the sign bit set according
|
||||
# to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing
|
||||
# only. This return holds no useful output aside from debugging.
|
||||
if return_scores:
|
||||
exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device,
|
||||
dtype=torch.float32)
|
||||
else:
|
||||
exp_scores = None
|
||||
|
||||
# stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities)
|
||||
if is_varlen:
|
||||
softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32)
|
||||
stride_lse_m, stride_lse_h = softmax_lse.stride()
|
||||
stride_lse_z = 0
|
||||
else:
|
||||
softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32)
|
||||
stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride()
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2),
|
||||
bias.stride(3))
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1))
|
||||
else:
|
||||
alibi_strides = (0, 0)
|
||||
|
||||
|
||||
attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
|
||||
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
|
||||
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores,
|
||||
scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes,
|
||||
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
|
||||
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
|
||||
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
|
||||
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
|
||||
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores)
|
||||
|
||||
return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted
|
||||
258
comfy/flash_attn_triton_amd/fwd_ref.py
Normal file
258
comfy/flash_attn_triton_amd/fwd_ref.py
Normal file
@ -0,0 +1,258 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):
|
||||
# Compute attention scores
|
||||
attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32))
|
||||
|
||||
# Scale scores
|
||||
attention_scaled_scores = sm_scale * attention_scores
|
||||
|
||||
# Apply causal mask if necessary
|
||||
if causal:
|
||||
L_q, L_k = q.shape[1], k.shape[1]
|
||||
row_idx = torch.arange(L_q, device=q.device).unsqueeze(1)
|
||||
col_idx = torch.arange(L_k, device=q.device).unsqueeze(0)
|
||||
col_offset = L_q-L_k
|
||||
causal_mask = row_idx >= (col_offset + col_idx)
|
||||
# set -inf to places the causal mask is false
|
||||
attention_scaled_scores = attention_scaled_scores.masked_fill(
|
||||
torch.logical_not(causal_mask.unsqueeze(0)), float('-inf')
|
||||
)
|
||||
|
||||
|
||||
# Compute max for numerical stability
|
||||
max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0]
|
||||
if causal:
|
||||
# Replace -inf in max_scores with zeros to avoid NaN in subtraction
|
||||
max_scores = torch.where(
|
||||
torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores
|
||||
)
|
||||
|
||||
# Shift scores
|
||||
attention_shifted_scaled_scores = attention_scaled_scores - max_scores
|
||||
|
||||
# Exponentiate
|
||||
if use_exp2:
|
||||
RCP_LN = 1 / math.log(2)
|
||||
exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores)
|
||||
else:
|
||||
exp_scores = torch.exp(attention_shifted_scaled_scores)
|
||||
|
||||
# Sum of exponentials
|
||||
sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True)
|
||||
if causal:
|
||||
# if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly
|
||||
sum_exp_scores = torch.where(
|
||||
sum_exp_scores == 0,
|
||||
torch.ones_like(sum_exp_scores),
|
||||
sum_exp_scores
|
||||
)
|
||||
|
||||
# Compute softmax probabilities
|
||||
softmax = exp_scores / sum_exp_scores
|
||||
|
||||
# Compute log-sum-exp
|
||||
if use_exp2:
|
||||
LN2 = math.log(2)
|
||||
RCP_LN = 1 / math.log(2)
|
||||
max_scores_base2 = max_scores * RCP_LN
|
||||
softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores)
|
||||
softmax_lse = softmax_lse_base2 * LN2
|
||||
softmax_lse.squeeze_(-1)
|
||||
else:
|
||||
softmax_lse = max_scores + torch.log(sum_exp_scores)
|
||||
softmax_lse = softmax_lse.squeeze(-1)
|
||||
|
||||
# Compute output
|
||||
o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16)
|
||||
|
||||
return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
|
||||
|
||||
def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2):
|
||||
"""Compute reference output and softmax_lse using PyTorch's built-in function"""
|
||||
|
||||
# Ensure the layout is 'bhsd'
|
||||
if layout == "bshd":
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
v = v.transpose(1, 2).contiguous()
|
||||
elif layout != "bhsd":
|
||||
raise ValueError(f"Unknown layout {layout}")
|
||||
|
||||
# Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format
|
||||
batch_size, num_heads, seq_len_q, head_dim = q.shape
|
||||
seq_len_k = k.shape[2]
|
||||
|
||||
# Merge batch and heads dimensions
|
||||
q = q.reshape(batch_size * num_heads, seq_len_q, head_dim)
|
||||
k = k.reshape(batch_size * num_heads, seq_len_k, head_dim)
|
||||
v = v.reshape(batch_size * num_heads, seq_len_k, head_dim)
|
||||
|
||||
# Call the core attention function
|
||||
o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl(
|
||||
q, k, v, sm_scale, causal, use_exp2
|
||||
)
|
||||
|
||||
# Reshape outputs back to [batch_size, num_heads, seq_len, head_dim]
|
||||
o = o.reshape(batch_size, num_heads, seq_len_q, head_dim)
|
||||
softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q)
|
||||
exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
|
||||
softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
|
||||
attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
|
||||
attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
|
||||
attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k)
|
||||
|
||||
# Restore original layout if necessary
|
||||
if layout == "bshd":
|
||||
o = o.transpose(1, 2)
|
||||
|
||||
return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores
|
||||
|
||||
def attention_varlen_forward_pytorch_ref_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument
|
||||
use_exp2
|
||||
):
|
||||
# Ensure the layout is 'thd'
|
||||
if layout != 'thd':
|
||||
raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.")
|
||||
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
num_heads = q.shape[1]
|
||||
head_dim = q.shape[2]
|
||||
|
||||
# Pre-allocate outputs
|
||||
total_L_q = q.shape[0]
|
||||
total_L_k = k.shape[0] # pylint: disable=unused-variable
|
||||
|
||||
o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device)
|
||||
softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
# Get the start and end indices for the current sequence
|
||||
start_q = cu_seqlens_q[i].item()
|
||||
end_q = cu_seqlens_q[i + 1].item()
|
||||
start_k = cu_seqlens_k[i].item()
|
||||
end_k = cu_seqlens_k[i + 1].item()
|
||||
|
||||
# Extract q_i, k_i, v_i
|
||||
q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim]
|
||||
k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
|
||||
v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim]
|
||||
|
||||
# Permute to [num_heads, L_q_i, head_dim]
|
||||
q_i = q_i.permute(1, 0, 2)
|
||||
k_i = k_i.permute(1, 0, 2)
|
||||
v_i = v_i.permute(1, 0, 2)
|
||||
|
||||
# Call the core attention function for this sequence
|
||||
(
|
||||
o_i,
|
||||
softmax_lse_i,
|
||||
exp_scores_i,
|
||||
softmax_i,
|
||||
attention_shifted_scaled_scores_i,
|
||||
attention_scaled_scores_i,
|
||||
attention_scores_i,
|
||||
) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2)
|
||||
|
||||
# Convert back to 'thd' layout and float16
|
||||
o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim]
|
||||
|
||||
# Place outputs in pre-allocated tensors
|
||||
o[start_q:end_q, :, :] = o_i
|
||||
softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads]
|
||||
|
||||
# For variable-sized outputs, map them into the preallocated tensors
|
||||
# exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i]
|
||||
exp_scores_i = exp_scores_i.permute(1, 0, 2)
|
||||
softmax_i = softmax_i.permute(1, 0, 2)
|
||||
attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2)
|
||||
attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2)
|
||||
attention_scores_i = attention_scores_i.permute(1, 0, 2)
|
||||
|
||||
return (
|
||||
o,
|
||||
softmax_lse,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def attention_forward_pytorch_ref_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
use_exp2
|
||||
):
|
||||
# compute reference
|
||||
if layout == "thd":
|
||||
(
|
||||
o_ref,
|
||||
softmax_lse_ref,
|
||||
exp_scores_ref,
|
||||
softmax_ref,
|
||||
attention_shifted_scaled_scores_ref,
|
||||
attention_scaled_scores_ref,
|
||||
attention_scores_ref,
|
||||
) = attention_varlen_forward_pytorch_ref_impl(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
sm_scale,
|
||||
causal,
|
||||
layout,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
use_exp2,
|
||||
)
|
||||
else:
|
||||
(
|
||||
o_ref,
|
||||
softmax_lse_ref,
|
||||
exp_scores_ref,
|
||||
softmax_ref,
|
||||
attention_shifted_scaled_scores_ref,
|
||||
attention_scaled_scores_ref,
|
||||
attention_scores_ref,
|
||||
) = attention_vanilla_forward_pytorch_ref_impl(
|
||||
q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2
|
||||
)
|
||||
|
||||
return (
|
||||
o_ref,
|
||||
softmax_lse_ref,
|
||||
exp_scores_ref,
|
||||
softmax_ref,
|
||||
attention_shifted_scaled_scores_ref,
|
||||
attention_scaled_scores_ref,
|
||||
attention_scores_ref,
|
||||
)
|
||||
|
||||
|
||||
def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
|
||||
q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
|
||||
k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
|
||||
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
|
||||
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
|
||||
394
comfy/flash_attn_triton_amd/interface_fa.py
Normal file
394
comfy/flash_attn_triton_amd/interface_fa.py
Normal file
@ -0,0 +1,394 @@
|
||||
import os
|
||||
import torch
|
||||
from comfy.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl
|
||||
from comfy.flash_attn_triton_amd.bwd_prefill import attention_prefill_backward_triton_impl
|
||||
from comfy.flash_attn_triton_amd.fwd_decode import attention_decode_forward_triton_impl
|
||||
from comfy.flash_attn_triton_amd.fwd_ref import attention_forward_pytorch_ref_impl
|
||||
from comfy.flash_attn_triton_amd.bwd_ref import attention_backward_pytorch_ref_impl
|
||||
from comfy.flash_attn_triton_amd.utils import MetaData, get_shape_from_layout
|
||||
|
||||
|
||||
USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
|
||||
|
||||
|
||||
def fwd(q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
|
||||
return_softmax,
|
||||
gen_ # pylint: disable=unused-argument
|
||||
):
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
|
||||
|
||||
if o is None:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# Setup metadata
|
||||
metadata = MetaData(sm_scale=softmax_scale)
|
||||
metadata.max_seqlens_q = q.shape[1]
|
||||
metadata.max_seqlens_k = k.shape[1]
|
||||
metadata.layout = "bshd"
|
||||
if return_softmax:
|
||||
metadata.return_scores = True
|
||||
|
||||
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) # pylint: disable=unused-variable
|
||||
|
||||
if causal:
|
||||
metadata.need_causal()
|
||||
|
||||
if alibi_slopes is not None:
|
||||
metadata.need_alibi(alibi_slopes, batch, nheads_q)
|
||||
|
||||
if dropout_p > 0.0:
|
||||
metadata.need_dropout(dropout_p, return_softmax)
|
||||
|
||||
# Check arguments
|
||||
metadata.check_args(q, k, v, o)
|
||||
if USE_REF:
|
||||
(output,
|
||||
softmax_lse,
|
||||
exp_scores,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_) = attention_forward_pytorch_ref_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.sm_scale,
|
||||
metadata.causal,
|
||||
metadata.layout,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.max_seqlens_q,
|
||||
metadata.max_seqlens_k,
|
||||
metadata.use_exp2)
|
||||
o.copy_(output)
|
||||
else:
|
||||
(_,
|
||||
softmax_lse,
|
||||
exp_scores,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_) = attention_prefill_forward_triton_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
metadata.sm_scale,
|
||||
metadata.alibi_slopes,
|
||||
metadata.causal,
|
||||
metadata.bias,
|
||||
metadata.dropout_p,
|
||||
metadata.layout,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.max_seqlens_q,
|
||||
metadata.max_seqlens_k,
|
||||
metadata.return_scores,
|
||||
metadata.use_exp2)
|
||||
|
||||
return o, softmax_lse, exp_scores, None
|
||||
|
||||
|
||||
def bwd(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
|
||||
):
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("dropout is not supported on AMD yet")
|
||||
|
||||
if USE_REF:
|
||||
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
softmax_scale,
|
||||
causal,
|
||||
"bshd",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
dq.copy_(dq_ref)
|
||||
dk.copy_(dk_ref)
|
||||
dv.copy_(dv_ref)
|
||||
delta = delta_ref
|
||||
else:
|
||||
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
softmax_scale,
|
||||
alibi_slopes,
|
||||
causal,
|
||||
"bshd",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
delta = delta_triton
|
||||
|
||||
return dq, dk, dv, delta
|
||||
|
||||
|
||||
def varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_k, leftpad_k, block_table_, # pylint: disable=unused-argument
|
||||
alibi_slopes,\
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
zero_tensors, # pylint: disable=unused-argument
|
||||
causal,
|
||||
window_size_left, window_size_right, softcap, # pylint: disable=unused-argument
|
||||
return_softmax,
|
||||
gen_ # pylint: disable=unused-argument
|
||||
):
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("dropout is not supported on AMD's Triton Backend yet")
|
||||
|
||||
if o is None:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# Setup metadata
|
||||
metadata = MetaData(sm_scale=softmax_scale)
|
||||
if return_softmax:
|
||||
metadata.return_scores = True
|
||||
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
|
||||
|
||||
# get shapes
|
||||
batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable
|
||||
|
||||
if causal:
|
||||
metadata.need_causal()
|
||||
|
||||
if alibi_slopes is not None:
|
||||
metadata.need_alibi(alibi_slopes, batch, nheads_q)
|
||||
|
||||
if dropout_p > 0.0:
|
||||
metadata.need_dropout(dropout_p, return_softmax)
|
||||
|
||||
# Check arguments
|
||||
metadata.check_args(q, k, v, o)
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
if USE_REF:
|
||||
(output,
|
||||
softmax_lse,
|
||||
exp_scores,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_) = attention_forward_pytorch_ref_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.sm_scale,
|
||||
metadata.causal,
|
||||
metadata.layout,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.max_seqlens_q,
|
||||
metadata.max_seqlens_k,
|
||||
metadata.use_exp2)
|
||||
o.copy_(output)
|
||||
else:
|
||||
(_,
|
||||
softmax_lse,
|
||||
exp_scores,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_) = attention_prefill_forward_triton_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
metadata.sm_scale,
|
||||
metadata.alibi_slopes,
|
||||
metadata.causal,
|
||||
metadata.bias,
|
||||
metadata.dropout_p,
|
||||
metadata.layout,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.max_seqlens_q,
|
||||
metadata.max_seqlens_k,
|
||||
metadata.return_scores,
|
||||
metadata.use_exp2)
|
||||
|
||||
return o, softmax_lse, exp_scores, None
|
||||
|
||||
|
||||
def varlen_bwd(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
zero_tensors, # pylint: disable=unused-argument
|
||||
causal,
|
||||
window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument
|
||||
):
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("dropout is not supported on AMD yet")
|
||||
|
||||
if USE_REF:
|
||||
dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
softmax_scale,
|
||||
causal,
|
||||
"thd",
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
False,
|
||||
)
|
||||
dq.copy_(dq_ref)
|
||||
dk.copy_(dk_ref)
|
||||
dv.copy_(dv_ref)
|
||||
delta = delta_ref
|
||||
else:
|
||||
dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
softmax_scale,
|
||||
alibi_slopes,
|
||||
causal,
|
||||
"thd",
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
False,
|
||||
)
|
||||
delta = delta_triton
|
||||
|
||||
return dq, dk, dv, delta
|
||||
|
||||
|
||||
def fwd_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k,
|
||||
v,
|
||||
cache_seqlens,
|
||||
rotary_cos, rotary_sin, # pylint: disable=unused-argument
|
||||
cache_batch_idx,
|
||||
cache_leftpad, block_table, # pylint: disable=unused-argument
|
||||
alibi_slopes,
|
||||
out,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size_left, window_size_right, softcap, rotary_interleaved, num_splits, # pylint: disable=unused-argument
|
||||
):
|
||||
if out is None:
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# fill metadata
|
||||
metadata = MetaData(sm_scale=softmax_scale)
|
||||
metadata.layout = "bshd"
|
||||
metadata.max_seqlens_q = q.shape[1]
|
||||
metadata.max_seqlens_k = k_cache.shape[1]
|
||||
metadata.cache_seqlens = cache_seqlens
|
||||
metadata.cache_batch_idx = cache_batch_idx
|
||||
|
||||
if k is not None and v is not None:
|
||||
metadata.new_kv = True
|
||||
metadata.seqlen_new = k.shape[1]
|
||||
metadata.k_new = k
|
||||
metadata.v_new = v
|
||||
|
||||
if causal:
|
||||
metadata.need_causal()
|
||||
|
||||
if alibi_slopes is not None:
|
||||
batch, _ , nheads_q, _= q.shape
|
||||
metadata.need_alibi(alibi_slopes, batch, nheads_q)
|
||||
|
||||
# launch kernel
|
||||
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
|
||||
output, softmax_lse = attention_decode_forward_triton_impl(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
metadata.sm_scale,
|
||||
metadata.causal,
|
||||
metadata.alibi_slopes,
|
||||
metadata.layout,
|
||||
metadata.cache_seqlens,
|
||||
metadata.cache_batch_idx,
|
||||
metadata.new_kv,
|
||||
metadata.k_new,
|
||||
metadata.v_new,
|
||||
)
|
||||
return output, softmax_lse
|
||||
280
comfy/flash_attn_triton_amd/utils.py
Normal file
280
comfy/flash_attn_triton_amd/utils.py
Normal file
@ -0,0 +1,280 @@
|
||||
import os
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
|
||||
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
|
||||
|
||||
|
||||
class MetaData():
|
||||
cu_seqlens_q = None
|
||||
cu_seqlens_k = None
|
||||
max_seqlens_q = 0
|
||||
max_seqlens_k = 0
|
||||
bias = None
|
||||
alibi_slopes = None
|
||||
causal = False
|
||||
num_contexts = 0
|
||||
varlen = False
|
||||
layout = None
|
||||
cache_seqlens = None
|
||||
cache_batch_idx = None
|
||||
new_kv = False
|
||||
seqlen_new = None
|
||||
k_new = None
|
||||
v_new = None
|
||||
dropout_p, return_scores= 0.0, False
|
||||
# NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
|
||||
use_exp2 = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"MetaData(\n"
|
||||
f" sm_scale={self.sm_scale},\n"
|
||||
f" cu_seqlens_q={self.cu_seqlens_q},\n"
|
||||
f" cu_seqlens_k={self.cu_seqlens_k},\n"
|
||||
f" max_seqlens_q={self.max_seqlens_q},\n"
|
||||
f" max_seqlens_k={self.max_seqlens_k},\n"
|
||||
f" bias={self.bias},\n"
|
||||
f" alibi_slopes={self.alibi_slopes},\n"
|
||||
f" causal={self.causal},\n"
|
||||
f" num_contexts={self.num_contexts},\n"
|
||||
f" varlen={self.varlen},\n"
|
||||
f" layout={self.layout},\n"
|
||||
f" cache_seqlens={self.cache_seqlens},\n"
|
||||
f" cache_batch_idx={self.cache_batch_idx},\n"
|
||||
f" new_kv={self.new_kv},\n"
|
||||
f" seqlen_new={self.seqlen_new},\n"
|
||||
f" k_new={self.k_new},\n"
|
||||
f" v_new={self.v_new},\n"
|
||||
f" dropout_p={self.dropout_p},\n"
|
||||
f" return_scores={self.return_scores}\n"
|
||||
f")")
|
||||
|
||||
def __init__(self, sm_scale=1.0):
|
||||
self.sm_scale = sm_scale
|
||||
|
||||
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
|
||||
self.varlen = True
|
||||
self.layout = 'thd'
|
||||
self.cu_seqlens_q = cu_seqlens_q
|
||||
self.cu_seqlens_k = cu_seqlens_k
|
||||
# Without "varlen", there should still be one sequence.
|
||||
assert len(cu_seqlens_q) >= 2
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
self.num_contexts = len(cu_seqlens_q) - 1
|
||||
for i in range(0, self.num_contexts):
|
||||
self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q)
|
||||
self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k)
|
||||
|
||||
def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): # pylint: disable=unused-argument
|
||||
assert bias.is_cuda
|
||||
assert bias.dim() == 4
|
||||
assert bias.shape[0] == 1
|
||||
assert bias.shape[2:] == (seqlen_q, seqlen_k)
|
||||
self.bias = bias
|
||||
|
||||
def need_alibi(self, alibi_slopes, batch, nheads):
|
||||
assert alibi_slopes.is_cuda
|
||||
assert alibi_slopes.dim() == 2
|
||||
assert alibi_slopes.shape[0] == batch
|
||||
assert alibi_slopes.shape[1] == nheads
|
||||
self.alibi_slopes = alibi_slopes
|
||||
|
||||
def need_causal(self):
|
||||
self.causal = True
|
||||
|
||||
def need_dropout(self, dropout_p, return_scores):
|
||||
self.dropout_p = dropout_p
|
||||
self.return_scores = return_scores
|
||||
|
||||
def check_args(self, q, k, v, o):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
|
||||
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) # pylint: disable=unused-variable
|
||||
if self.varlen:
|
||||
assert q.dim() == 3
|
||||
assert self.cu_seqlens_q is not None
|
||||
assert self.cu_seqlens_k is not None
|
||||
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
|
||||
# TODO: Remove once bias is supported with varlen
|
||||
assert self.bias is None
|
||||
# TODO:Remove once dropout is supported with varlen
|
||||
assert self.dropout_p == 0.0
|
||||
# assert not self.return_scores
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
|
||||
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
assert self.layout is not None
|
||||
assert self.layout == 'thd' or not self.varlen
|
||||
|
||||
def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False):
|
||||
torch.manual_seed(20)
|
||||
|
||||
# Initialize q, k, v
|
||||
if layout == 'bhsd':
|
||||
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
|
||||
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
|
||||
elif layout == 'bshd':
|
||||
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
|
||||
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
|
||||
else:
|
||||
assert False, f'Got unsupported tensor layout: {layout}'
|
||||
|
||||
q = None
|
||||
k = None
|
||||
v = None
|
||||
|
||||
if DEBUG_INPUT:
|
||||
if layout == "bhsd":
|
||||
q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
|
||||
k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
|
||||
v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
|
||||
elif layout == "bshd":
|
||||
q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_()
|
||||
k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
|
||||
v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_()
|
||||
else:
|
||||
q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True)
|
||||
k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
|
||||
v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True)
|
||||
|
||||
if DEBUG_INPUT:
|
||||
sm_scale = 1
|
||||
else:
|
||||
sm_scale = D_HEAD**-0.5
|
||||
input_metadata = MetaData(sm_scale=sm_scale)
|
||||
input_metadata.max_seqlens_q = N_CTX_Q
|
||||
input_metadata.max_seqlens_k = N_CTX_K
|
||||
input_metadata.layout = layout
|
||||
return q, k, v, input_metadata
|
||||
|
||||
|
||||
def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False):
|
||||
torch.manual_seed(20)
|
||||
|
||||
# Random or equal sequence lengths based on 'equal_seqlens' flag
|
||||
if not equal_seqlens:
|
||||
max_seqlens_q = N_CTX_Q // Z
|
||||
max_seqlens_k = N_CTX_K // Z
|
||||
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
|
||||
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
|
||||
else:
|
||||
seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32)
|
||||
seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32)
|
||||
|
||||
# Calculate cumulative sequence lengths
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)])
|
||||
cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)])
|
||||
cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32)
|
||||
cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32)
|
||||
|
||||
# Total lengths
|
||||
total_q = cu_seqlens_q[-1].item()
|
||||
total_k = cu_seqlens_k[-1].item()
|
||||
|
||||
if DEBUG_INPUT:
|
||||
# Initialize q, k, v with deterministic values
|
||||
q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1)
|
||||
q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_()
|
||||
k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
|
||||
k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
|
||||
v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1)
|
||||
v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_()
|
||||
sm_scale = 1
|
||||
else:
|
||||
# Initialize q, k, v with random values
|
||||
q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_()
|
||||
k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
|
||||
v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_()
|
||||
sm_scale = D_HEAD ** -0.5
|
||||
|
||||
input_metadata = MetaData(sm_scale=sm_scale)
|
||||
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
|
||||
return q, k, v, input_metadata
|
||||
|
||||
|
||||
def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
|
||||
if layout == 'bhsd':
|
||||
batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape
|
||||
batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape
|
||||
elif layout == 'bshd':
|
||||
batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape
|
||||
batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape
|
||||
elif layout == 'thd':
|
||||
batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] # pylint: disable=self-assigning-variable
|
||||
batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] # pylint: disable=self-assigning-variable
|
||||
else:
|
||||
assert False, "Got unsupported layout."
|
||||
|
||||
# assert
|
||||
assert batch_q == batch_k
|
||||
assert head_size_q == head_size_k
|
||||
|
||||
return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k
|
||||
|
||||
|
||||
def get_strides_from_layout(q, k, v, o, layout):
|
||||
if layout == 'thd':
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
elif layout == 'bhsd':
|
||||
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
|
||||
elif layout == 'bshd':
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
else:
|
||||
assert False, 'Got unsupported layout.'
|
||||
return q_strides, k_strides, v_strides, o_strides
|
||||
|
||||
|
||||
def get_padded_headsize(size):
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
padded_d_model = 1 << (size - 1).bit_length()
|
||||
# Smallest head_dim supported is 16. If smaller, the tile in the
|
||||
# kernel is padded - there is no padding in memory for any dims.
|
||||
padded_d_model = max(padded_d_model, 16)
|
||||
return padded_d_model
|
||||
|
||||
|
||||
def _strides(x: torch.Tensor, *stride_names: str):
|
||||
if x is None:
|
||||
return {f"stride_{s}": 0 for i, s in enumerate(stride_names)}
|
||||
|
||||
assert x.ndim == len(stride_names)
|
||||
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
|
||||
|
||||
|
||||
def get_input_shapes():
|
||||
cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128)
|
||||
for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)]
|
||||
return cases
|
||||
|
||||
|
||||
def is_hip():
|
||||
return triton.runtime.driver.active.get_current_target().backend == "hip"
|
||||
|
||||
|
||||
def is_cdna():
|
||||
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
|
||||
'gfx90a', 'gfx908')
|
||||
|
||||
|
||||
def is_rdna():
|
||||
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101",
|
||||
"gfx1102", "gfx1200", "gfx1201")
|
||||
Loading…
Reference in New Issue
Block a user