Skip to content

Commit

Permalink
Reduct data movement overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Feb 1, 2025
1 parent 33623e5 commit 90a5123
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 83 deletions.
47 changes: 38 additions & 9 deletions python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,68 @@
from typing import Tuple, Union

import torch

from sglang.srt.lora.lora import LoraBatchInfo


def get_fuse_output_scaling_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)


def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)


class BaseLoraBackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
Args:
name: name of backend
batch_info: information of current batch for use
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of scaling and adding will be fused into kernel
"""

def __init__(self, name: str, batch_info: LoraBatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)

def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
Usually input_dim is much larger than r
usually input_dim is much larger than r
Returns:
result with shape (s, r)
"""
pass

def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
Usually output_dim is much larger than r
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
"""
Expand All @@ -46,17 +72,20 @@ def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
q_lora_b: torch.Tensor,
kv_lora_b: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
q_lora_b: lora_b module for q, with shape (1, num_lora, output_dim_q, r)
kv_lora_b: lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors containing:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
result with shape (s, output_dim_q + 2 * output_dim_kv)
"""
Expand Down
25 changes: 17 additions & 8 deletions python/sglang/srt/lora/backend/flashinfer_backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch
from flashinfer import SegmentGEMMWrapper

Expand All @@ -15,7 +17,9 @@ def __init__(self, name: str, batch_info: LoraBatchInfo = None):
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)

def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:

return self.segment_gemm.run(
x=x,
Expand All @@ -26,7 +30,9 @@ def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tens
weight_indices=self.batch_info.weight_indices,
)

def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:

return self.segment_gemm.run(
x=x,
Expand All @@ -41,13 +47,15 @@ def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
q_lora_b: torch.Tensor,
kv_lora_b: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:

# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)

q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
Expand All @@ -57,16 +65,17 @@ def run_qkv_lora(
dtype=x.dtype,
)

# FIXME parallelize qkv
# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)

# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)

lora_output[
Expand Down
39 changes: 29 additions & 10 deletions python/sglang/srt/lora/backend/triton_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import torch
import triton
import triton.language as tl

from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
Expand All @@ -16,27 +14,48 @@ class TritonLoraBackend(BaseLoraBackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)

def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return sgemm_lora_a_fwd(x, weights, self.batch_info)

def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
return sgemm_lora_b_fwd(x, weights, self.batch_info)
def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)

def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
q_lora_b: torch.Tensor,
kv_lora_b: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:

# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3 * r, input_dim)
# q_lora_b: (1, num_lora, output_dim_q, r)
# kv_lora_b: (2, num_lora, output_dim_kv, r)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert isinstance(qkv_lora_b, torch.Tensor)

lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
lora_output = qkv_lora_b_fwd(
lora_a_output, q_lora_b, kv_lora_b, self.batch_info
lora_a_output,
qkv_lora_b,
self.batch_info,
output_offset,
max_qkv_out_dim,
base_output,
scaling,
)
return lora_output
64 changes: 54 additions & 10 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,54 @@ def set_lora_info(
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv

if self.lora_backend.fuse_qkv_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]

# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()

# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.output_offset = None
self.max_qkv_out_dim = None

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.lora_backend.run_qkv_lora(
x=x,
qkv_lora_a=self.A_buffer_qkv,
q_lora_b=self.B_buffer_q,
kv_lora_b=self.B_buffer_kv,
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
output_offset=self.output_offset,
max_qkv_out_dim=self.max_qkv_out_dim,
base_output=base_output,
scaling=self.scaling,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
return base_output + lora_output * self.scaling


class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
Expand All @@ -178,11 +215,18 @@ def set_lora_info(self, A_buffer, B_buffer):
self.B_buffer = B_buffer

def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output, weights=self.B_buffer[0]
lora_a_output,
self.B_buffer[0],
base_output=base_output,
scaling=self.scaling,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
return base_output + lora_output * self.scaling

def forward(self, input_):
# duplicate the logic in RowParallelLinear
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def get_backend_from_name(name):
if name in backend_mapping:
return backend_mapping[name]

raise Exception(f"No supported lora backend called {name}.")
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)


def get_layer_id(name):
Expand Down
Loading

0 comments on commit 90a5123

Please sign in to comment.