diff --git a/benchmark/lora/launch_server.py b/benchmark/lora/launch_server.py index f139f0df6fe..418155dbf5b 100644 --- a/benchmark/lora/launch_server.py +++ b/benchmark/lora/launch_server.py @@ -1,10 +1,10 @@ import argparse import os -NUM_LORAS = 8 +NUM_LORAS = 4 LORA_PATH = { - "base": "mistralai/Mistral-7B-Instruct-v0.3", - "lora": "/home/ying/test_lora", + "base": "meta-llama/Llama-2-7b-hf", + "lora": "winddude/wizardLM-LlaMA-LoRA-7B", } @@ -21,7 +21,8 @@ def launch_server(args): cmd += f"{lora_name}={lora_path} " cmd += f"--disable-radix --disable-cuda-graph " cmd += f"--max-loras-per-batch {args.max_loras_per_batch} " - cmd += f"--max-running-requests {args.max_running_requests}" + cmd += f"--max-running-requests {args.max_running_requests} " + cmd += f"--lora-backend {args.lora_backend}" print(cmd) os.system(cmd) @@ -42,6 +43,11 @@ def launch_server(args): type=int, default=8, ) + parser.add_argument( + "--lora-backend", + type=str, + default="triton", + ) args = parser.parse_args() launch_server(args) diff --git a/benchmark/lora/lora_bench.py b/benchmark/lora/lora_bench.py index 713cbbf76ca..b5af65a7dd7 100644 --- a/benchmark/lora/lora_bench.py +++ b/benchmark/lora/lora_bench.py @@ -183,6 +183,7 @@ async def benchmark( api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + lora_name="dummy", # the lora_name argument will not be used extra_request_body=extra_request_body, ) test_output = await request_func(request_func_input=test_input) @@ -206,6 +207,7 @@ async def benchmark( api_url=api_url, prompt_len=prompt_len, output_len=output_len, + lora_name="dummy", extra_request_body=extra_request_body, ) tasks.append( @@ -255,6 +257,9 @@ async def benchmark( "Output token throughput (tok/s):", metrics.output_throughput ) ) + print( + "{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput) + ) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7e8f4ca0a54..337da1373df 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma * `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929). * `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model. +* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`. ## Kernel backend diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py new file mode 100644 index 00000000000..ed377b4b4ad --- /dev/null +++ b/python/sglang/srt/lora/backend/__init__.py @@ -0,0 +1,8 @@ +from .base_backend import BaseLoraBackend +from .flashinfer_backend import FlashInferLoraBackend +from .triton_backend import TritonLoraBackend + +__all__ = [ + "FlashInferLoraBackend", + "TritonLoraBackend", +] diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py new file mode 100644 index 00000000000..d6c72a14e73 --- /dev/null +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -0,0 +1,95 @@ +from typing import Tuple, Union + +import torch + +from sglang.srt.lora.lora import LoraBatchInfo + + +def get_fuse_output_scaling_add_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + +def get_fuse_qkv_lora_b_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + +class BaseLoraBackend: + """Base class for different Lora backends. + Each backend has its own implementation of Lora kernels. + + Args: + name: name of backend + batch_info: information of current batch for use + fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward, + and the operation of scaling and adding will be fused into kernel + """ + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + self.name = name + self.batch_info = batch_info + self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name) + self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """Run segment Gemm of lora a modules with current backend. + The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. + + Args: + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths + weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank + usually input_dim is much larger than r + Returns: + result with shape (s, r) + """ + pass + + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """Run segment Gemm of lora b modules with current backend. + The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. + + Args: + x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank + weights: a set of lora weights with shape (num_lora, output_dim, r) + usually output_dim is much larger than r + Returns: + result with shape (s, output_dim) + """ + pass + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], + *args, + **kwargs + ) -> torch.Tensor: + """Run the lora pass for QKV Layer. + + Args: + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths + qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim) + qkv_lora_b: lora_b module for qkv. + If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r) + If passed in as a tuple of two tensors containing: + a lora_b module for q, with shape (1, num_lora, output_dim_q, r) + and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) + Returns: + result with shape (s, output_dim_q + 2 * output_dim_kv) + """ + pass + + def set_batch_info(self, batch_info: LoraBatchInfo): + self.batch_info = batch_info diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py new file mode 100644 index 00000000000..5374a3e0a6b --- /dev/null +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -0,0 +1,88 @@ +from typing import Tuple + +import torch +from flashinfer import SegmentGEMMWrapper + +from sglang.srt.lora.backend import BaseLoraBackend +from sglang.srt.lora.lora import LoraBatchInfo + + +class FlashInferLoraBackend(BaseLoraBackend): + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + super().__init__(name, batch_info) + + # Set up SGemm Wrapper from flashinfer + # FIXME wait for flashinfer segment gemm update + workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") + self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + return self.segment_gemm.run( + x=x, + weights=weights, + batch_size=self.batch_info.bs, + weight_column_major=True, + seg_indptr=self.batch_info.seg_indptr, + weight_indices=self.batch_info.weight_indices, + ) + + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + return self.segment_gemm.run( + x=x, + weights=weights, + batch_size=self.batch_info.bs, + weight_column_major=True, + seg_indptr=self.batch_info.seg_indptr, + weight_indices=self.batch_info.weight_indices, + ) + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: Tuple[torch.Tensor], + *args, + **kwargs, + ) -> torch.Tensor: + + # Shape of lora_a_output: (s, 3 * r) + lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) + + q_lora_b, kv_lora_b = qkv_lora_b + lora_rank = kv_lora_b.shape[-1] + output_dim_q = q_lora_b.shape[-2] + output_dim_kv = kv_lora_b.shape[-2] + lora_output = torch.empty( + (x.shape[0], output_dim_q + 2 * output_dim_kv), + device=x.device, + dtype=x.dtype, + ) + + # q + lora_output[:, :output_dim_q] = self.run_lora_b_sgemm( + x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0] + ) + + # kv + lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = ( + self.run_lora_b_sgemm( + x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), + weights=kv_lora_b[0], + ) + ) + + lora_output[ + :, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv + ] = self.run_lora_b_sgemm( + x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(), + weights=kv_lora_b[1], + ) + + return lora_output diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py new file mode 100644 index 00000000000..357040bf9d9 --- /dev/null +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -0,0 +1,61 @@ +import torch + +from sglang.srt.lora.backend import BaseLoraBackend +from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.triton_ops import ( + qkv_lora_b_fwd, + sgemm_lora_a_fwd, + sgemm_lora_b_fwd, +) + + +class TritonLoraBackend(BaseLoraBackend): + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + super().__init__(name, batch_info) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + return sgemm_lora_a_fwd(x, weights, self.batch_info) + + def run_lora_b_sgemm( + self, + x: torch.Tensor, + weights: torch.Tensor, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling) + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: torch.Tensor, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + + # x: (s, input_dim) + # qkv_lora_a: (num_lora, 3 * r, input_dim) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + assert isinstance(qkv_lora_b, torch.Tensor) + + lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info) + lora_output = qkv_lora_b_fwd( + lora_a_output, + qkv_lora_b, + self.batch_info, + output_offset, + max_qkv_out_dim, + base_output, + scaling, + ) + return lora_output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index c8cbe36602b..494a385306c 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -18,8 +18,8 @@ # LoRA layers class inheritance adapted from: # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py - import re +from dataclasses import dataclass import torch from torch import nn @@ -34,14 +34,32 @@ from sglang.srt.model_loader.loader import DefaultModelLoader +@dataclass +class LoraBatchInfo: + # Batch size + bs: int + + # Lengths of each sequence in shape (bs,) + seg_lens: torch.Tensor + + # Indice pointers of each sequence in shape (bs + 1, ) + seg_indptr: torch.Tensor + + # Maximum sequence length of current batch + max_len: int + + # The index of lora adapter used by each sequence, in shape (bs,) + weight_indices: torch.Tensor + + class BaseLayerWithLoRA(nn.Module): - def __init__(self, base_layer, segment_gemm, lora_rank, scaling): + def __init__(self, base_layer, lora_rank, scaling, lora_backend): super().__init__() self.base_layer = base_layer - self.segment_gemm = segment_gemm self.lora_rank = lora_rank self.scaling = scaling self.set_lora = False + self.lora_backend = lora_backend def forward(self, x: torch.Tensor): return self.base_layer.forward(x) @@ -52,17 +70,17 @@ def set_lora_info(self, *args): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling + self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) self.weight = base_layer.weight class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: # TODO @@ -88,136 +106,127 @@ def forward(self, input_: torch.Tensor): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def __init__( - self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) - def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): + def set_lora_info( + self, + A_buffer, + B_buffer, + ): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, - ) - # FIXME + lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer) + + output_dim = base_output.shape[-1] lora_output = torch.empty_like(base_output) - output_dim = lora_output.shape[-1] // 2 - for i in range(2): - left = output_dim * i - right = left + output_dim - lora_output[:, left:right] = self.segment_gemm.run( - x=lora_a_output[ - :, self.lora_rank * i : self.lora_rank * (i + 1) - ].contiguous(), - weights=self.B_buffer[:, left:right, :].contiguous(), - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm( + x=lora_a_output[:, 0 : self.lora_rank].contiguous(), + weights=self.B_buffer[0], + ) + + lora_output[:, output_dim : 2 * output_dim] = ( + self.lora_backend.run_lora_b_sgemm( + x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(), + weights=self.B_buffer[1], ) + ) + return base_output + lora_output * self.scaling class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - def __init__( - self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling + def init__( + self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) def set_lora_info( - self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices + self, + A_buffer_qkv, + B_buffer_q, + B_buffer_kv, ): self.set_lora = True self.A_buffer_qkv = A_buffer_qkv - self.B_buffer_q = B_buffer_q - self.B_buffer_kv = B_buffer_kv - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices + + if self.lora_backend.fuse_qkv_lora_b: + assert ( + B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] + ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b" + output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] + + # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) + self.B_buffer_qkv = torch.cat( + (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2 + ).contiguous() + + # Offsets of q/k/v in output dimension + self.output_offset = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=B_buffer_q.device, + ) + # For computing number of launched blocks + self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) + else: + self.B_buffer_qkv = ( + B_buffer_q, + B_buffer_kv, + ) + self.output_offset = None + self.max_qkv_out_dim = None def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer_qkv, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_output = self.lora_backend.run_qkv_lora( + x, + self.A_buffer_qkv, + self.B_buffer_qkv, + output_offset=self.output_offset, + max_qkv_out_dim=self.max_qkv_out_dim, + base_output=base_output, + scaling=self.scaling, ) - # FIXME parallelize qkv - lora_output = torch.empty_like(base_output) - # q - output_dim_q = self.B_buffer_q.shape[-2] - lora_output[:, :output_dim_q] = self.segment_gemm.run( - x=lora_a_output[:, : self.lora_rank].contiguous(), - weights=self.B_buffer_q, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - # kv - output_dim_kv = self.B_buffer_kv.shape[-2] // 2 - for i in range(2): - left = output_dim_kv * i - right = left + output_dim_kv - lora_output[:, output_dim_q + left : output_dim_q + right] = ( - self.segment_gemm.run( - x=lora_a_output[ - :, self.lora_rank * (i + 1) : self.lora_rank * (i + 2) - ].contiguous(), - weights=self.B_buffer_kv[:, left:right, :].contiguous(), - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, - ) - ) - return base_output + lora_output * self.scaling class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) - def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): + def set_lora_info(self, A_buffer, B_buffer): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) + lora_output = self.lora_backend.run_lora_b_sgemm( + lora_a_output, + self.B_buffer[0], + base_output=base_output, + scaling=self.scaling, ) - lora_output = self.segment_gemm.run( - x=lora_output, - weights=self.B_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - return base_output + lora_output * self.scaling def forward(self, input_): # duplicate the logic in RowParallelLinear @@ -255,7 +264,7 @@ def forward(self, input_): def get_lora_layer( - layer: nn.Module, segment_gemm, lora_rank, scaling + layer: nn.Module, lora_rank, scaling, lora_backend ) -> BaseLayerWithLoRA: supported_layer_types = { # the order matters @@ -267,7 +276,7 @@ def get_lora_layer( } for src_layer_type, lora_layer_type in supported_layer_types.items(): if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck - ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling) + ret = lora_layer_type(layer, lora_rank, scaling, lora_backend) return ret raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") @@ -297,13 +306,14 @@ def offload_from_gpu(self): class LoRAAdapter(nn.Module): - def __init__(self, uid, config, base_hf_config, load_config): + def __init__(self, uid, config, base_hf_config, load_config, lora_backend): super().__init__() self.uid = uid self.config = config assert self.config.hf_config["peft_type"].lower() == "lora" self.base_hf_config = base_hf_config self.load_config = load_config + self.lora_backend = lora_backend self.scaling = self.config.lora_alpha / self.config.r self.layers = nn.ModuleList( @@ -376,20 +386,25 @@ def initialize_weights(self): layer.weights.pop(weight_name) layer.weights.pop(v_name) else: - layer.weights[kv_name] = torch.cat( - ( + layer.weights[kv_name] = torch.stack( + [ layer.weights[weight_name], layer.weights[v_name], - ), - 0, + ], + dim=0, ) layer.weights.pop(weight_name) layer.weights.pop(v_name) elif "gate_proj" in weight_name: up_name = weight_name.replace("gate_proj", "up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") - layer.weights[gate_up_name] = torch.cat( - (layer.weights[weight_name], layer.weights[up_name]), 0 - ) + if "lora_A" in weight_name: + layer.weights[gate_up_name] = torch.cat( + (layer.weights[weight_name], layer.weights[up_name]), 0 + ) + else: + layer.weights[gate_up_name] = torch.stack( + [layer.weights[weight_name], layer.weights[up_name]], dim=0 + ) layer.weights.pop(weight_name) layer.weights.pop(up_name) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 0449e252453..404f3f50700 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -20,16 +20,14 @@ import torch -from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer +from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend +from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_flashinfer_available, replace_submodule logger = logging.getLogger(__name__) -if is_flashinfer_available(): - from flashinfer import SegmentGEMMWrapper - def get_module_name(name): # Fallback solution of mapping from config module name to module name in model class. @@ -77,6 +75,20 @@ def get_stacked_name(name): return params_mapping.get(name, (name, name)) +def get_backend_from_name(name): + backend_mapping = { + "triton": TritonLoraBackend, + "flashinfer": FlashInferLoraBackend, + } + + if name in backend_mapping: + return backend_mapping[name] + + raise Exception( + f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" + ) + + def get_layer_id(name): match = re.search(r"layers\.(\d+)\.", name) if match is None: @@ -93,6 +105,7 @@ def __init__( max_loras_per_batch, load_config, dtype, + lora_backend, ): self.base_model = base_model self.lora_paths = lora_paths @@ -101,8 +114,9 @@ def __init__( self.load_config = load_config self.dtype = dtype - workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") - self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + logger.info(f"Using {lora_backend} as backend of Lora kernels.") + backend_type = get_backend_from_name(lora_backend) + self.lora_backend = backend_type(lora_backend) self.init_loras() self.init_lora_memory_pool() @@ -123,7 +137,7 @@ def get_target_modules(self): def set_lora_module(self, module_name, module): lora_module = get_lora_layer( - module, self.segment_gemm, self.max_lora_dim, self.scaling + module, self.max_lora_dim, self.scaling, self.lora_backend ) replace_submodule(self.base_model, module_name, lora_module) return lora_module @@ -162,7 +176,11 @@ def init_loras(self): self.lora_id[name] = len(self.loras) self.loras.append( LoRAAdapter( - name, self.configs[name], self.base_hf_config, self.load_config + name, + self.configs[name], + self.base_hf_config, + self.load_config, + self.lora_backend, ) ) self.loras[-1].initialize_weights() @@ -226,8 +244,9 @@ def init_lora_memory_pool(self): self.B_buffer[module_B] = [ torch.empty( ( + c, self.max_loras_per_batch, - hidden_dim_B * c, + hidden_dim_B, self.max_lora_dim, ), dtype=self.dtype, @@ -263,7 +282,16 @@ def load_lora(self, uid, buffer_id): else: lora_weight_name = self.get_weight_name(name, 1) if lora_weight_name: - self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) + c = self.loras[-1].get_stacked_multiply(lora_weight_name) + if c > 1: + for j in range(c): + self.B_buffer[lora_weight_name][i][j][buffer_id].copy_( + weights[j] + ) + else: + self.B_buffer[lora_weight_name][i][0][buffer_id].copy_( + weights + ) def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool @@ -292,20 +320,30 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): if cur_uids == set([None]): return - # setup lora in forward modules + # set up batch info shared by all lora moruldes bs = forward_batch.batch_size seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() else torch.ones(bs, device="cuda") ) - # FIXME: reuse the data rather than recompute seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) + max_len = int(torch.max(seg_lens)) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") for i, lora_path in enumerate(forward_batch.lora_paths): weight_indices[i] = self.buffer_id[lora_path] + batch_info = LoraBatchInfo( + bs=bs, + seg_lens=seg_lens, + seg_indptr=seg_indptr, + max_len=max_len, + weight_indices=weight_indices, + ) + self.lora_backend.set_batch_info(batch_info) + + # call set_lora_info for each lora modules for module_name, module in self.lora_modules: layer_id = get_layer_id(module_name) @@ -314,16 +352,10 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): module.set_lora_info( self.A_buffer[weight_name][layer_id], self.B_buffer[weight_name][layer_id], - bs, - seg_indptr, - weight_indices, ) else: module.set_lora_info( self.A_buffer["qkv_proj"][layer_id], self.B_buffer["q_proj"][layer_id], self.B_buffer["kv_proj"][layer_id], - bs, - seg_indptr, - weight_indices, ) diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py new file mode 100644 index 00000000000..efc76bb8b47 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -0,0 +1,5 @@ +from .qkv_lora_b import qkv_lora_b_fwd +from .sgemm_lora_a import sgemm_lora_a_fwd +from .sgemm_lora_b import sgemm_lora_b_fwd + +__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"] diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py new file mode 100644 index 00000000000..3e090f4dc37 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -0,0 +1,182 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _qkv_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + K, # K = R + max_qkv_out_dim, # max(output_q_dim, output_kv_dim) + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Offsets of q/k/v slice on output dimension + n_offs, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, +): + # This kernel packs 3 sgemms (q/k/v) into a single kernel. + + # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank + # weights: (num_lora, N_Q + 2 * N_KV, K) + # output: (s, N_Q + 2 * N_KV) + # N_Q >> K, N_KV >> K + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) + batch_id = tl.program_id(axis=2) + qkv_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = tl.load(n_offs + qkv_id) + n_size = tl.load(n_offs + qkv_id + 1) - n_start + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def qkv_lora_b_fwd( + x: torch.Tensor, + qkv_lora_b: torch.Tensor, + batch_info: LoraBatchInfo, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, +) -> torch.Tensor: + + # x: (s, 3 * r) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + # output_offset = [0, output_dim_q, output_dim_q + output_dim_kv, + # output_dim_q + 2 * output_dim_kv] + # max_qkv_out_dim = max(output_dim_q, output_dim_kv) + # output: (s, output_dim_q + 2 * output_dim_kv) + + # Compute lora_output with shape (s, output_dim) as follows: + # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], ) + # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] + # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) + # lora_output[:, output_dim_q + output_dim_kv: ] + # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) + + # Get dims + s = x.shape[0] + input_dim = x.shape[1] + r = qkv_lora_b.shape[-1] + output_dim = qkv_lora_b.shape[-2] + assert input_dim == 3 * r + assert output_offset.shape[0] == 4 + + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_OUT = 64 + + grid_b = ( + triton.cdiv(batch_info.max_len, BLOCK_S) + * triton.cdiv(max_qkv_out_dim, BLOCK_OUT), + 3, # this dimension decides current block computes on q, k or v + batch_info.bs, + ) + + if base_output is None: + output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + + _qkv_lora_b_kernel[grid_b]( + x, + qkv_lora_b, + output, + r, + max_qkv_out_dim, + x.stride(0), + x.stride(1), + qkv_lora_b.stride(0), + qkv_lora_b.stride(1), + qkv_lora_b.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + output_offset, + BLOCK_S, + BLOCK_OUT, + BLOCK_R, + fuse_scaling_add, + scaling, + ) + + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py new file mode 100644 index 00000000000..305bb8c5f0e --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py @@ -0,0 +1,143 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_a_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # r + K, # input_dim + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_a_fwd( + x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo +) -> torch.Tensor: + # x: (s, input_dim) + # weights: (num_lora, r, input_dim) + # output: (s, r) + # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r + # input_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + R = weights.shape[-2] + K = weights.shape[-1] + assert x.shape[-1] == K + + # Block shapes + BLOCK_S = 16 + BLOCK_K = 256 + BLOCK_R = 16 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R), + batch_info.bs, + ) + + output = torch.empty((S, R), device=x.device, dtype=x.dtype) + _sgemm_lora_a_kernel[grid]( + x, + weights, + output, + R, + K, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_R, + BLOCK_K, + ) + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py new file mode 100644 index 00000000000..c0bc913630c --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -0,0 +1,159 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # output_dim + K, # r + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, +): + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = s_offset[:, None] < seg_len + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_b_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info: LoraBatchInfo, + base_output: torch.Tensor = None, + scaling: float = 1.0, +) -> torch.Tensor: + # x: (s, r) + # weights: (num_lora, output_dim, r) + # output: (s, output_dim) + # output_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + N = weights.shape[-2] + R = weights.shape[-1] + assert x.shape[-1] == R + + # Block shapes + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_N = 256 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N), + batch_info.bs, + ) + + if base_output is None: + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + + _sgemm_lora_b_kernel[grid]( + x, + weights, + output, + N, + R, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_N, + BLOCK_R, + fuse_scaling_add, + scaling, + ) + return output diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6fa1429dc2c..4933ad1cb62 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -529,6 +529,7 @@ def init_lora_manager(self): max_loras_per_batch=self.server_args.max_loras_per_batch, load_config=self.load_config, dtype=self.dtype, + lora_backend=self.server_args.lora_backend, ) logger.info("LoRA manager ready.") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f9340e47764..4f1c4df8b25 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -113,6 +113,7 @@ class ServerArgs: # LoRA lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 + lora_backend: str = "triton" # Kernel backend attention_backend: Optional[str] = None @@ -649,13 +650,19 @@ def add_cli_args(parser: argparse.ArgumentParser): nargs="*", default=None, action=LoRAPathAction, - help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", + help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.", ) parser.add_argument( "--max-loras-per-batch", type=int, default=8, - help="Maximum number of adapters for a running batch, include base-only request", + help="Maximum number of adapters for a running batch, include base-only request.", + ) + parser.add_argument( + "--lora-backend", + type=str, + default="triton", + help="Choose the kernel backend for multi-LoRA serving.", ) # Kernel backend diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index bae0fcf2a49..6486b2550da 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -272,6 +272,7 @@ def __init__( port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: List[str] = None, max_loras_per_batch: int = 4, + lora_backend: str = "triton", disable_cuda_graph: bool = False, disable_radix_cache: bool = False, ): @@ -287,6 +288,7 @@ def __init__( is_embedding=not self.is_generation, lora_paths=lora_paths, max_loras_per_batch=max_loras_per_batch, + lora_backend=lora_backend, disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, ) diff --git a/test/srt/models/test_lora_backend.py b/test/srt/models/test_lora_backend.py new file mode 100644 index 00000000000..6d61633004c --- /dev/null +++ b/test/srt/models/test_lora_backend.py @@ -0,0 +1,183 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import unittest + +import torch + +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import calculate_rouge_l + +LORA_SETS = [ + {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, + # {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]} +] +TORCH_DTYPES = [torch.float16] + +PROMPTS = [ + "AI is a field of computer science focused on", + """ + ### Instruction: + Tell me about llamas and alpacas + ### Response: + Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. + ### Question 2: + What do you know about llamas? + ### Answer: + """, +] + +BACKENDS = ["triton", "flashinfer"] + +prefill_tolerance: float = 5e-2 +decode_tolerance: float = 5e-2 +rouge_l_tolerance: float = 1 + + +class TestLoRABackend(unittest.TestCase): + + def run_backend( + self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens, backend + ): + print(f"=================== testing {backend} backend =======================") + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [] + i = 0 + for _ in range(len(prompts)): + batch_lora_paths.append(all_lora_paths[i]) + i = (i + 1) % len(all_lora_paths) + print(f"batch lora paths={batch_lora_paths}") + with SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + tp_size=tp_size, + lora_paths=all_lora_paths, + max_loras_per_batch=3, + lora_backend=backend, + disable_cuda_graph=True, + disable_radix_cache=True, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with HFRunner( + base_path, torch_dtype=torch_dtype, model_type="generation" + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + model_type="generation", + ) as srt_runner: + srt_no_lora_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + ) as hf_runner: + hf_no_lora_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + for i in range(len(prompts)): + print(f"Prompt {i} with lora path {batch_lora_paths[i]}:") + + # compare input logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i]) + srt_no_lora_logprobs = torch.Tensor( + srt_no_lora_outputs.top_input_logprobs[i] + ) + print( + "max input diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and srt_lora", + torch.max(abs(srt_no_lora_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and hf_base", + torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)), + ) + print( + "max input diff between hf_lora and hf_base", + torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), + ) + if hf_logprobs.shape[0] <= 100: + assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), ( + f"prefill logprobs are not all close with model_path={base_path}," + f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" + f"prefill_tolerance={prefill_tolerance}." + f"{hf_logprobs=}, {srt_logprobs=}" + ) + + # compare output logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) + print( + "max output diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + "\n", + ) + if hf_logprobs.shape[0] <= 100: + assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), ( + f"decode logprobs are not all close with model_path={base_path}," + f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" + f"decode_tolerance={decode_tolerance}." + f"{hf_logprobs=}, {srt_logprobs=}" + ) + + # compare output strings + srt_output_str = srt_outputs.output_strs[i].strip(" ") + hf_output_str = hf_outputs.output_strs[i] + print(f"srt_output_str={srt_output_str}") + print(f"hf_output_str={hf_output_str}") + rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str]) + print(f"{rouge_l_scores=}") + assert ( + rouge_l_scores[0] >= rouge_l_tolerance + ), f"ROUGE-L scores of prompt {i} outputs are greater than rouge_l_tolerance={rouge_l_tolerance}" + + def test_all(self): + for lora_set in LORA_SETS: + print(f"Testing lora set {lora_set}: ") + for torch_dtype in TORCH_DTYPES: + tp_size = 1 + max_new_tokens = 32 + for backend in BACKENDS: + self.run_backend( + PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens, backend + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 603bab957bd..1fbb7f92f2e 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -8,6 +8,7 @@ "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", + "models/test_lora_backend.py", "models/test_qwen_models.py", "models/test_reward_models.py", "sampling/penaltylib",