From aad001bfbeb884de463827196926fa41261fa466 Mon Sep 17 00:00:00 2001 From: patientx Date: Thu, 1 May 2025 23:04:46 +0300 Subject: [PATCH] Add files via upload --- comfy/flash_attn_triton_amd/__init__.py | 0 comfy/flash_attn_triton_amd/bwd_prefill.py | 606 +++++++++++++++++ comfy/flash_attn_triton_amd/bwd_ref.py | 271 ++++++++ comfy/flash_attn_triton_amd/fwd_decode.py | 700 ++++++++++++++++++++ comfy/flash_attn_triton_amd/fwd_prefill.py | 634 ++++++++++++++++++ comfy/flash_attn_triton_amd/fwd_ref.py | 258 ++++++++ comfy/flash_attn_triton_amd/interface_fa.py | 394 +++++++++++ comfy/flash_attn_triton_amd/utils.py | 280 ++++++++ 8 files changed, 3143 insertions(+) create mode 100644 comfy/flash_attn_triton_amd/__init__.py create mode 100644 comfy/flash_attn_triton_amd/bwd_prefill.py create mode 100644 comfy/flash_attn_triton_amd/bwd_ref.py create mode 100644 comfy/flash_attn_triton_amd/fwd_decode.py create mode 100644 comfy/flash_attn_triton_amd/fwd_prefill.py create mode 100644 comfy/flash_attn_triton_amd/fwd_ref.py create mode 100644 comfy/flash_attn_triton_amd/interface_fa.py create mode 100644 comfy/flash_attn_triton_amd/utils.py diff --git a/comfy/flash_attn_triton_amd/__init__.py b/comfy/flash_attn_triton_amd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/flash_attn_triton_amd/bwd_prefill.py b/comfy/flash_attn_triton_amd/bwd_prefill.py new file mode 100644 index 000000000..4ba9f0e39 --- /dev/null +++ b/comfy/flash_attn_triton_amd/bwd_prefill.py @@ -0,0 +1,606 @@ +import torch +import triton +import triton.language as tl +from comfy.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout + + +@triton.jit +def _bwd_preprocess_use_o( + Out, + DO, + Delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_doz, stride_doh, stride_dom, stride_dok, # pylint: disable=unused-argument + stride_deltaz, stride_deltah, stride_deltam, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + N_CTX_Q: tl.constexpr, + Z: tl.constexpr, # pylint: disable=unused-argument + H: tl.constexpr, + IS_VARLEN: tl.constexpr +): + pid_m = tl.program_id(0) + pid_bh = tl.program_id(1) + + # Compute batch and head indices + off_z = pid_bh // H + off_h = pid_bh % H + + if IS_VARLEN: + # Compute sequence lengths for the current batch + q_start = tl.load(cu_seqlens_q + off_z) + q_end = tl.load(cu_seqlens_q + off_z + 1) + k_start = tl.load(cu_seqlens_k + off_z) + k_end = tl.load(cu_seqlens_k + off_z + 1) + + # Compute actual sequence lengths + N_CTX_Q = q_end - q_start + N_CTX_K = k_end - k_start # pylint: disable=unused-variable + else: + q_start = 0 + k_start = 0 + N_CTX_Q = max_seqlen_q + N_CTX_K = max_seqlen_k # pylint: disable=unused-variable + + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_d = tl.arange(0, BLOCK_DMODEL) + + # create masks + mask_m = off_m < N_CTX_Q + mask_d = off_d < ACTUAL_BLOCK_DMODEL + + # compute offsets + o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om + do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om + + # compute pointers + out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok + do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok + + # load + o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + + # compute delta + delta = tl.sum(o * do, axis=1) + + # write-back delta + delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam + delta_ptrs = delta_offset + off_m * stride_deltam + tl.store(delta_ptrs, delta, mask=mask_m) + + +@triton.jit +def _bwd_kernel_one_col_block( + Q, + K, + V, + sm_scale, + Out, DO, DQ, DK, DV, L, D, # pylint: disable=unused-argument + q_offset, + k_offset, + v_offset, + do_offset, + dq_offset, + dk_offset, + dv_offset, + d_offset, + l_offset, + stride_dq_all, stride_qz, stride_qh, # pylint: disable=unused-argument + stride_qm, + stride_qk, + stride_kz, stride_kh, # pylint: disable=unused-argument + stride_kn, + stride_kk, + stride_vz, stride_vh, # pylint: disable=unused-argument + stride_vn, + stride_vk, + stride_deltaz, stride_deltah, # pylint: disable=unused-argument + stride_deltam, + Z, H, # pylint: disable=unused-argument + N_CTX_Q, + N_CTX_K, + off_h, off_z, off_hz, # pylint: disable=unused-argument + start_n, + num_block_m, + num_block_n, # pylint: disable=unused-argument + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + CAUSAL: tl.constexpr, + USE_EXP2: tl.constexpr, +): + if CAUSAL: + # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M + lo = 0 + else: + lo = 0 + + # initialize col and head offsets + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # masks + mask_n = offs_n < N_CTX_K + mask_d = offs_d < ACTUAL_BLOCK_DMODEL + kv_mask = mask_n[:, None] & mask_d[None, :] + + # initialize grad accumulators + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + + # load k and v once per column block + k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk + k = tl.load(k_ptrs, mask=kv_mask, other=0.0) + v = tl.load(v_ptrs, mask=kv_mask, other=0.0) + + # loop over rows + for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M): + offs_m = start_m + tl.arange(0, BLOCK_M) + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + + # update mask as row block changes + mask_m = offs_m < N_CTX_Q + q_mask = mask_m[:, None] & mask_d[None, :] + + # load q, k, v, do on-chip + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + do = tl.load(do_ptrs, mask=q_mask, other=0.0) + + # recompute p = softmax(qk, dim=-1).T + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + if CAUSAL: + col_offset = N_CTX_Q - N_CTX_K + causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :]) + qk = tl.where(causal_mask, qk, float("-inf")) + + l_ptrs = l_offset + offs_m * stride_deltam + l_i = tl.load(l_ptrs, mask=mask_m) + + # compute p + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + qk *= sm_scale * RCP_LN2 + l_i *= RCP_LN2 + p = tl.math.exp2(qk - l_i[:, None]) + else: + qk *= sm_scale + p = tl.math.exp(qk - l_i[:, None]) + + # mask block in the cases where the data is smaller the block size + p_mask = mask_m[:, None] & mask_n[None, :] + p = tl.where(p_mask, p, 0.0) + + # compute dv + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + + # compute dp + dp = tl.dot(do, tl.trans(v)) + + # compute ds , ds = p * (dp - delta[:, None]) + d_ptrs = d_offset + offs_m * stride_deltam + Di = tl.load(d_ptrs, mask=mask_m) + ds = (p * (dp - Di[:, None])) * sm_scale + ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty) + + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q) + + # compute dq + if SEQUENCE_PARALLEL: + dq = tl.dot(ds, k) + else: + dq = tl.load(dq_ptrs, mask=q_mask, other=0.0) + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask) + + # write-back dv and dk + dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk + + # write-back + tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) + tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) + +@triton.jit +def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + D, + stride_dq_all, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_deltaz, + stride_deltah, + stride_deltam, + Z, + H, + num_block_m, + num_block_n, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + CAUSAL: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + # program ids + off_hz = tl.program_id(0) + if SEQUENCE_PARALLEL: + start_n = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + if IS_VARLEN: + # Compute sequence lengths for the current batch + q_start = tl.load(cu_seqlens_q + off_z) + q_end = tl.load(cu_seqlens_q + off_z + 1) + k_start = tl.load(cu_seqlens_k + off_z) + k_end = tl.load(cu_seqlens_k + off_z + 1) + + # Compute actual sequence lengths + N_CTX_Q = q_end - q_start + N_CTX_K = k_end - k_start + else: + q_start = 0 + k_start = 0 + N_CTX_Q = max_seqlen_q + N_CTX_K = max_seqlen_k + + # input tensor offsets + q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn + v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn + do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam + d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam + + # output tensor offsets + dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn + dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn + if SEQUENCE_PARALLEL: + dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + else: + dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + + # inner loop + if SEQUENCE_PARALLEL: + _bwd_kernel_one_col_block( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + D, + q_offset, + k_offset, + v_offset, + do_offset, + dq_offset, + dk_offset, + dv_offset, + d_offset, + l_offset, + stride_dq_all, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_deltaz, + stride_deltah, + stride_deltam, + Z, + H, + N_CTX_Q, + N_CTX_K, + off_h, + off_z, + off_hz, + start_n, + num_block_m, + num_block_n, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, + CAUSAL=CAUSAL, + USE_EXP2=USE_EXP2, + ) + else: + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + D, + q_offset, + k_offset, + v_offset, + do_offset, + dq_offset, + dk_offset, + dv_offset, + d_offset, + l_offset, + stride_dq_all, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vn, + stride_vk, + stride_deltaz, + stride_deltah, + stride_deltam, + Z, + H, + N_CTX_Q, + N_CTX_K, + off_h, + off_z, + off_hz, + start_n, + num_block_m, + num_block_n, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, + CAUSAL=CAUSAL, + USE_EXP2=USE_EXP2, + ) + + +# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom. +def attention_prefill_backward_triton_impl( + do, + q, + k, + v, + o, + softmax_lse, + dq, + dk, + dv, + sm_scale: float, + alibi_slopes, # pylint: disable=unused-argument + causal, + layout: str, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q: int, + max_seqlen_k: int, + use_exp2: bool, + sequence_parallel = True, +): + # make contigious + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + softmax_lse = softmax_lse.contiguous() + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + stride_qz, stride_qh, stride_qm, stride_qk = q_strides + stride_kz, stride_kh, stride_kn, stride_kk = k_strides + stride_vz, stride_vh, stride_vn, stride_vk = v_strides + stride_oz, stride_oh, stride_om, stride_ok = o_strides + batch_headsize = batch * nheads_q + is_varlen = layout == "thd" + + # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks + if max_seqlen_q <= 32 or max_seqlen_k <= 32: + BLOCK_M = 32 + BLOCK_N = 32 + else: + BLOCK_M = 64 + BLOCK_N = 64 + num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful + num_stages = 1 + waves_per_eu = 1 + + # divide up the problem + num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M) + num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + BLOCK_DMODEL = padded_d_model + ACTUAL_BLOCK_DMODEL = head_size + + do = do.contiguous() + # NOTE: we might need to copy the output tensor if they are not continuous or have other issues + copy_back = {"dq": False, "dk": False, "dv": False} + + dq_og = None + # deal with dq + if dq is None: + if sequence_parallel: + dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype) + else: + dq_og = dq + if not dq.is_contiguous(): + dq = dq.contiguous() + copy_back["dq"] = True + + if sequence_parallel: + dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) + copy_back["dq"] = True + else: + # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros + dq.zero_() + stride_dq_all = dq.stride()[0] + + dk_og = None + dv_og = None + # deal with dk, dv + if (dk is None) or (dv is None): + dk = torch.empty_like(k) + dv = torch.empty_like(v) + else: + if not dk.is_contiguous(): + dk_og = dk + dk = dk.contiguous() + copy_back["dk"] = True + + if not dv.is_contiguous(): + dv_og = dv + dv = dv.contiguous() + copy_back["dv"] = True + + # assert contigious + assert do.is_contiguous() + assert q.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert o.is_contiguous() + assert softmax_lse.is_contiguous() + + # init delta + delta = torch.empty_like(softmax_lse) + if is_varlen: + stride_deltam, stride_deltah = delta.stride() + stride_deltaz = 0 + else: + stride_deltaz, stride_deltah, stride_deltam = delta.stride() + + _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)]( + o, + do, + delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_oz, stride_oh, stride_om, stride_ok, + stride_deltaz, stride_deltah, stride_deltam, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLOCK_M=BLOCK_M, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + N_CTX_Q=max_seqlen_q, + Z=batch, + H=nheads_q, + IS_VARLEN=is_varlen + ) + + _bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)]( + q, + k, + v, + sm_scale, + o, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_dq_all, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_deltaz, stride_deltah, stride_deltam, + batch, + nheads_q, + num_blocks_m, + num_blocks_n, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + SEQUENCE_PARALLEL=sequence_parallel, + CAUSAL=causal, + USE_EXP2=use_exp2, + num_warps=num_warps, + num_stages=num_stages, + waves_per_eu = waves_per_eu, + IS_VARLEN=is_varlen + ) + + if sequence_parallel: + dq = dq.sum(dim=0) + + if copy_back["dq"]: + dq_og.copy_(dq) + dq = dq_og + if copy_back["dk"]: + dk_og.copy_(dk) + dk = dk_og + if copy_back["dv"]: + dv_og.copy_(dv) + dv = dv_og + + return dq, dk, dv, delta, None, None diff --git a/comfy/flash_attn_triton_amd/bwd_ref.py b/comfy/flash_attn_triton_amd/bwd_ref.py new file mode 100644 index 000000000..4dd182f5d --- /dev/null +++ b/comfy/flash_attn_triton_amd/bwd_ref.py @@ -0,0 +1,271 @@ +import math +import torch + + +def attention_backward_core_ref_impl( + do, q, k, v, o, softmax_lse, sm_scale, causal, use_exp2 +): + # cast to float32 + do = do.to(torch.float32) + q = q.to(torch.float32) + k = k.to(torch.float32) + v = v.to(torch.float32) + o = o.to(torch.float32) + softmax_lse = softmax_lse.to(torch.float32) + + # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32 + attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) + + # scale scores + attention_scaled_scores = sm_scale * attention_scores + + # Apply causal mask if necessary + if causal: + L_q, L_k = q.shape[1], k.shape[1] + row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) + col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) + col_offset = L_q-L_k + causal_mask = row_idx >= (col_offset + col_idx) + # set -inf to places the causal mask is false + attention_scaled_scores = attention_scaled_scores.masked_fill( + torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') + ) + + # compute probabilities using softmax_lse + if use_exp2: + RCP_LN = 1 / math.log(2) + attention_scaled_scores_base2 = attention_scaled_scores * RCP_LN + softmax_lse_base2 = softmax_lse * RCP_LN + softmax_lse_3d = softmax_lse_base2.unsqueeze(-1) + p = torch.exp2(attention_scaled_scores_base2 - softmax_lse_3d) + else: + softmax_lse_3d = softmax_lse.unsqueeze(-1) + p = torch.exp(attention_scaled_scores - softmax_lse_3d) + + # compute gradient wrt v + dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32)) + + # compute dp + dp = torch.matmul(do, v.transpose(-2, -1)) + + # calculate ds using dp + delta = torch.sum(o * do, axis=-1).to(torch.float32) # what OAI kernel uses + delta_3d = delta.unsqueeze(-1) + ds = (p * (dp - delta_3d)) * sm_scale + + # compute gradient wrt k + dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32)) + + # compute gradient wrt q + dq = torch.matmul(ds, k.to(torch.float32)) + + # cast back to original dtype + dq = dq.to(torch.float16) + dk = dk.to(torch.float16) + dv = dv.to(torch.float16) + + # remove d dim with size 1 + delta = delta_3d.squeeze(-1) + + return dq, dk, dv, delta + +def attention_varlen_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument + use_exp2, +): + # Ensure the layout is 'thd' + if layout != 'thd': + raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") + + batch_size = cu_seqlens_q.shape[0] - 1 + num_heads = q.shape[1] + head_dim = q.shape[2] # pylint: disable=unused-variable + + # Pre-allocate outputs + total_L_q = q.shape[0] + total_L_k = k.shape[0] # pylint: disable=unused-variable + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + # delta has the same shape as softmax_lse: [total_L_q, num_heads] + delta = torch.zeros((total_L_q, num_heads), dtype=torch.float32, device=o.device) + + for i in range(batch_size): + # Get the start and end indices for the current sequence + start_q = cu_seqlens_q[i].item() + end_q = cu_seqlens_q[i + 1].item() + start_k = cu_seqlens_k[i].item() + end_k = cu_seqlens_k[i + 1].item() + + # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i + q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] + do_i = do[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] + o_i = o[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] + # softmax_lse has shape [total_L_q, num_heads] + softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, num_heads] + softmax_lse_i = softmax_lse_i.transpose(0, 1) # [num_heads, L_q_i] + + # Permute to [num_heads, L_q_i, head_dim] + q_i = q_i.permute(1, 0, 2) + k_i = k_i.permute(1, 0, 2) + v_i = v_i.permute(1, 0, 2) + do_i = do_i.permute(1, 0, 2) + o_i = o_i.permute(1, 0, 2) + # softmax_lse_i is already in [num_heads, L_q_i] + + # Call the core backward function for this sequence + dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl( + do_i, + q_i, + k_i, + v_i, + o_i, + softmax_lse_i, + sm_scale, + causal, + use_exp2 + ) + + # Convert back to 'thd' layout + dq_i = dq_i.permute(1, 0, 2) # [L_q_i, num_heads, head_dim] + dk_i = dk_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim] + dv_i = dv_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim] + + # Place outputs in pre-allocated tensors + dq[start_q:end_q, :, :] = dq_i + dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys + dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values + # delta_i has shape [num_heads, L_q_i] + delta_i = delta_i.transpose(1, 0) # [L_q_i, num_heads] + delta[start_q:end_q, :] = delta_i + + return dq, dk, dv, delta + +def attention_vanilla_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + use_exp2, +): + if layout == "bshd": + do = do.transpose(1, 2).contiguous() + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + o = o.transpose(1, 2).contiguous() + elif layout == "bhsd": + pass + else: + raise ValueError(f"Unknown layout {layout}") + + # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format + batch_size, num_heads, seq_len_q, head_dim = q.shape + seq_len_k = k.shape[2] + + # Merge batch and heads dimensions + do = do.reshape(batch_size * num_heads, seq_len_q, head_dim) + q = q.reshape(batch_size * num_heads, seq_len_q, head_dim) + k = k.reshape(batch_size * num_heads, seq_len_k, head_dim) + v = v.reshape(batch_size * num_heads, seq_len_k, head_dim) + softmax_lse = softmax_lse.reshape(batch_size * num_heads, seq_len_q) + o = o.reshape(batch_size * num_heads, seq_len_q, head_dim) + + dq, dk, dv, delta = attention_backward_core_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + use_exp2 + ) + + # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim] + dq = dq.reshape(batch_size, num_heads, seq_len_q, head_dim) + dk = dk.reshape(batch_size, num_heads, seq_len_k, head_dim) + dv = dv.reshape(batch_size, num_heads, seq_len_k, head_dim) + delta = delta.reshape(batch_size, num_heads, seq_len_q) + + # Go back to original layout + if layout == "bshd": + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + dv = dv.transpose(1, 2) + elif layout == "bhsd": + pass + else: + raise ValueError(f"Unknown layout {layout}") + + return dq, dk, dv, delta + + +def attention_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + use_exp2 +): + if layout == "thd": + dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + use_exp2, + ) + else: + dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl( + do, + q, + k, + v, + o, + softmax_lse, + sm_scale, + causal, + layout, + use_exp2, + ) + + return dq, dk, dv, delta diff --git a/comfy/flash_attn_triton_amd/fwd_decode.py b/comfy/flash_attn_triton_amd/fwd_decode.py new file mode 100644 index 000000000..9360c583a --- /dev/null +++ b/comfy/flash_attn_triton_amd/fwd_decode.py @@ -0,0 +1,700 @@ +import torch +import triton +import triton.language as tl +from comfy.flash_attn_triton_amd.utils import _strides, get_padded_headsize + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + K_new, + V_new, + Cache_seqlens, + Cache_batch_idx, + Alibi_slopes, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qd, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kd, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vd, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_d, # pylint: disable=unused-argument + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, # pylint: disable=unused-argument + stride_kn_z, + stride_kn_n, + stride_kn_g, + stride_kn_h, + stride_kn_d, + stride_vn_z, + stride_vn_n, + stride_vn_g, + stride_vn_h, + stride_vn_d, + stride_az, + stride_ah, + Z, # pylint: disable=unused-argument + N_CTX_Q, + N_CTX_K, + N_CTX_NEW, + BLOCK_N_PER_SPLIT, + H_q: tl.constexpr, + H_kv: tl.constexpr, + G_q: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_CACHE_SEQLENs: tl.constexpr, + USE_CACHE_BATCH_IDX: tl.constexpr, + NEW_KV: tl.constexpr, + IS_GQA: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_ALIBI: tl.constexpr, +): + # Padding + PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + if PADDED_HEAD: + d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H_q * G_q) + off_h_q = (off_zhg // G_q) % H_q + off_g_q = off_zhg % G_q + splitk_idx = tl.program_id(2) + + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + off_z) + else: + cache_batch_idx = off_z + + # Load ALiBi slope if enabled + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_CACHE_SEQLENs: + cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) + if NEW_KV: + kv_len = cache_seqlen_last_idx + N_CTX_NEW + else: + kv_len = cache_seqlen_last_idx + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + HEAD_RATIO: tl.constexpr = H_q // H_kv + if IS_GQA: + k_head_idx = off_h_q // HEAD_RATIO + v_head_idx = k_head_idx + else: + k_head_idx = off_h_q + v_head_idx = off_h_q + + # calculate base offset + k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg + v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg + + # Copy new Keys and Values into Cache + if NEW_KV: + knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g + + # Determine the starting position for new data in the cache + if USE_CACHE_SEQLENs: + start_idx = tl.load(Cache_seqlens + off_z) + else: + start_idx = N_CTX_K - N_CTX_NEW + + # Copy new Keys + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from K_new + k_new_block = tl.load( + knew_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + + (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to K + tl.store( + k_base + + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + + (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, + k_new_block, + mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), + ) + + # Copy new Values + vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g + for i in range(0, N_CTX_NEW, BLOCK_N): + # Load from V_new + v_new_block = tl.load( + vnew_base + + (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + other=0 + ) + + # Store to V + tl.store( + v_base + + (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, + v_new_block, + mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & + (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), + ) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg, + shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qd), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + + K_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(ACTUAL_BLOCK_DMODEL, hi), + strides=(stride_kd, stride_kn), + offsets=(0, lo), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, ACTUAL_BLOCK_DMODEL), + strides=(stride_vn, stride_vd), + offsets=(lo, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) + q = (q * qk_scale).to(q.dtype) + if PADDED_HEAD: + q = tl.where(d_mask[None, :], q, 0.0) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k, v = load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + 1, + BLOCK_DMODEL, + ACTUAL_BLOCK_DMODEL, + Q.dtype.element_ty, + 0, + ) + if PADDED_HEAD: + k = tl.where(d_mask[:, None], k, 0.0) + v = tl.where(d_mask[None, :], v, 0.0) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) # noqa: F821 + + if USE_ALIBI: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # Apply causal mask if IS_CAUSAL is True + if IS_CAUSAL: + row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_Q - kv_len + causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + if IS_CAUSAL: + alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) + else: + alpha = tl.math.exp2(m_i - m_i_new) + # cause of nan because subtracting infs + if IS_CAUSAL: + qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) + else: + qk = qk - m_i_new[:, None] + + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(v.dtype), v) + + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, BLOCK_DMODEL), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, + boundary_check=(0, ), + ) + # Write metadata for split-K reduction + Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + + tl.arange(0, BLOCK_M)) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + +@triton.jit +def load_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, V_scale_shift_block_ptr, # pylint: disable=unused-argument + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # pylint: disable=unused-argument + ACTUAL_BLOCK_DMODEL: tl.constexpr, + dtype: tl.constexpr, # pylint: disable=unused-argument + group_id: tl.constexpr, +): + # Load K/V for a given block + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) + + return k, v + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, # pylint: disable=unused-argument + stride_lse_zhg, + stride_lse_m, M_ceil: tl.constexpr, # pylint: disable=unused-argument + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + use_mask: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm + + o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + + if IS_CAUSAL: + l_m_offset = l_m - g_m + alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) + else: + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + + if IS_CAUSAL: + # Avoid division by zero + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe + else: + acc_out = tl.sum(acc, axis=0) / g_sum + + # Store output + Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + + off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + tl.store(Out_ptr, acc_out) + + # Store lse + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + if IS_CAUSAL: + lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + +def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new): + # kernel config + BLOCK_M = 16 + BLOCK_N = 64 + SPLIT_K = None + NUM_QUANT_GROUPS = 1 # pylint: disable=unused-variable + + # kernels expects "bsghd" + original_layout = layout + if layout == "bshd": + q = q.unsqueeze(2) + k = k.unsqueeze(2) + v = v.unsqueeze(2) + if new_kv: + k_new = k_new.unsqueeze(2) + v_new = v_new.unsqueeze(2) + layout = "bsghd" + elif layout == "bhsd": + q = q.permute(0, 2, 1, 3).unsqueeze(2) + k = k.permute(0, 2, 1, 3).unsqueeze(2) + v = v.permute(0, 2, 1, 3).unsqueeze(2) + if new_kv: + k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) + v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) + layout = "bsghd" + elif layout == "bsghd": + pass + elif layout is None: + raise ValueError("Layout not given") + assert layout == "bsghd" + + # get dims + batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape + _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape # pylint: disable=unused-variable + _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape # pylint: disable=unused-variable + + assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" + + # get padded size + dim_padded = get_padded_headsize(dim_k) + + # Handle MQA/GQA case + if heads_per_group_q > heads_per_group_k: + is_gqa = True + elif heads_per_group_q < heads_per_group_k: + raise ValueError("heads_per_group_q < heads_per_group_k") + else: + is_gqa = False + + assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" + + if SPLIT_K is not None: + split_k = SPLIT_K + else: + # Use heuristics + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens? + + seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M + out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) + lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k) + + num_warps = 1 + split_size = (seqlen_k + split_k - 1) // split_k + use_cache_seqlens = cache_seqlens is not None + + # TODO: enable quantization + _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + Out_splitK=out_splitk, + Metadata=metadata, + K_new = k_new, + V_new = v_new, + Cache_seqlens=cache_seqlens, + Cache_batch_idx=cache_batch_idx, + Alibi_slopes=alibi_slopes, + **_strides(q, "qz", "qm", "qg", "qh", "qd"), + **_strides(k, "kz", "kn", "kg", "kh", "kd"), + **_strides(v, "vz", "vn", "vg", "vh", "vd"), + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), + **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), + **_strides(alibi_slopes, "az", "ah"), + Z=batch_size, + H_q=heads_per_group_q, + H_kv=heads_per_group_k, + G_q=n_group_q, + N_CTX_Q=seqlen_q, + N_CTX_K=seqlen_k, + N_CTX_NEW=k_new.shape[1] if new_kv else None, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_k, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, + USE_CACHE_SEQLENs=use_cache_seqlens, + USE_CACHE_BATCH_IDX=cache_batch_idx is not None, + NEW_KV=new_kv, + IS_GQA=is_gqa, + IS_CAUSAL=causal, + USE_ALIBI=False if alibi_slopes is None else True, + num_warps=num_warps, + num_stages=1, + ) + + out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert dim_padded % k_block_num == 0 + k_block_size = dim_padded // k_block_num + grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) + + _splitK_reduce[grid]( + out_splitk, + metadata, + out, + lse, + **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=seqlen_q_ceil, + BLOCK_SIZE=k_block_size, + G=n_group_q, + H=heads_per_group_q, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + IS_CAUSAL=causal, + num_warps=4) + + lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) + if q.ndim == 4: + # BMGHK -> BMHK + assert n_group_q == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if seqlen_k == 0: + out.zero_() + out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() + + # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q + if original_layout == "bshd": + # out=out.transpose(1, 2).contiguous() # this screws up heads and data. + # the data is laid out properly. Just need to reshape dims + out = out.reshape(batch_size, seqlen_q, -1, dim_padded) + + return out.narrow(-1, 0, dim_k), lse diff --git a/comfy/flash_attn_triton_amd/fwd_prefill.py b/comfy/flash_attn_triton_amd/fwd_prefill.py new file mode 100644 index 000000000..865e13a8e --- /dev/null +++ b/comfy/flash_attn_triton_amd/fwd_prefill.py @@ -0,0 +1,634 @@ +import torch +import triton +import triton.language as tl +from comfy.flash_attn_triton_amd.utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, AUTOTUNE + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): # pylint: disable=unused-argument + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. +@triton.jit +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, # pylint: disable=unused-argument + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr): + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + if PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + + # -- compute qk ---- + qk += tl.dot(q, k) + qk_scaled = qk * SM_SCALE + if RETURN_SCORES: + score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + tl.store(score_ptrs, qk_scaled, mask=score_mask) + + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + qk_scaled += bias + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + qk_scaled += alibi_block + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) + + # scale and subtract max + q_shifted = qk_scaled - m_ij[:, None] + if RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask) + + # Compute scaled QK and softmax probabilities + if USE_EXP2: + p = tl.math.exp2(q_shifted * RCP_LN2) + else: + p = tl.math.exp(q_shifted) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask) + p = tl.where(keep, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + tl.store(exp_scores_ptrs, p, mask=exp_score_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff = m_i - m_ij + if USE_EXP2: + alpha = tl.math.exp2(m_diff * RCP_LN2) + else: + alpha = tl.math.exp(m_diff) + acc = acc * alpha[:, None] + v = None + if not PRE_LOAD_V: + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if RETURN_SCORES: + score_ptrs += BLOCK_N + scores_scaled_shifted_ptrs += BLOCK_N + exp_scores_ptrs += BLOCK_N + return acc, l_i, m_i + + +def get_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_autotune_configs(): + if AUTOTUNE: + if is_rdna(): + return get_rdna_autotune_configs() + elif is_cdna(): + return get_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + else: + return [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "VARLEN", + "HQ", + "HK", + ] + + +autotune_configs, autotune_keys = get_autotune_configs() + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + # use_cuda_graph=True, +) +@triton.jit +def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, # pylint: disable=unused-argument + stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + # print("cu_seqlens_q_start:", cu_seqlens_q_start) + + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + l_ptrs = l_offset + offs_m * stride_lse_m + + l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) + + # mask_m_offsets = start_m + tl.arange(0, BLOCK_M) + # lse_mask = mask_m_offsets < causal_start_idx + # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + n_extra_tokens = 0 + # print("n_extra_tokens:", n_extra_tokens) + # print("seqlen_k:", seqlen_k) + # print("BLOCK_N:", BLOCK_N) + # return + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if RETURN_SCORES: + scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm + score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + + scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm + scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + + exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm + exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + else: + score_ptrs = None + scores_scaled_shifted_ptrs = None + exp_scores_ptrs = None + + if ENABLE_DROPOUT: + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + exp_scores_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if RETURN_SCORES: + score_ptrs += n_full_blocks * BLOCK_N + scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N + exp_scores_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, + IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z: tl.tensor = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + + # write back LSE(Log Sum Exponents), the log of the normalization constant + l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + l_ptrs = l_offset + offs_m * stride_lse_m + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + LN2: tl.constexpr = 0.6931471824645996 + # compute log-sum-exp in base 2 units + mi_base2 = m_i * RCP_LN2 + softmax_lse = mi_base2 + tl.math.log2(l_i) + # convert back to natural units + softmax_lse *= LN2 + else: + softmax_lse = m_i + tl.math.log(l_i) + + if IS_CAUSAL: + # zero out nans caused by -infs when doing causal + lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx + softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) + + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) # the log of the normalization constant + else: + tl.store(l_ptrs, softmax_lse) # the log of the normalization constant + + # write back O + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + +def attention_prefill_forward_triton_impl( + q, + k, + v, + o, + sm_scale, + alibi_slopes, + causal, + bias, + dropout_p, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + return_scores, + use_exp2): + # check if varlen + is_varlen = layout == "thd" + + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if bias is not None: + assert bias.numel() < 2**31 + + batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) # pylint: disable=unused-variable + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + + grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) # pylint: disable=unnecessary-lambda-assignment + + if return_scores: + scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + dtype=torch.float32) + scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + dtype=torch.float32) + scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3)) + else: + scores = None + scores_scaled_shifted = None + scores_strides = (0, 0 , 0 , 0) + + # exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + if return_scores: + exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + dtype=torch.float32) + else: + exp_scores = None + + # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) + if is_varlen: + softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) + stride_lse_m, stride_lse_h = softmax_lse.stride() + stride_lse_z = 0 + else: + softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2), + bias.stride(3)) + else: + bias_strides = (0, 0, 0, 0) + + if alibi_slopes is not None: + alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + + attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores, + scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes, + HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, + BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, + USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores) + + return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted diff --git a/comfy/flash_attn_triton_amd/fwd_ref.py b/comfy/flash_attn_triton_amd/fwd_ref.py new file mode 100644 index 000000000..e2a25b11a --- /dev/null +++ b/comfy/flash_attn_triton_amd/fwd_ref.py @@ -0,0 +1,258 @@ +import math +import torch + + +def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): + # Compute attention scores + attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) + + # Scale scores + attention_scaled_scores = sm_scale * attention_scores + + # Apply causal mask if necessary + if causal: + L_q, L_k = q.shape[1], k.shape[1] + row_idx = torch.arange(L_q, device=q.device).unsqueeze(1) + col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) + col_offset = L_q-L_k + causal_mask = row_idx >= (col_offset + col_idx) + # set -inf to places the causal mask is false + attention_scaled_scores = attention_scaled_scores.masked_fill( + torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') + ) + + + # Compute max for numerical stability + max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0] + if causal: + # Replace -inf in max_scores with zeros to avoid NaN in subtraction + max_scores = torch.where( + torch.isinf(max_scores), torch.zeros_like(max_scores), max_scores + ) + + # Shift scores + attention_shifted_scaled_scores = attention_scaled_scores - max_scores + + # Exponentiate + if use_exp2: + RCP_LN = 1 / math.log(2) + exp_scores = torch.exp2(RCP_LN * attention_shifted_scaled_scores) + else: + exp_scores = torch.exp(attention_shifted_scaled_scores) + + # Sum of exponentials + sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True) + if causal: + # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly + sum_exp_scores = torch.where( + sum_exp_scores == 0, + torch.ones_like(sum_exp_scores), + sum_exp_scores + ) + + # Compute softmax probabilities + softmax = exp_scores / sum_exp_scores + + # Compute log-sum-exp + if use_exp2: + LN2 = math.log(2) + RCP_LN = 1 / math.log(2) + max_scores_base2 = max_scores * RCP_LN + softmax_lse_base2 = max_scores_base2 + torch.log2(sum_exp_scores) + softmax_lse = softmax_lse_base2 * LN2 + softmax_lse.squeeze_(-1) + else: + softmax_lse = max_scores + torch.log(sum_exp_scores) + softmax_lse = softmax_lse.squeeze(-1) + + # Compute output + o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16) + + return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + +def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2): + """Compute reference output and softmax_lse using PyTorch's built-in function""" + + # Ensure the layout is 'bhsd' + if layout == "bshd": + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + elif layout != "bhsd": + raise ValueError(f"Unknown layout {layout}") + + # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format + batch_size, num_heads, seq_len_q, head_dim = q.shape + seq_len_k = k.shape[2] + + # Merge batch and heads dimensions + q = q.reshape(batch_size * num_heads, seq_len_q, head_dim) + k = k.reshape(batch_size * num_heads, seq_len_k, head_dim) + v = v.reshape(batch_size * num_heads, seq_len_k, head_dim) + + # Call the core attention function + o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl( + q, k, v, sm_scale, causal, use_exp2 + ) + + # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim] + o = o.reshape(batch_size, num_heads, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q) + exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) + softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k) + attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) + attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) + attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) + + # Restore original layout if necessary + if layout == "bshd": + o = o.transpose(1, 2) + + return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + +def attention_varlen_forward_pytorch_ref_impl( + q, + k, + v, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, max_seqlen_k, # pylint: disable=unused-argument + use_exp2 +): + # Ensure the layout is 'thd' + if layout != 'thd': + raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") + + batch_size = cu_seqlens_q.shape[0] - 1 + num_heads = q.shape[1] + head_dim = q.shape[2] + + # Pre-allocate outputs + total_L_q = q.shape[0] + total_L_k = k.shape[0] # pylint: disable=unused-variable + + o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device) + softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device) + + for i in range(batch_size): + # Get the start and end indices for the current sequence + start_q = cu_seqlens_q[i].item() + end_q = cu_seqlens_q[i + 1].item() + start_k = cu_seqlens_k[i].item() + end_k = cu_seqlens_k[i + 1].item() + + # Extract q_i, k_i, v_i + q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] + + # Permute to [num_heads, L_q_i, head_dim] + q_i = q_i.permute(1, 0, 2) + k_i = k_i.permute(1, 0, 2) + v_i = v_i.permute(1, 0, 2) + + # Call the core attention function for this sequence + ( + o_i, + softmax_lse_i, + exp_scores_i, + softmax_i, + attention_shifted_scaled_scores_i, + attention_scaled_scores_i, + attention_scores_i, + ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2) + + # Convert back to 'thd' layout and float16 + o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim] + + # Place outputs in pre-allocated tensors + o[start_q:end_q, :, :] = o_i + softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads] + + # For variable-sized outputs, map them into the preallocated tensors + # exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i] + exp_scores_i = exp_scores_i.permute(1, 0, 2) + softmax_i = softmax_i.permute(1, 0, 2) + attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2) + attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2) + attention_scores_i = attention_scores_i.permute(1, 0, 2) + + return ( + o, + softmax_lse, + None, + None, + None, + None, + None, + ) + + +def attention_forward_pytorch_ref_impl( + q, + k, + v, + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + use_exp2 + ): + # compute reference + if layout == "thd": + ( + o_ref, + softmax_lse_ref, + exp_scores_ref, + softmax_ref, + attention_shifted_scaled_scores_ref, + attention_scaled_scores_ref, + attention_scores_ref, + ) = attention_varlen_forward_pytorch_ref_impl( + q.clone(), + k.clone(), + v.clone(), + sm_scale, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + use_exp2, + ) + else: + ( + o_ref, + softmax_lse_ref, + exp_scores_ref, + softmax_ref, + attention_shifted_scaled_scores_ref, + attention_scaled_scores_ref, + attention_scores_ref, + ) = attention_vanilla_forward_pytorch_ref_impl( + q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2 + ) + + return ( + o_ref, + softmax_lse_ref, + exp_scores_ref, + softmax_ref, + attention_shifted_scaled_scores_ref, + attention_scaled_scores_ref, + attention_scores_ref, + ) + + +def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) diff --git a/comfy/flash_attn_triton_amd/interface_fa.py b/comfy/flash_attn_triton_amd/interface_fa.py new file mode 100644 index 000000000..763f496b5 --- /dev/null +++ b/comfy/flash_attn_triton_amd/interface_fa.py @@ -0,0 +1,394 @@ +import os +import torch +from comfy.flash_attn_triton_amd.fwd_prefill import attention_prefill_forward_triton_impl +from comfy.flash_attn_triton_amd.bwd_prefill import attention_prefill_backward_triton_impl +from comfy.flash_attn_triton_amd.fwd_decode import attention_decode_forward_triton_impl +from comfy.flash_attn_triton_amd.fwd_ref import attention_forward_pytorch_ref_impl +from comfy.flash_attn_triton_amd.bwd_ref import attention_backward_pytorch_ref_impl +from comfy.flash_attn_triton_amd.utils import MetaData, get_shape_from_layout + + +USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') + + +def fwd(q, + k, + v, + o, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, window_size_right, softcap, # pylint: disable=unused-argument + return_softmax, + gen_ # pylint: disable=unused-argument +): + if dropout_p != 0.0: + raise ValueError("dropout is not supported on AMD's Triton Backend yet") + + if o is None: + o = torch.empty_like(q) + + # Setup metadata + metadata = MetaData(sm_scale=softmax_scale) + metadata.max_seqlens_q = q.shape[1] + metadata.max_seqlens_k = k.shape[1] + metadata.layout = "bshd" + if return_softmax: + metadata.return_scores = True + + batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) # pylint: disable=unused-variable + + if causal: + metadata.need_causal() + + if alibi_slopes is not None: + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + if dropout_p > 0.0: + metadata.need_dropout(dropout_p, return_softmax) + + # Check arguments + metadata.check_args(q, k, v, o) + if USE_REF: + (output, + softmax_lse, + exp_scores, + _, + _, + _, + _) = attention_forward_pytorch_ref_impl( + q, + k, + v, + metadata.sm_scale, + metadata.causal, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.use_exp2) + o.copy_(output) + else: + (_, + softmax_lse, + exp_scores, + _, + _, + _, + _, + _, + _) = attention_prefill_forward_triton_impl( + q, + k, + v, + o, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + metadata.bias, + metadata.dropout_p, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.return_scores, + metadata.use_exp2) + + return o, softmax_lse, exp_scores, None + + +def bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument +): + if dropout_p != 0.0: + raise ValueError("dropout is not supported on AMD yet") + + if USE_REF: + dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + causal, + "bshd", + None, + None, + None, + None, + False, + ) + dq.copy_(dq_ref) + dk.copy_(dk_ref) + dv.copy_(dv_ref) + delta = delta_ref + else: + dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "bshd", + None, + None, + None, + None, + False, + ) + delta = delta_triton + + return dq, dk, dv, delta + + +def varlen_fwd( + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, leftpad_k, block_table_, # pylint: disable=unused-argument + alibi_slopes,\ + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, # pylint: disable=unused-argument + causal, + window_size_left, window_size_right, softcap, # pylint: disable=unused-argument + return_softmax, + gen_ # pylint: disable=unused-argument +): + if dropout_p != 0.0: + raise ValueError("dropout is not supported on AMD's Triton Backend yet") + + if o is None: + o = torch.empty_like(q) + + # Setup metadata + metadata = MetaData(sm_scale=softmax_scale) + if return_softmax: + metadata.return_scores = True + metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata + + # get shapes + batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # pylint: disable=unused-variable + + if causal: + metadata.need_causal() + + if alibi_slopes is not None: + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + if dropout_p > 0.0: + metadata.need_dropout(dropout_p, return_softmax) + + # Check arguments + metadata.check_args(q, k, v, o) + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + if USE_REF: + (output, + softmax_lse, + exp_scores, + _, + _, + _, + _) = attention_forward_pytorch_ref_impl( + q, + k, + v, + metadata.sm_scale, + metadata.causal, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.use_exp2) + o.copy_(output) + else: + (_, + softmax_lse, + exp_scores, + _, + _, + _, + _, + _, + _) = attention_prefill_forward_triton_impl( + q, + k, + v, + o, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + metadata.bias, + metadata.dropout_p, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.return_scores, + metadata.use_exp2) + + return o, softmax_lse, exp_scores, None + + +def varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, # pylint: disable=unused-argument + causal, + window_size_left, window_size_right, softcap, deterministic, gen_, rng_state, # pylint: disable=unused-argument +): + if dropout_p != 0.0: + raise ValueError("dropout is not supported on AMD yet") + + if USE_REF: + dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + False, + ) + dq.copy_(dq_ref) + dk.copy_(dk_ref) + dv.copy_(dv_ref) + delta = delta_ref + else: + dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( # pylint: disable=unused-variable + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "thd", + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + False, + ) + delta = delta_triton + + return dq, dk, dv, delta + + +def fwd_kvcache( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, rotary_sin, # pylint: disable=unused-argument + cache_batch_idx, + cache_leftpad, block_table, # pylint: disable=unused-argument + alibi_slopes, + out, + softmax_scale, + causal, + window_size_left, window_size_right, softcap, rotary_interleaved, num_splits, # pylint: disable=unused-argument +): + if out is None: + out = torch.empty_like(q) + + # fill metadata + metadata = MetaData(sm_scale=softmax_scale) + metadata.layout = "bshd" + metadata.max_seqlens_q = q.shape[1] + metadata.max_seqlens_k = k_cache.shape[1] + metadata.cache_seqlens = cache_seqlens + metadata.cache_batch_idx = cache_batch_idx + + if k is not None and v is not None: + metadata.new_kv = True + metadata.seqlen_new = k.shape[1] + metadata.k_new = k + metadata.v_new = v + + if causal: + metadata.need_causal() + + if alibi_slopes is not None: + batch, _ , nheads_q, _= q.shape + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + # launch kernel + # TODO: pass output as an arg. Maybe we are copying output which is causing slow down + output, softmax_lse = attention_decode_forward_triton_impl( + q, + k_cache, + v_cache, + metadata.sm_scale, + metadata.causal, + metadata.alibi_slopes, + metadata.layout, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.new_kv, + metadata.k_new, + metadata.v_new, + ) + return output, softmax_lse diff --git a/comfy/flash_attn_triton_amd/utils.py b/comfy/flash_attn_triton_amd/utils.py new file mode 100644 index 000000000..3a5432c68 --- /dev/null +++ b/comfy/flash_attn_triton_amd/utils.py @@ -0,0 +1,280 @@ +import os +import torch +import triton + + +AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') +PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') + + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + layout = None + cache_seqlens = None + cache_batch_idx = None + new_kv = False + seqlen_new = None + k_new = None + v_new = None + dropout_p, return_scores= 0.0, False + # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. + use_exp2 = False + + def __repr__(self) -> str: + return (f"MetaData(\n" + f" sm_scale={self.sm_scale},\n" + f" cu_seqlens_q={self.cu_seqlens_q},\n" + f" cu_seqlens_k={self.cu_seqlens_k},\n" + f" max_seqlens_q={self.max_seqlens_q},\n" + f" max_seqlens_k={self.max_seqlens_k},\n" + f" bias={self.bias},\n" + f" alibi_slopes={self.alibi_slopes},\n" + f" causal={self.causal},\n" + f" num_contexts={self.num_contexts},\n" + f" varlen={self.varlen},\n" + f" layout={self.layout},\n" + f" cache_seqlens={self.cache_seqlens},\n" + f" cache_batch_idx={self.cache_batch_idx},\n" + f" new_kv={self.new_kv},\n" + f" seqlen_new={self.seqlen_new},\n" + f" k_new={self.k_new},\n" + f" v_new={self.v_new},\n" + f" dropout_p={self.dropout_p},\n" + f" return_scores={self.return_scores}\n" + f")") + + def __init__(self, sm_scale=1.0): + self.sm_scale = sm_scale + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): # pylint: disable=unused-argument + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def need_dropout(self, dropout_p, return_scores): + self.dropout_p = dropout_p + self.return_scores = return_scores + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) # pylint: disable=unused-variable + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + # assert not self.return_scores + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False): + torch.manual_seed(20) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + else: + assert False, f'Got unsupported tensor layout: {layout}' + + q = None + k = None + v = None + + if DEBUG_INPUT: + if layout == "bhsd": + q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + elif layout == "bshd": + q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_() + k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + else: + q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True) + k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) + v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) + + if DEBUG_INPUT: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout + return q, k, v, input_metadata + + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False): + torch.manual_seed(20) + + # Random or equal sequence lengths based on 'equal_seqlens' flag + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) + else: + seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32) + seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32) + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)]) + cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32) + cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32) + + # Total lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + + if DEBUG_INPUT: + # Initialize q, k, v with deterministic values + q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1) + q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_() + k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) + k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() + v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) + v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() + sm_scale = 1 + else: + # Initialize q, k, v with random values + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() + sm_scale = D_HEAD ** -0.5 + + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): + if layout == 'bhsd': + batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape + batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape + elif layout == 'bshd': + batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape + batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape + elif layout == 'thd': + batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] # pylint: disable=self-assigning-variable + batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] # pylint: disable=self-assigning-variable + else: + assert False, "Got unsupported layout." + + # assert + assert batch_q == batch_k + assert head_size_q == head_size_k + + return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k + + +def get_strides_from_layout(q, k, v, o, layout): + if layout == 'thd': + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif layout == 'bhsd': + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif layout == 'bshd': + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return q_strides, k_strides, v_strides, o_strides + + +def get_padded_headsize(size): + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + return padded_d_model + + +def _strides(x: torch.Tensor, *stride_names: str): + if x is None: + return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} + + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +def get_input_shapes(): + cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] + return cases + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', + 'gfx90a', 'gfx908') + + +def is_rdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101", + "gfx1102", "gfx1200", "gfx1201")