From 44cac886c4505e7a421a28e62f7495e82aad613f Mon Sep 17 00:00:00 2001 From: patientx Date: Thu, 15 May 2025 13:54:47 +0300 Subject: [PATCH] Create quant_per_block.py --- comfy/customzluda/sa/quant_per_block.py | 82 +++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 comfy/customzluda/sa/quant_per_block.py diff --git a/comfy/customzluda/sa/quant_per_block.py b/comfy/customzluda/sa/quant_per_block.py new file mode 100644 index 000000000..fe86cc45f --- /dev/null +++ b/comfy/customzluda/sa/quant_per_block.py @@ -0,0 +1,82 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_per_block_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + sm_scale, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127. + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_block_int8(q, k, BLKQ=32, BLKK=16, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ, 1), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK, 1), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + sm_scale=(sm_scale * 1.44269504), + C=head_dim, BLK=BLKQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + sm_scale=1.0, + C=head_dim, BLK=BLKK + ) + + return q_int8, q_scale, k_int8, k_scale