diff --git a/comfy/customzluda/sa/attn_qk_int8_per_block.py b/comfy/customzluda/sa/attn_qk_int8_per_block.py new file mode 100644 index 000000000..0ad12ff20 --- /dev/null +++ b/comfy/customzluda/sa/attn_qk_int8_per_block.py @@ -0,0 +1,164 @@ +import torch, math +import triton +import triton.language as tl + +# Autotune Here +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'STAGE':S, 'waves_per_eu':wpe}, num_warps=nw, num_stages=ns) \ + for BM in [32]\ + for BN in [16]\ + for nw in[2, 4]\ + for ns in [1]\ + for S in [1]\ + for wpe in [3,4] +] + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + BLOCK_AREA = BLOCK_M * BLOCK_N + + # do not keep too high block area, any higher doesnt seem to help for navi21 + if (BLOCK_AREA > 1024): + return False + + # do not keep 'mirror image' configs (ie keep [64,32] and discard [32,64]) + if (BLOCK_M < BLOCK_N): + return False + + # do not keep skinny sizes for now + if (BLOCK_M//BLOCK_N >= 8): + return False + + # do not keep configs where num_warps is too high or low + if (BLOCK_AREA >= 1024 and conf.num_warps != 2): + return False + if (BLOCK_AREA >= 2048 and conf.num_warps != 4): + return False + + return True + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, + K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + ): + lo, hi = 0, kv_len + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask = k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + acc = acc * alpha[:, None] + + v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + + acc += tl.dot(p, v, out_dtype=tl.float32) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i + +@triton.autotune( + list(filter(keep, configs)), + key=['qo_len', 'kv_len', 'h_qo'] +) +@triton.jit +def _attn_fwd(Q, K, V, Q_scale, K_scale, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_vz, stride_vh, stride_vn, + stride_oz, stride_oh, stride_on, + qo_len, kv_len, H: tl.constexpr, num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + start_m = tl.program_id(0) + + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :] + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] + O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :] + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) + +def forward(q, k, v, q_scale, k_scale, tensor_layout="HND", output_dtype=torch.float16): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 1 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = lambda META: (triton.cdiv(qo_len, META['BLOCK_M']), h_qo, b) + _attn_fwd[grid]( + q, k, v, q_scale, k_scale, o, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_v, stride_h_v, stride_seq_v, + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, + h_qo, num_kv_groups, + HEAD_DIM=HEAD_DIM_K, + ) + return o \ No newline at end of file diff --git a/comfy/customzluda/sa/attn_qk_int8_per_block_causal.py b/comfy/customzluda/sa/attn_qk_int8_per_block_causal.py new file mode 100644 index 000000000..4d3be8cb0 --- /dev/null +++ b/comfy/customzluda/sa/attn_qk_int8_per_block_causal.py @@ -0,0 +1,148 @@ +import torch, math +import triton +import triton.language as tl + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, + K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + ): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += lo // BLOCK_N + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask = k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + acc = acc * alpha[:, None] + + v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + + acc += tl.dot(p, v, out_dtype=tl.float32) # zlp + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, m_i + +@triton.jit +def _attn_fwd(Q, K, V, Q_scale, K_scale, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_vz, stride_vh, stride_vn, + stride_oz, stride_oh, stride_on, + qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + start_m = tl.program_id(0) + + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :] + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] + O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :] + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n + ) + + acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 2, offs_m, offs_n + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) + +def forward(q, k, v, q_scale, k_scale, tensor_layout="HND", output_dtype=torch.float16): + BLOCK_M = 32 #zlp + BLOCK_N = 16 #zlp + stage = 3 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention" + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b ) + _attn_fwd[grid]( + q, k, v, q_scale, k_scale, o, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_v, stride_h_v, stride_seq_v, + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, + h_qo, num_kv_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4) + return o \ No newline at end of file