Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Define backends and add Triton backend for Lora #3161

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Fridge003
Copy link
Contributor

@Fridge003 Fridge003 commented Jan 27, 2025

Motivation

Current Lora modules relies on SGemm kernels provided by flashinfer to do the computation. However, Flashinfer is not optimized well on tall and thin matrices of Lora modules. What's more, the way LoraManager that manages segment indices and weight indices of input batch is inefficient. All these issues make Lora run slowly with SGLang.

Modifications

To improve efficiency of Lora, this PR makes the following modifications on the basis of PR draft #1728:

  1. Define BaseLoraBackend, FlashInferLoraBackend and TritonLoraBackend classes, which disentangle GEMM implementation of each backend from the forward logic of Lora modules. A new server arg lora-backend is added for controlling the backend.
  2. Define BatchInfo class that packs [bs, seg_lens, seg_indptr, max_len, weight_indices] together. By attaching it to lora backend, it only needs to be set once at every batch forward.
  3. Add triton kernels that can run GEMM more efficiently. Including sgemm kernel for lora a (large K, small N), sgemm kernel for lora b(large N, small K), and a fused kernel for qkv's lora_b modules.

Usage

A new argument lora-backend is added to server arguments. This argument can be either triton or flashinfer, indicating the backend to be chosen. Its default value is triton.

Accuracy Test

Accuracy test can be run with:

python test/srt/models/test_lora_backend.py

The code can pass accuracy test on both H100 and A6000 machine.

Benchmarking result

To do benchmarking for lora, run this command to launch server:

# Triton backend
python benchmark/lora/launch_server.py --max-loras-per-batch 4 --lora-backend triton

# Flashinfer backend
# python benchmark/lora/launch_server.py --max-loras-per-batch 4 --lora-backend flashinfer

# Base model without lora
# python benchmark/lora/launch_server.py --base-only

Then run this command to request test from client:

python benchmark/lora/lora_bench.py

Benchmark configurations:

  • base model: meta-llama/Llama-2-7b-hf
  • lora adapter: winddude/wizardLM-LlaMA-LoRA-7B
  • GPU: Nvidia H100
  • maximum number of serving loras: 4
  • metric: total throughput
  • number of requests: 50
  • input length: uniform random distribution on [1, 1024]
  • output length: uniform random distribution on [1, 128]
Backend Total Throughput (tok/s) Mean E2E Latency (ms)
Triton 2040.96 7165.67
Flashinfer 1606.97 9270.38
No Lora 3090.54 4776.19

Further Optimization

There are two main bottlenecks of Lora with current Triton backend:

  • On prefiling batches with long sequence, the lora process has to wait for prior non-lora kernels to complete, which takes a long time. I tried using multiple cuda streams, but the overhead of synchronization is much larger than the time saved.
  • Overhead of Triton's compiling process, which can only be solved by replacing Triton

The reward of autotuning is poor since sgemm on lora modules has low arithmetic intensity. The current kernels without autotuning are already fast enough.

The best way to optimize lora kernel is adding Cuda/Cutlass backend, so the time of triton compiling can be saved.

Checklist

@Fridge003 Fridge003 mentioned this pull request Jan 26, 2025
11 tasks
@Fridge003 Fridge003 changed the title [Feature] Define Gemm backends and add Triton backend for Lora [Feature] Define backends and add Triton backend for Lora Jan 27, 2025
@Fridge003 Fridge003 force-pushed the lora_triton branch 4 times, most recently from 0668455 to a61658f Compare January 31, 2025 03:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants