Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Encoder-decoder model support and T5 Model support #3117

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Use the Google style in this project.
BasedOnStyle: Google

ColumnLimit: 120
16 changes: 12 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import List, Optional, Tuple

import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from transformers import (AutoModelForCausalLM, T5ForConditionalGeneration,
AutoTokenizer, PreTrainedTokenizerBase)
from tqdm import tqdm


Expand Down Expand Up @@ -123,8 +123,16 @@ def run_hf(
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if "t5" in model:
llm = T5ForConditionalGeneration.from_pretrained(
model,
torch_dtype=torch.float16,
trust_remote_code=trust_remote_code)
else:
llm = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype=torch.float16,
trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
Expand Down
62 changes: 42 additions & 20 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
Expand All @@ -19,31 +20,32 @@
#include <hip/hip_runtime.h>
#endif

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include <torch/extension.h>

#include <algorithm>

#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef ENABLE_FP8_E5M2
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#endif

#include <algorithm>

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))

namespace vllm {

// Utility function for attention softmax.
template<int NUM_WARPS>
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
Expand Down Expand Up @@ -101,6 +103,7 @@ __device__ void paged_attention_kernel(
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, max_seq_len]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
Expand Down Expand Up @@ -141,6 +144,10 @@ __device__ void paged_attention_kernel(
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const float* custom_bias_vec = custom_bias == nullptr
? nullptr
: custom_bias + seq_idx * num_kv_heads * num_context_blocks * BLOCK_SIZE +
kv_head_idx * num_context_blocks * BLOCK_SIZE;

// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
Expand Down Expand Up @@ -232,8 +239,10 @@ __device__ void paged_attention_kernel(
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
// Add the custom or ALiBi bias if given.
qk += (custom_bias_vec != nullptr) ? custom_bias_vec[token_idx]
: (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1)
: 0;

if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
Expand Down Expand Up @@ -443,13 +452,14 @@ __global__ void paged_attention_v1_kernel(
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
Expand All @@ -474,13 +484,14 @@ __global__ void paged_attention_v2_kernel(
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
custom_bias, q_stride, kv_block_stride, kv_head_stride);
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -600,6 +611,7 @@ __global__ void paged_attention_v2_reduce_kernel(
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
custom_bias_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
Expand All @@ -621,7 +633,8 @@ void paged_attention_v1_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -634,9 +647,11 @@ void paged_attention_v1_launcher(
assert(head_size % thread_group_size == 0);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* alibi_slopes_ptr =
alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr;

// NOTE: alibi_slopes is optional.
const float* custom_bias_ptr = custom_bias ? reinterpret_cast<const float*>(custom_bias.value().data_ptr()) : nullptr;

T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
Expand Down Expand Up @@ -696,7 +711,8 @@ void paged_attention_v1_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
custom_bias);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -728,6 +744,7 @@ void paged_attention_v1(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
Expand Down Expand Up @@ -770,6 +787,7 @@ void paged_attention_v1(
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
custom_bias_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride); \
Expand Down Expand Up @@ -802,7 +820,8 @@ void paged_attention_v2_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -815,9 +834,10 @@ void paged_attention_v2_launcher(
assert(head_size % thread_group_size == 0);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* alibi_slopes_ptr =
alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr;

const float* custom_bias_ptr = custom_bias ? reinterpret_cast<const float*>(custom_bias.value().data_ptr()) : nullptr;

T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
Expand Down Expand Up @@ -886,7 +906,8 @@ void paged_attention_v2_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
custom_bias);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -921,6 +942,7 @@ void paged_attention_v2(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void paged_attention_v1(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype);

void paged_attention_v2(
Expand All @@ -31,6 +32,7 @@ void paged_attention_v2(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype);

void rms_norm(
Expand Down
76 changes: 76 additions & 0 deletions examples/offline_inference_enc_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
'''
Affirm T5 model outputs match between vLLM and native PyTorch

Scenarios:
* t5-small, t5-large
* float16, float32, bfloat16, bfloat32
* Custom prompts & num. prompts

Output: for several prompts, compare native PyTorch & vLLM prompt completions
'''
import warnings
import torch
from vllm import LLM, SamplingParams
from transformers import T5Tokenizer, T5ForConditionalGeneration

warnings.filterwarnings("ignore",
category=UserWarning,
module="transformers.generation.utils.*")

hf_model_id = "t5-small"
dtype = "bfloat16"
prompts = [
"Who are you?",
"Who are you?",
"How do you like your egg made",
"How do you like your egg made",
]

dtype_obj = getattr(torch, dtype)

# Native PyTorch test

# - Model and tokenizer initialization
tokenizer = T5Tokenizer.from_pretrained(hf_model_id, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(hf_model_id).to(
dtype=dtype_obj)

# - Assume 'dtype' is already defined, e.g., dtype=torch.float32
# - Tokenizing the prompts list with specified data type
input_ids = tokenizer(prompts,
return_tensors="pt",
padding=True,
truncation=True).input_ids

# - If using GPU, also send input_ids to the same device as the model
if torch.cuda.is_available():
model = model.cuda() # Move model to GPU
input_ids = input_ids.cuda() # Move input_ids to GPU

# - Generating outputs for all tokenized prompts
native_outputs = model.generate(input_ids).cpu()

# vLLM test
model = LLM(hf_model_id,
enforce_eager=True,
dtype=dtype,
gpu_memory_utilization=0.5)

sampling_params = SamplingParams(max_tokens=100, temperature=0)

vllm_outputs = model.generate(
prompts,
sampling_params=sampling_params,
)

# Print native & vLLM outputs
i = 0
for native_output, vllm_output in zip(native_outputs, vllm_outputs):
prompt = prompts[i] # Get the corresponding prompt for this output
native_generated_text = tokenizer.decode(
native_output, skip_special_tokens=True) # Decode the generated text
vllm_generated_text = vllm_output.outputs[0].text
print(
f"Prompt: {prompt!r}, Native PyTorch generated text: {native_generated_text!r}, vLLM generated text: {vllm_generated_text!r}"
)
i += 1
5 changes: 4 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def __init__(
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()

self.is_encoder_decoder = getattr(self.model_config.hf_config,
"is_encoder_decoder", False)

def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

Expand Down Expand Up @@ -477,7 +480,7 @@ def add_request(
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request)
self.is_encoder_decoder, lora_request)

# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def generate(
# Use default sampling params.
sampling_params = SamplingParams()

if self.llm_engine.is_encoder_decoder:
assert (prefix_pos is None
), "Encoder decoder models do not support Prefix Cache yet"

# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
def __repr__(self) -> str:
return ("InputMetadata("
f"is_prompt={self.is_prompt}, "
f"prompt_lens={self.prompt_lens}, "
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
Expand Down
Loading