[Feature] Define backends and add Triton backend for Lora #3161
+1,129
−135
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Current Lora modules relies on SGemm kernels provided by flashinfer to do the computation. However, Flashinfer is not optimized well on tall and thin matrices of Lora modules. What's more, the way
LoraManager
that manages segment indices and weight indices of input batch is inefficient. All these issues make Lora run slowly with SGLang.Modifications
To improve efficiency of Lora, this PR makes the following modifications on the basis of PR draft #1728:
BaseLoraBackend
,FlashInferLoraBackend
andTritonLoraBackend
classes, which disentangle GEMM implementation of each backend from the forward logic of Lora modules. A new server arglora-backend
is added for controlling the backend.BatchInfo
class that packs [bs
,seg_lens
,seg_indptr
,max_len
,weight_indices
] together. By attaching it to lora backend, it only needs to be set once at every batch forward.Usage
A new argument
lora-backend
is added to server arguments. This argument can be eithertriton
orflashinfer
, indicating the backend to be chosen. Its default value istriton
.Accuracy Test
Accuracy test can be run with:
The code can pass accuracy test on both H100 and A6000 machine.
Benchmarking result
To do benchmarking for lora, run this command to launch server:
Then run this command to request test from client:
Benchmark configurations:
Further Optimization
There are two main bottlenecks of Lora with current Triton backend:
The reward of autotuning is poor since sgemm on lora modules has low arithmetic intensity. The current kernels without autotuning are already fast enough.
The best way to optimize lora kernel is adding Cuda/Cutlass backend, so the time of triton compiling can be saved.
Checklist