An example of implementing a dequantization GEMM:
def dequant_matmul(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
for k in T.Pipelined(
T.ceildiv(K, block_K),
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert("int", 8)(
B_local[i, j // 2],
j % 2,
T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct[bx * block_N, by * block_M])
Notes: Dequantize GEMM with magic layout transformations to get optimal performance can be found at project BitBLAS, example kernels can be found at testing/python/kernel/
, detailed explanation and examples is coming soon.