mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
83 lines
3.4 KiB
Python
83 lines
3.4 KiB
Python
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
@triton.jit
|
|
def quant_per_block_int8_kernel(Input, Output, Scale, L,
|
|
stride_iz, stride_ih, stride_in,
|
|
stride_oz, stride_oh, stride_on,
|
|
stride_sz, stride_sh,
|
|
sm_scale,
|
|
C: tl.constexpr, BLK: tl.constexpr):
|
|
off_blk = tl.program_id(0)
|
|
off_h = tl.program_id(1)
|
|
off_b = tl.program_id(2)
|
|
|
|
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
|
offs_k = tl.arange(0, C)
|
|
|
|
input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
|
|
output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
|
|
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
|
|
|
|
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
|
x = x.to(tl.float32)
|
|
x *= sm_scale
|
|
scale = tl.max(tl.abs(x)) / 127.
|
|
x_int8 = x / scale
|
|
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
|
x_int8 = x_int8.to(tl.int8)
|
|
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
|
tl.store(scale_ptrs, scale)
|
|
|
|
def per_block_int8(q, k, BLKQ=32, BLKK=16, sm_scale=None, tensor_layout="HND"):
|
|
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
|
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
|
|
|
if tensor_layout == "HND":
|
|
b, h_qo, qo_len, head_dim = q.shape
|
|
_, h_kv, kv_len, _ = k.shape
|
|
|
|
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
|
|
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
|
|
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
|
|
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
|
|
elif tensor_layout == "NHD":
|
|
b, qo_len, h_qo, head_dim = q.shape
|
|
_, kv_len, h_kv, _ = k.shape
|
|
|
|
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
|
|
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
|
|
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
|
|
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
|
|
else:
|
|
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
|
|
|
|
q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ, 1), device=q.device, dtype=torch.float32)
|
|
k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK, 1), device=q.device, dtype=torch.float32)
|
|
|
|
if sm_scale is None:
|
|
sm_scale = head_dim**-0.5
|
|
|
|
grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
|
|
quant_per_block_int8_kernel[grid](
|
|
q, q_int8, q_scale, qo_len,
|
|
stride_bz_q, stride_h_q, stride_seq_q,
|
|
stride_bz_qo, stride_h_qo, stride_seq_qo,
|
|
q_scale.stride(0), q_scale.stride(1),
|
|
sm_scale=(sm_scale * 1.44269504),
|
|
C=head_dim, BLK=BLKQ
|
|
)
|
|
|
|
grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
|
|
quant_per_block_int8_kernel[grid](
|
|
k, k_int8, k_scale, kv_len,
|
|
stride_bz_k, stride_h_k, stride_seq_k,
|
|
stride_bz_ko, stride_h_ko, stride_seq_ko,
|
|
k_scale.stride(0), k_scale.stride(1),
|
|
sm_scale=1.0,
|
|
C=head_dim, BLK=BLKK
|
|
)
|
|
|
|
return q_int8, q_scale, k_int8, k_scale
|