Skip to content

Commit

Permalink
Split sgemm for lora_a and lora_b
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Jan 31, 2025
1 parent a61658f commit 6a6dadd
Show file tree
Hide file tree
Showing 9 changed files with 528 additions and 354 deletions.
1 change: 1 addition & 0 deletions python/sglang/srt/lora/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_backend import BaseLoraBackend
from .flashinfer_backend import FlashInferLoraBackend
from .triton_backend import TritonLoraBackend

Expand Down
20 changes: 17 additions & 3 deletions python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,27 @@ def __init__(self, name: str, batch_info: LoraBatchInfo = None):
self.name = name
self.batch_info = batch_info

def run_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Run segment Gemm with current backend.
def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, output_dim, input_dim)
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
Usually input_dim is much larger than r
Returns:
result with shape (s, r)
"""
pass

def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
Usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
"""
Expand Down
23 changes: 17 additions & 6 deletions python/sglang/srt/lora/backend/flashinfer_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from flashinfer import SegmentGEMMWrapper

from sglang.srt.lora.backend.base_backend import BaseLoraBackend
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo


Expand All @@ -15,7 +15,18 @@ def __init__(self, name: str, batch_info: LoraBatchInfo = None):
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)

def run_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
def run_lora_a_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:

return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)

def run_lora_b_sgemm(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:

return self.segment_gemm.run(
x=x,
Expand All @@ -35,7 +46,7 @@ def run_qkv_lora(
) -> torch.Tensor:

# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_sgemm(x=x, weights=qkv_lora_a)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)

lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
Expand All @@ -48,19 +59,19 @@ def run_qkv_lora(

# FIXME parallelize qkv
# q
lora_output[:, :output_dim_q] = self.run_sgemm(
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)

# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = self.run_sgemm(
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)

lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_sgemm(
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)
Expand Down
Loading

0 comments on commit 6a6dadd

Please sign in to comment.