Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lk-chen authored Feb 6, 2025
2 parents 5254415 + 467a96a commit 6077919
Show file tree
Hide file tree
Showing 157 changed files with 4,545 additions and 1,955 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6353
- name: "exact_match,flexible-extract"
value: 0.637
limit: null
num_fewshot: null
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ steps:
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
- examples/offline_inference/rlhf.py
- examples/offline_inference/ray_placement.py
commands:
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
Expand All @@ -136,6 +137,7 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py

- label: Metrics, Tracing Test # 10min
num_gpus: 2
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/reminder_comment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: PR Reminder Comment Bot
on:
pull_request_target:
types: [opened]

jobs:
pr_reminder:
runs-on: ubuntu-latest
Expand All @@ -15,7 +14,12 @@ jobs:
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
'🚀'
})
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Easy, fast, and cheap LLM serving for everyone

*Latest News* 🔥
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing).
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
Expand Down
3 changes: 3 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping);

void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping);

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
Expand Down
82 changes: 70 additions & 12 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());

const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
}
}

// Kernel for MLA, which works on a single joint kv_cache
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_mla_kernel(
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
const int mem_footprint_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
int64_t src_block = block_mapping[2 * pair_idx];
int64_t dst_block = block_mapping[2 * pair_idx + 1];
int64_t src_offset = src_block * mem_footprint_per_block;
int64_t dst_offset = dst_block * mem_footprint_per_block;
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
cache[dst_offset + i] = cache[src_offset + i];
}
}

} // namespace vllm

// Note: the key_caches and value_caches vectors are constant but
Expand Down Expand Up @@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
}));
}

// copy blocks kernel for MLA (assumes a joint KV-cache)
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping) {
int num_layers = kv_caches.size();
if (num_layers == 0) {
return;
}
torch::Device cache_device = kv_caches[0].device();
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");

std::vector<int64_t> cache_ptrs(num_layers);
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
}
torch::Tensor cache_ptrs_tensor =
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
.to(cache_device);

int num_pairs = block_mapping.size(0);
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
int mem_footprint_per_block = kv_caches[0].stride(0);
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, mem_footprint_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
}));
}

namespace vllm {

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
Expand Down Expand Up @@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel(
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
Expand All @@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel(
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx = block_idx * block_stride +
block_offset * (kv_lora_rank + pe_dim) + i +
offset;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
Expand Down Expand Up @@ -391,14 +448,14 @@ void reshape_and_cache_flash(
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));

void concat_and_cache_mla(
Expand Down Expand Up @@ -428,6 +485,7 @@ void concat_and_cache_mla(
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);

dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
Expand Down
4 changes: 2 additions & 2 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ __global__ void sgl_moe_align_block_size_kernel(
__shared__ int32_t shared_counts[32][8];
__shared__ int32_t local_offsets[256];

const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
const int experts_per_warp = 8;
const int my_expert_start = warp_id * experts_per_warp;

Expand Down
103 changes: 89 additions & 14 deletions csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
// num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();

// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}

// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);

// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);

int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
Expand Down Expand Up @@ -165,19 +201,58 @@ and process in batched manner.
void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
TORCH_CHECK(
positions.size(0) == num_tokens || positions.numel() == num_tokens,
"positions must have the same num_tokens or batch_size as "
"cos_sin_cache_offsets");

int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}

// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);

// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);

int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);

cache_ops.def(
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);

// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# ones.
extensions = [
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.linkcode",
"sphinx.ext.intersphinx",
"sphinx_copybutton",
Expand Down
6 changes: 0 additions & 6 deletions docs/source/features/quantization/auto_awq.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

# AutoAWQ

:::{warning}
Please note that AWQ support in vLLM is under-optimized at the moment. We would recommend using the unquantized version of the model for better
accuracy and higher throughput. Currently, you can use AWQ as a way to reduce memory footprint. As of now, it is more suitable for low latency
inference with small number of concurrent requests. vLLM's AWQ implementation have lower throughput than unquantized version.
:::

To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ).
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
The main benefits are lower latency and memory usage.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started/installation/cpu/build.inc.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Second, install Python packages for vLLM CPU backend building:

```console
pip install --upgrade pip
pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy
pip install "cmake>=3.26" wheel packaging ninja "setuptools-scm>=8" numpy
pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
```

Expand Down
Loading

0 comments on commit 6077919

Please sign in to comment.