-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Define backends and add Triton backend for Lora #3161
Conversation
6a6dadd
to
90a5123
Compare
90a5123
to
bf7ab1f
Compare
@Ying1123 we don't have flashinfer yet on ROCm, I found this merge causes a break on AMD. |
@HaiShaw AMD CIs are crucial for preventing such issues from a process perspective. |
@HaiShaw Also may you help fix the top of the main branch? |
Yes, let me push/press/push on it!! |
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) | ||
lora_output = self.lora_backend.run_lora_b_sgemm( | ||
lora_a_output, | ||
self.B_buffer[0], | ||
base_output=base_output, | ||
scaling=self.scaling, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, would there be a benefit in fusing these two ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fusion of two neighboring Gemms will be really hard to implement, and its benefit is uncertain.
@Fridge003 Please check these. |
I think it has been fixed from AMD people. |
Motivation
Current Lora modules relies on SGemm kernels provided by flashinfer to do the computation. However, Flashinfer is not optimized well on tall and thin matrices of Lora modules. What's more, the way
LoraManager
that manages segment indices and weight indices of input batch is inefficient. All these issues make Lora run slowly with SGLang.Modifications
To improve efficiency of Lora, this PR makes the following modifications on the basis of PR draft #1728:
BaseLoraBackend
,FlashInferLoraBackend
andTritonLoraBackend
classes, which discouple GEMM implementation of each backend from the forward logic of Lora modules. A new server arglora-backend
is added for controlling the backend.BatchInfo
class that packs [bs
,seg_lens
,seg_indptr
,max_len
,weight_indices
] together. By attaching it to lora backend, it only needs to be set once at every batch forward.Usage
A new argument
lora-backend
is added to server arguments. This argument can be eithertriton
orflashinfer
, indicating the backend to be chosen. Its default value istriton
.Accuracy Test
Accuracy test can be run with:
The code can pass accuracy test on both H100 and A6000 machine.
Benchmarking result
To do benchmarking for lora, run this command to launch server:
Then run this command to request test from client:
Benchmark configurations:
Further Optimization
There are two main bottlenecks of Lora with current Triton backend:
The reward of autotuning is poor since sgemm on lora modules has low arithmetic intensity. The current kernels without autotuning are already fast enough.
The best way to optimize lora kernel is adding Cuda/Cutlass backend, so the time of triton compiling can be saved.
Checklist