diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index f1c41f1204b..d6c72a14e73 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -1,8 +1,26 @@ +from typing import Tuple, Union + import torch from sglang.srt.lora.lora import LoraBatchInfo +def get_fuse_output_scaling_add_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + +def get_fuse_qkv_lora_b_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + class BaseLoraBackend: """Base class for different Lora backends. Each backend has its own implementation of Lora kernels. @@ -10,33 +28,41 @@ class BaseLoraBackend: Args: name: name of backend batch_info: information of current batch for use + fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward, + and the operation of scaling and adding will be fused into kernel """ def __init__(self, name: str, batch_info: LoraBatchInfo = None): self.name = name self.batch_info = batch_info + self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name) + self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name) - def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: """Run segment Gemm of lora a modules with current backend. The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. Args: x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank - Usually input_dim is much larger than r + usually input_dim is much larger than r Returns: result with shape (s, r) """ pass - def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: """Run segment Gemm of lora b modules with current backend. The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. Args: x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank weights: a set of lora weights with shape (num_lora, output_dim, r) - Usually output_dim is much larger than r + usually output_dim is much larger than r Returns: result with shape (s, output_dim) """ @@ -46,17 +72,20 @@ def run_qkv_lora( self, x: torch.Tensor, qkv_lora_a: torch.Tensor, - q_lora_b: torch.Tensor, - kv_lora_b: torch.Tensor, + qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], + *args, + **kwargs ) -> torch.Tensor: """Run the lora pass for QKV Layer. Args: x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim) - q_lora_b: lora_b module for q, with shape (1, num_lora, output_dim_q, r) - kv_lora_b: lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) - + qkv_lora_b: lora_b module for qkv. + If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r) + If passed in as a tuple of two tensors containing: + a lora_b module for q, with shape (1, num_lora, output_dim_q, r) + and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) Returns: result with shape (s, output_dim_q + 2 * output_dim_kv) """ diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py index 88136555ad4..5374a3e0a6b 100644 --- a/python/sglang/srt/lora/backend/flashinfer_backend.py +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch from flashinfer import SegmentGEMMWrapper @@ -15,7 +17,9 @@ def __init__(self, name: str, batch_info: LoraBatchInfo = None): workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) - def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: return self.segment_gemm.run( x=x, @@ -26,7 +30,9 @@ def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tens weight_indices=self.batch_info.weight_indices, ) - def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: return self.segment_gemm.run( x=x, @@ -41,13 +47,15 @@ def run_qkv_lora( self, x: torch.Tensor, qkv_lora_a: torch.Tensor, - q_lora_b: torch.Tensor, - kv_lora_b: torch.Tensor, + qkv_lora_b: Tuple[torch.Tensor], + *args, + **kwargs, ) -> torch.Tensor: # Shape of lora_a_output: (s, 3 * r) lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) + q_lora_b, kv_lora_b = qkv_lora_b lora_rank = kv_lora_b.shape[-1] output_dim_q = q_lora_b.shape[-2] output_dim_kv = kv_lora_b.shape[-2] @@ -57,16 +65,17 @@ def run_qkv_lora( dtype=x.dtype, ) - # FIXME parallelize qkv # q lora_output[:, :output_dim_q] = self.run_lora_b_sgemm( x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0] ) # kv - lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = self.run_lora_b_sgemm( - x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), - weights=kv_lora_b[0], + lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = ( + self.run_lora_b_sgemm( + x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), + weights=kv_lora_b[0], + ) ) lora_output[ diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py index 34db604e8ad..357040bf9d9 100644 --- a/python/sglang/srt/lora/backend/triton_backend.py +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -1,6 +1,4 @@ import torch -import triton -import triton.language as tl from sglang.srt.lora.backend import BaseLoraBackend from sglang.srt.lora.lora import LoraBatchInfo @@ -16,27 +14,48 @@ class TritonLoraBackend(BaseLoraBackend): def __init__(self, name: str, batch_info: LoraBatchInfo = None): super().__init__(name, batch_info) - def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: return sgemm_lora_a_fwd(x, weights, self.batch_info) - def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: - return sgemm_lora_b_fwd(x, weights, self.batch_info) + def run_lora_b_sgemm( + self, + x: torch.Tensor, + weights: torch.Tensor, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling) def run_qkv_lora( self, x: torch.Tensor, qkv_lora_a: torch.Tensor, - q_lora_b: torch.Tensor, - kv_lora_b: torch.Tensor, + qkv_lora_b: torch.Tensor, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs ) -> torch.Tensor: # x: (s, input_dim) # qkv_lora_a: (num_lora, 3 * r, input_dim) - # q_lora_b: (1, num_lora, output_dim_q, r) - # kv_lora_b: (2, num_lora, output_dim_kv, r) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + assert isinstance(qkv_lora_b, torch.Tensor) lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info) lora_output = qkv_lora_b_fwd( - lora_a_output, q_lora_b, kv_lora_b, self.batch_info + lora_a_output, + qkv_lora_b, + self.batch_info, + output_offset, + max_qkv_out_dim, + base_output, + scaling, ) return lora_output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index b1db04e7f65..494a385306c 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -153,17 +153,54 @@ def set_lora_info( ): self.set_lora = True self.A_buffer_qkv = A_buffer_qkv - self.B_buffer_q = B_buffer_q - self.B_buffer_kv = B_buffer_kv + + if self.lora_backend.fuse_qkv_lora_b: + assert ( + B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] + ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b" + output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] + + # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) + self.B_buffer_qkv = torch.cat( + (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2 + ).contiguous() + + # Offsets of q/k/v in output dimension + self.output_offset = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=B_buffer_q.device, + ) + # For computing number of launched blocks + self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) + else: + self.B_buffer_qkv = ( + B_buffer_q, + B_buffer_kv, + ) + self.output_offset = None + self.max_qkv_out_dim = None def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: lora_output = self.lora_backend.run_qkv_lora( - x=x, - qkv_lora_a=self.A_buffer_qkv, - q_lora_b=self.B_buffer_q, - kv_lora_b=self.B_buffer_kv, + x, + self.A_buffer_qkv, + self.B_buffer_qkv, + output_offset=self.output_offset, + max_qkv_out_dim=self.max_qkv_out_dim, + base_output=base_output, + scaling=self.scaling, + ) + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - return base_output + lora_output * self.scaling class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -178,11 +215,18 @@ def set_lora_info(self, A_buffer, B_buffer): self.B_buffer = B_buffer def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer) + lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) lora_output = self.lora_backend.run_lora_b_sgemm( - x=lora_a_output, weights=self.B_buffer[0] + lora_a_output, + self.B_buffer[0], + base_output=base_output, + scaling=self.scaling, + ) + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - return base_output + lora_output * self.scaling def forward(self, input_): # duplicate the logic in RowParallelLinear diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 0d0cc9fc6b3..404f3f50700 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -84,7 +84,9 @@ def get_backend_from_name(name): if name in backend_mapping: return backend_mapping[name] - raise Exception(f"No supported lora backend called {name}.") + raise Exception( + f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" + ) def get_layer_id(name): diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py index 14bbcd10592..3e090f4dc37 100644 --- a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -13,8 +13,7 @@ def _qkv_lora_b_kernel( output, # Parameters of size K, # K = R - N_Q, - N_KV, + max_qkv_out_dim, # max(output_q_dim, output_kv_dim) # Strides x_stride_0, x_stride_1, @@ -28,11 +27,14 @@ def _qkv_lora_b_kernel( seg_indptr, weight_indices, # Offsets of q/k/v slice on output dimension - n_indptr, + n_offs, # Meta parameters BLOCK_S: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, ): # This kernel packs 3 sgemms (q/k/v) into a single kernel. @@ -50,11 +52,11 @@ def _qkv_lora_b_kernel( seg_len = tl.load(seg_lens + batch_id) w_index = tl.load(weight_indices + batch_id) seg_start = tl.load(seg_indptr + batch_id) - n_start = tl.load(n_indptr + qkv_id) - n_size = tl.load(n_indptr + qkv_id + 1) - n_start + n_start = tl.load(n_offs + qkv_id) + n_size = tl.load(n_offs + qkv_id + 1) - n_start # The tile in output matrix will have (pid_s, pid_n) as id - num_pid_n = tl.cdiv(max(N_Q, N_KV), BLOCK_N) + num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N) pid_s = pid // num_pid_n pid_n = pid % num_pid_n @@ -92,28 +94,36 @@ def _qkv_lora_b_kernel( w_ptrs += BLOCK_K * w_stride_2 # Store result to output matrix + partial_sum *= scaling partial_sum = partial_sum.to(x.dtype.element_ty) output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) tl.store(output_ptr, partial_sum, mask=output_mask) def qkv_lora_b_fwd( x: torch.Tensor, - q_lora_b: torch.Tensor, - kv_lora_b: torch.Tensor, + qkv_lora_b: torch.Tensor, batch_info: LoraBatchInfo, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, ) -> torch.Tensor: # x: (s, 3 * r) - # q_lora_b: (1, num_lora, output_dim_q, r) - # kv_lora_b: (2, num_lora, output_dim_kv, r) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + # output_offset = [0, output_dim_q, output_dim_q + output_dim_kv, + # output_dim_q + 2 * output_dim_kv] + # max_qkv_out_dim = max(output_dim_q, output_dim_kv) # output: (s, output_dim_q + 2 * output_dim_kv) # Compute lora_output with shape (s, output_dim) as follows: - # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], q_lora_b[0]) + # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], ) # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) # lora_output[:, output_dim_q + output_dim_kv: ] @@ -122,59 +132,51 @@ def qkv_lora_b_fwd( # Get dims s = x.shape[0] input_dim = x.shape[1] - r = q_lora_b.shape[-1] - output_dim_q = q_lora_b.shape[-2] - output_dim_kv = kv_lora_b.shape[-2] - output_dim = output_dim_q + 2 * output_dim_kv + r = qkv_lora_b.shape[-1] + output_dim = qkv_lora_b.shape[-2] assert input_dim == 3 * r + assert output_offset.shape[0] == 4 - # FIXME: replace with autotune BLOCK_S = 16 BLOCK_R = 16 BLOCK_OUT = 64 grid_b = ( triton.cdiv(batch_info.max_len, BLOCK_S) - * triton.cdiv(max(output_dim_q, output_dim_kv), BLOCK_OUT), + * triton.cdiv(max_qkv_out_dim, BLOCK_OUT), 3, # this dimension decides current block computes on q, k or v batch_info.bs, ) - # w_lora_b with shape (num_lora, output_dim_q + 2 * output_dim_kv, r) is passed to kernel - w_lora_b = torch.cat((q_lora_b[0], kv_lora_b[0], kv_lora_b[1]), dim=-2).contiguous() - lora_output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) - n_indptr = torch.tensor( - [ - 0, - output_dim_q, - output_dim_q + output_dim_kv, - output_dim_q + 2 * output_dim_kv, - ], - dtype=torch.int32, - device=x.device, - ) + if base_output is None: + output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True _qkv_lora_b_kernel[grid_b]( x, - w_lora_b, - lora_output, + qkv_lora_b, + output, r, - output_dim_q, - output_dim_kv, + max_qkv_out_dim, x.stride(0), x.stride(1), - w_lora_b.stride(0), - w_lora_b.stride(1), - w_lora_b.stride(2), - lora_output.stride(0), - lora_output.stride(1), + qkv_lora_b.stride(0), + qkv_lora_b.stride(1), + qkv_lora_b.stride(2), + output.stride(0), + output.stride(1), batch_info.seg_lens, batch_info.seg_indptr, batch_info.weight_indices, - n_indptr, + output_offset, BLOCK_S, BLOCK_OUT, BLOCK_R, + fuse_scaling_add, + scaling, ) - return lora_output + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py index fec73f7679e..305bb8c5f0e 100644 --- a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py @@ -110,7 +110,6 @@ def sgemm_lora_a_fwd( assert x.shape[-1] == K # Block shapes - # FIXME: Add autotune BLOCK_S = 16 BLOCK_K = 256 BLOCK_R = 16 diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py index d2961eb0532..c0bc913630c 100644 --- a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -30,6 +30,9 @@ def _sgemm_lora_b_kernel( BLOCK_S: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, ): # x: (s, K), s is the sum of sequence lengths # weights: (num_lora, N, K) @@ -81,16 +84,23 @@ def _sgemm_lora_b_kernel( w_ptrs += BLOCK_K * w_stride_2 # Store result to output matrix + partial_sum *= scaling partial_sum = partial_sum.to(x.dtype.element_ty) output_ptr = (output + seg_start * output_stride_0) + ( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) output_mask = s_offset[:, None] < seg_len + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) tl.store(output_ptr, partial_sum, mask=output_mask) def sgemm_lora_b_fwd( - x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo + x: torch.Tensor, + weights: torch.Tensor, + batch_info: LoraBatchInfo, + base_output: torch.Tensor = None, + scaling: float = 1.0, ) -> torch.Tensor: # x: (s, r) # weights: (num_lora, output_dim, r) @@ -108,7 +118,6 @@ def sgemm_lora_b_fwd( assert x.shape[-1] == R # Block shapes - # FIXME: Add autotune BLOCK_S = 16 BLOCK_R = 16 BLOCK_N = 256 @@ -118,7 +127,13 @@ def sgemm_lora_b_fwd( batch_info.bs, ) - output = torch.empty((S, N), device=x.device, dtype=x.dtype) + if base_output is None: + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + _sgemm_lora_b_kernel[grid]( x, weights, @@ -138,5 +153,7 @@ def sgemm_lora_b_fwd( BLOCK_S, BLOCK_N, BLOCK_R, + fuse_scaling_add, + scaling, ) return output