From 0668455fde91f97c1d0cab5b049bbe87396575aa Mon Sep 17 00:00:00 2001 From: Fridge003 Date: Thu, 30 Jan 2025 12:50:18 -0800 Subject: [PATCH] Implement fused kernel for qkv's lora_b --- .../sglang/srt/lora/backend/triton_backend.py | 216 ++++++++++++++++-- 1 file changed, 195 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index e4321a7a2e4..d61cb296b93 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -92,6 +92,100 @@ def _sgemm_kernel( tl.store(output_ptr, partial_sum, mask=output_mask) +@triton.jit +def _qkv_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + K, # K = R + N_Q, + N_KV, + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Offsets of q/k/v slice on output dimension + n_indptr, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # This kernel packs 3 sgemms (q/k/v) into a single kernel. + + # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank + # weights: (num_lora, N_Q + 2 * N_KV, K) + # output: (s, N_Q + 2 * N_KV) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) + batch_id = tl.program_id(axis=2) + qkv_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = tl.load(n_indptr + qkv_id) + n_size = tl.load(n_indptr + qkv_id + 1) - n_start + + # The tile in output matrix will have (pid_s, pid_n) as id + # FIXME: might be replaced by super-grouping + num_pid_n = tl.cdiv(max(N_Q, N_KV), BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + tl.store(output_ptr, partial_sum, mask=output_mask) + + class TritonLoraBackend(BaseLoraBackend): def __init__(self, name: str, batch_info: LoraBatchInfo = None): @@ -123,7 +217,7 @@ def run_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: self.batch_info.bs, ) - output = torch.zeros((S, N), device=x.device, dtype=x.dtype) + output = torch.empty((S, N), device=x.device, dtype=x.dtype) _sgemm_kernel[grid]( x, weights, @@ -159,35 +253,115 @@ def run_qkv_lora( # q_lora_b: (1, num_lora, output_dim_q, r) # kv_lora_b: (2, num_lora, output_dim_kv, r) - # Shape of lora_a_output: (s, 3 * r) - lora_a_output = self.run_sgemm(x=x, weights=qkv_lora_a) + assert x.is_contiguous() + assert qkv_lora_a.is_contiguous() + assert q_lora_b.is_contiguous() + assert kv_lora_b.is_contiguous + assert len(x.shape) == 2 + assert len(qkv_lora_a.shape) == 3 + assert len(q_lora_b.shape) == 4 + assert len(kv_lora_b.shape) == 4 - lora_rank = kv_lora_b.shape[-1] + # Get dims + s = x.shape[0] + input_dim = x.shape[1] output_dim_q = q_lora_b.shape[-2] output_dim_kv = kv_lora_b.shape[-2] - lora_output = torch.empty( - (x.shape[0], output_dim_q + 2 * output_dim_kv), - device=x.device, - dtype=x.dtype, + r = q_lora_b.shape[-1] + middle_dim = 3 * r + output_dim = output_dim_q + 2 * output_dim_kv + assert qkv_lora_a.shape[-2] == middle_dim + assert qkv_lora_a.shape[-1] == input_dim + assert kv_lora_b.shape[-1] == r + + # Compute lora_a_output = sgemm(x, qkv_lora_a) + # shape of lora_a_output: (s, middle_dim) + BLOCK_S = 16 + BLOCK_IN = 64 + BLOCK_R = 16 + + grid_a = ( + triton.cdiv(self.batch_info.max_len, BLOCK_S) + * triton.cdiv(middle_dim, BLOCK_R), + self.batch_info.bs, ) - # FIXME parallelize qkv - # q - lora_output[:, :output_dim_q] = self.run_sgemm( - x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0] + lora_a_output = torch.empty((s, middle_dim), device=x.device, dtype=x.dtype) + _sgemm_kernel[grid_a]( + x, + qkv_lora_a, + lora_a_output, + middle_dim, + input_dim, + x.stride(0), + x.stride(1), + qkv_lora_a.stride(0), + qkv_lora_a.stride(1), + qkv_lora_a.stride(2), + lora_a_output.stride(0), + lora_a_output.stride(1), + self.batch_info.seg_lens, + self.batch_info.seg_indptr, + self.batch_info.weight_indices, + BLOCK_S, + BLOCK_R, + BLOCK_IN, ) - # kv - lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = self.run_sgemm( - x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), - weights=kv_lora_b[0], + # Compute lora_output with shape (s, output_dim) as follows: + # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], q_lora_b[0]) + # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] + # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) + # lora_output[:, output_dim_q + output_dim_kv: ] + # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_OUT = 64 + + grid_b = ( + triton.cdiv(self.batch_info.max_len, BLOCK_S) + * triton.cdiv(max(output_dim_q, output_dim_kv), BLOCK_OUT), + 3, # this dimension decides current block computes on q, k or v + self.batch_info.bs, ) - lora_output[ - :, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv - ] = self.run_sgemm( - x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(), - weights=kv_lora_b[1], + # w_lora_b with shape (num_lora, output_dim_q + 2 * output_dim_kv, r) is passed to kernel + w_lora_b = torch.cat( + (q_lora_b[0], kv_lora_b[0], kv_lora_b[1]), dim=-2 + ).contiguous() + lora_output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) + n_indptr = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=x.device, + ) + + _qkv_lora_b_kernel[grid_b]( + lora_a_output, + w_lora_b, + lora_output, + r, + output_dim_q, + output_dim_kv, + lora_a_output.stride(0), + lora_a_output.stride(1), + w_lora_b.stride(0), + w_lora_b.stride(1), + w_lora_b.stride(2), + lora_output.stride(0), + lora_output.stride(1), + self.batch_info.seg_lens, + self.batch_info.seg_indptr, + self.batch_info.weight_indices, + n_indptr, + BLOCK_S, + BLOCK_OUT, + BLOCK_R, ) return lora_output