Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Define backends and add Triton backend for Lora #3161

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
14 changes: 10 additions & 4 deletions benchmark/lora/launch_server.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
import os

NUM_LORAS = 8
NUM_LORAS = 4
LORA_PATH = {
"base": "mistralai/Mistral-7B-Instruct-v0.3",
"lora": "/home/ying/test_lora",
"base": "meta-llama/Llama-2-7b-hf",
"lora": "winddude/wizardLM-LlaMA-LoRA-7B",
}


Expand All @@ -21,7 +21,8 @@ def launch_server(args):
cmd += f"{lora_name}={lora_path} "
cmd += f"--disable-radix --disable-cuda-graph "
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
cmd += f"--max-running-requests {args.max_running_requests}"
cmd += f"--max-running-requests {args.max_running_requests} "
cmd += f"--lora-backend {args.lora_backend}"
print(cmd)
os.system(cmd)

Expand All @@ -42,6 +43,11 @@ def launch_server(args):
type=int,
default=8,
)
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
)
args = parser.parse_args()

launch_server(args)
5 changes: 5 additions & 0 deletions benchmark/lora/lora_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ async def benchmark(
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
lora_name="dummy", # the lora_name argument will not be used
extra_request_body=extra_request_body,
)
test_output = await request_func(request_func_input=test_input)
Expand All @@ -206,6 +207,7 @@ async def benchmark(
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
lora_name="dummy",
extra_request_body=extra_request_body,
)
tasks.append(
Expand Down Expand Up @@ -255,6 +257,9 @@ async def benchmark(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput)
)
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
Expand Down
1 change: 1 addition & 0 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma

* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929).
* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model.
* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`.

## Kernel backend

Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/lora/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base_backend import BaseLoraBackend
from .flashinfer_backend import FlashInferLoraBackend
from .triton_backend import TritonLoraBackend

__all__ = [
"FlashInferLoraBackend",
"TritonLoraBackend",
]
95 changes: 95 additions & 0 deletions python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Tuple, Union

import torch

from sglang.srt.lora.lora import LoraBatchInfo


def get_fuse_output_scaling_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)


def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)


class BaseLoraBackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
Args:
name: name of backend
batch_info: information of current batch for use
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of scaling and adding will be fused into kernel
"""

def __init__(self, name: str, batch_info: LoraBatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
usually input_dim is much larger than r
Returns:
result with shape (s, r)
"""
pass

def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
"""
pass

def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors containing:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
result with shape (s, output_dim_q + 2 * output_dim_kv)
"""
pass

def set_batch_info(self, batch_info: LoraBatchInfo):
self.batch_info = batch_info
88 changes: 88 additions & 0 deletions python/sglang/srt/lora/backend/flashinfer_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Tuple

import torch
from flashinfer import SegmentGEMMWrapper

from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo


class FlashInferLoraBackend(BaseLoraBackend):

def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)

# Set up SGemm Wrapper from flashinfer
# FIXME wait for flashinfer segment gemm update
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:

return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)

def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:

return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)

def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:

# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)

q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
lora_output = torch.empty(
(x.shape[0], output_dim_q + 2 * output_dim_kv),
device=x.device,
dtype=x.dtype,
)

# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)

# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)

lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)

return lora_output
61 changes: 61 additions & 0 deletions python/sglang/srt/lora/backend/triton_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch

from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.triton_ops import (
qkv_lora_b_fwd,
sgemm_lora_a_fwd,
sgemm_lora_b_fwd,
)


class TritonLoraBackend(BaseLoraBackend):

def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return sgemm_lora_a_fwd(x, weights, self.batch_info)

def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)

def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:

# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3 * r, input_dim)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert isinstance(qkv_lora_b, torch.Tensor)

lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
lora_output = qkv_lora_b_fwd(
lora_a_output,
qkv_lora_b,
self.batch_info,
output_offset,
max_qkv_out_dim,
base_output,
scaling,
)
return lora_output
Loading