From dd82ba3dd7b740d56bc930f19d7dd730f39ed4c0 Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Sun, 18 Feb 2024 14:15:37 +0800 Subject: [PATCH 01/11] t5-small --- .clang-format | 4 + benchmarks/benchmark_throughput.py | 10 +- csrc/attention/attention_kernels.cu | 123 ++-- csrc/ops.h | 6 +- test.py | 17 + vllm/engine/llm_engine.py | 5 +- vllm/entrypoints/llm.py | 4 + vllm/model_executor/input_metadata.py | 1 + vllm/model_executor/layers/attention.py | 79 ++- .../layers/enc_dec_attention.py | 283 ++++++++ vllm/model_executor/layers/sampler.py | 36 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/gpt2.py | 52 +- vllm/model_executor/models/t5.py | 628 ++++++++++++++++++ vllm/sequence.py | 15 +- vllm/worker/model_runner.py | 110 ++- 16 files changed, 1257 insertions(+), 117 deletions(-) create mode 100644 .clang-format create mode 100644 test.py create mode 100644 vllm/model_executor/layers/enc_dec_attention.py create mode 100644 vllm/model_executor/models/t5.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..f9dde9e2b531b --- /dev/null +++ b/.clang-format @@ -0,0 +1,4 @@ +# Use the Google style in this project. +BasedOnStyle: Google + +ColumnLimit: 120 diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1ad502526c97c..34b1e8fdaad0b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,7 +6,7 @@ from typing import List, Optional, Tuple import torch -from transformers import (AutoModelForCausalLM, AutoTokenizer, +from transformers import (AutoModelForCausalLM, T5ForConditionalGeneration, AutoTokenizer, PreTrainedTokenizerBase) from tqdm import tqdm @@ -123,8 +123,12 @@ def run_hf( trust_remote_code: bool, ) -> float: assert not use_beam_search - llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if "t5" in model: + llm = T5ForConditionalGeneration.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + else: + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index b5be3befa07e2..17a0423e69a1e 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -19,9 +20,12 @@ #include #endif -#include #include #include +#include +#include + +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -29,8 +33,6 @@ #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" #endif -#include - #ifndef USE_ROCM #define WARP_SIZE 32 #else @@ -38,12 +40,12 @@ #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -101,6 +103,7 @@ __device__ void paged_attention_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, max_seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride) { @@ -121,14 +124,15 @@ __device__ void paged_attention_kernel( const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); const int num_blocks = end_block_idx - start_block_idx; - + // printf("start_block_idx: %d, end_block_idx: %d, num_blocks: %d\n", start_block_idx, end_block_idx, num_blocks); // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); const int num_tokens = end_token_idx - start_token_idx; - + // printf("start_token_idx: %d, end_token_idx: %d, num_tokens: %d\n", start_token_idx, end_token_idx, num_tokens); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -141,6 +145,10 @@ __device__ void paged_attention_kernel( const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float* custom_bias_vec = custom_bias == nullptr + ? nullptr + : custom_bias + seq_idx * num_kv_heads * num_context_blocks * BLOCK_SIZE + + kv_head_idx * num_context_blocks * BLOCK_SIZE; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread group @@ -173,7 +181,7 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -232,8 +240,12 @@ __device__ void paged_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + // Add the custom or ALiBi bias if given. + qk += (custom_bias_vec != nullptr) ? custom_bias_vec[token_idx] + : (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) + : 0; + // printf("kv_head_idx: %d, token_idx: %d, qk: %f\n", kv_head_idx, token_idx, qk); + // qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -286,13 +298,11 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* exp_sums_ptr = + exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } @@ -411,9 +421,8 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -443,13 +452,14 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); + max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -474,28 +484,26 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); + custom_bias, q_stride, kv_block_stride, kv_head_stride); } // Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -504,8 +512,8 @@ __global__ void paged_attention_v2_reduce_kernel( if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -524,8 +532,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -555,8 +562,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -569,8 +575,8 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { @@ -582,7 +588,7 @@ __global__ void paged_attention_v2_reduce_kernel( } } -} // namespace vllm +} // namespace vllm #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ @@ -600,6 +606,7 @@ __global__ void paged_attention_v2_reduce_kernel( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride); @@ -621,7 +628,8 @@ void paged_attention_v1_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + const c10::optional& custom_bias) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -634,9 +642,11 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + // NOTE: alibi_slopes is optional. + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -696,7 +706,8 @@ void paged_attention_v1_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + custom_bias); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -728,6 +739,7 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { @@ -770,6 +782,7 @@ void paged_attention_v1( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride); \ @@ -802,7 +815,8 @@ void paged_attention_v2_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + const c10::optional& custom_bias) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -815,9 +829,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -886,7 +901,8 @@ void paged_attention_v2_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + custom_bias); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -921,6 +937,7 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { diff --git a/csrc/ops.h b/csrc/ops.h index dbdd2c2c57945..699ca3b2935f1 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -14,7 +14,8 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + const c10::optional& custom_bias); void paged_attention_v2( torch::Tensor& out, @@ -31,7 +32,8 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype, + const c10::optional& custom_bias); void rms_norm( torch::Tensor& out, diff --git a/test.py b/test.py new file mode 100644 index 0000000000000..447d38b20ba53 --- /dev/null +++ b/test.py @@ -0,0 +1,17 @@ +from vllm import LLM, SamplingParams + +model = LLM("t5-large", enforce_eager=True, dtype="float16", gpu_memory_utilization=0.5) +# model = LLM("gpt2", enforce_eager=True, dtype="float16") +sampling_params = SamplingParams(max_tokens=100, temperature=0) + +outputs = model.generate( + [ + "Who is Hilter?", + "Who is Hilter?", + "How do you like your egg made", + "How do you like your egg made", + ], + sampling_params=sampling_params, +) + +print(outputs) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f0fd7efdef813..7b2496d82a5e4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -153,6 +153,9 @@ def _dispatch_worker(self): Worker = imported_worker.Worker return Worker + self.is_encoder_decoder = getattr(self.model_config.hf_config, + "is_encoder_decoder", False) + def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -476,7 +479,7 @@ def add_request( block_size = self.cache_config.block_size seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - lora_request) + self.is_encoder_decoder, lora_request) # Check whether the input specifies prefix prefix = self.scheduler.prefix_pool.add_or_get_prefix( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc82018d18eb6..2f475dd7924ae 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -166,6 +166,10 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() + if self.llm_engine.is_encoder_decoder: + assert (prefix_pos is None + ), "Encoder decoder models do not support Prefix Cache yet" + # Add requests to the engine. num_requests = len(prompts) if prompts is not None else len( prompt_token_ids) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f8..79ada53e252ad 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -46,6 +46,7 @@ def __init__( def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " + f"prompt_lens={self.prompt_lens}, " f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2a82325b80213..aeb5fe5f47eb8 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -12,7 +12,7 @@ from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( - context_attention_fwd) + context_attention_fwd, ) from vllm.utils import is_hip _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -195,6 +195,9 @@ def forward( key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) + print("query.shape: ", query.shape) + print("key.shape: ", key.shape) + print("value.shape: ", value.shape) out = xops.memory_efficient_attention_forward( query, key, @@ -226,14 +229,18 @@ def forward( else: # Decoding run. - output = _paged_attention( + output = paged_attention( query, key_cache, value_cache, - input_metadata, + input_metadata.block_tables, + input_metadata.context_lens, + input_metadata.max_context_len, self.num_kv_heads, self.scale, self.alibi_slopes, + None, + input_metadata.kv_cache_dtype ) # Reshape the output tensor. @@ -274,22 +281,26 @@ def _make_alibi_bias( return attn_bias -def _paged_attention( +def paged_attention( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - input_metadata: InputMetadata, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + custom_bias: Optional[torch.Tensor], + kv_cache_dtype: torch.dtype, ) -> torch.Tensor: output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) + max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # print("max_num_partitions: ", max_num_partitions) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -297,9 +308,28 @@ def _paged_attention( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = max_context_len <= 8192 and (max_num_partitions == 1 + or num_seqs * num_heads > 512) if use_v1: + # print("v1") + # print("output: ", output) + # print("query: ", query) + # print("num_kv_heads: ", num_kv_heads) + # print("scale: ", scale) + # print("block_tables: ", block_tables) + # print("context_lens: ", context_lens) + # print("block_size: ", block_size) + # print("max_context_len: ", max_context_len) + # print("alibi_slopes: ", alibi_slopes) + # print("custom_bias: ", custom_bias) + # print("key_cache shape: ", key_cache.shape) + # print("value_cache shape: ", value_cache.shape) + # for block_table in block_tables: + # for block in block_table: + # print(f"key_cache at {block} shape: ", key_cache[block].shape) + # print(f"key_cache at {block}: ", key_cache[block]) + # print(f"value_cache at {block} shape: ", value_cache[block].shape) + # print(f"value_cache at {block}: ", value_cache[block]) # Run PagedAttention V1. ops.paged_attention_v1( output, @@ -308,14 +338,16 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, ) else: + # print("v2") # Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 tmp_output = torch.empty( @@ -329,6 +361,16 @@ def _paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) + # print("output: ", output) + # print("query: ", query) + # print("num_kv_heads: ", num_kv_heads) + # print("scale: ", scale) + # print("block_tables: ", block_tables) + # print("context_lens: ", context_lens) + # print("block_size: ", block_size) + # print("max_context_len: ", max_context_len) + # print("alibi_slopes: ", alibi_slopes) + # print("custom_bias: ", custom_bias) ops.paged_attention_v2( output, exp_sums, @@ -339,11 +381,12 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, ) return output diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py new file mode 100644 index 0000000000000..1978a40262f70 --- /dev/null +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -0,0 +1,283 @@ +"""Multi-head attention for encoder-decoder models.""" +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias, +) +from vllm._C import cache_ops +from vllm.model_executor.input_metadata import InputMetadata +from vllm.utils import is_hip +from vllm.model_executor.layers.attention import paged_attention + +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +class EncDecAttention(nn.Module): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError( + f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}." + ) + + +class EncoderAttention(EncDecAttention): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Encoder attention forward pass. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + custom_bias: Custom bias tensor. + + Returns: + Output tensor. + """ + # query: [batch_size, seq_len, num_heads * head_size] + # key: [batch_size, seq_len, num_heads * head_size] + # value: [batch_size, seq_len, num_heads * head_size] + # custom_bias: [batch_size, seq_len, seq_len] + # output: [batch_size, seq_len, num_heads * head_size] + + assert input_metadata.is_prompt + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_heads, self.head_size) + # print("query shape: ", query.shape) + if input_metadata.attn_bias is None: + input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size + ) + # When using custom attention bias, xformers requires the bias to + # be sliced from a tensor whose length is a multiple of 8. + # padded_len = (seq_len + 7) // 8 * 8 + # pad_len = padded_len - seq_len + # input_metadata.attn_bias = F.pad(input_metadata.attn_bias, (0, pad_len)) + # print("attention bias padded shape: ", input_metadata.attn_bias.shape) + + input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] + + # print("attention bias shape: ", input_metadata.attn_bias.shape) + # Normal attention + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] + if (is_hip()) + else None, + ) + output = out.view(batch_size, seq_len, hidden_size) + return output + + +class DecoderAttention(EncDecAttention): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Decoder attention forward pass. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + custom_bias: Custom bias tensor. + + Returns: + Output tensor. + """ + + # print("key shape pre view: ", key.shape) + # print("value shape pre view: ", value.shape) + + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + # print("key shape: ", key.shape) + # print("key: ", key) + # print("value shape: ", value.shape) + # print("value: ", value) + # print("slot mapping: ", input_metadata.slot_mapping[:, -1].flatten()) + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + # print("key_cache before: ", key_cache) + # print("value_cache before: ", value_cache) + + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping[:, -1].flatten().contiguous() + ) + + # print("key_cache after: ", key_cache) + # print("value_cache after: ", value_cache) + + max_prompt_len = input_metadata.prompt_lens.max().item() + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, prompt_table_len:].contiguous() + # print("decoder self attention block_tables", block_tables) + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.context_lens, + max_context_len=input_metadata.max_context_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=input_metadata.attn_bias.to(torch.float32), + ) + return output.view(batch_size, seq_len, hidden_size) + + +class CrossAttention(EncDecAttention): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Cross attention forward pass. + Args: + query: Query tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + input_metadata: Input metadata. + key: Key tensor. Only needed in the first pass. + value: Value tensor. Only needed in the first pass. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + # print("key shape pre view: ", key.shape) + key = key.view(-1, self.num_heads, self.head_size) + # print("key_shape: ", key.shape) + # print("key sum", key.sum((1, 2))) + if value is not None: + # print("value shape pre view: ", value.shape) + value = value.view(-1, self.num_heads, self.head_size) + # print("value_shape: ", value.shape) + # print("value sum", value.sum((1, 2))) + + # print("slot mapping: ", input_metadata.slot_mapping[:, :-1].flatten().shape) + # print("slot mapping: ", input_metadata.slot_mapping[:, :-1].flatten()) + # Reshape the keys and values and store them in the cache. + # It only happens during the first pass. + if ( + input_metadata.is_prompt + and key_cache is not None + and value_cache is not None + ): + assert key is not None and value is not None + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping[:, :-1].flatten().contiguous(), + ) + + # for slot in input_metadata.slot_mapping[:, :-1].flatten(): + # if slot != -1: + # block_number = slot//16; + # block_offset = slot%16; + # print(f"key_cache sum at {slot}: ", key_cache[block_number, :, :, block_offset, :].sum()) + # print(f"value_cache sum at {slot}: ", value_cache[block_number, :, :, block_offset].sum()) + max_prompt_len = input_metadata.prompt_lens.int().max().item() + # print("max_prompt_len: ", max_prompt_len) + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, :prompt_table_len].contiguous() + + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.prompt_lens.int(), + max_context_len=max_prompt_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=None, + ) + + return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 71655b216fb3d..25baac43fac71 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -57,6 +57,9 @@ def forward( sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[SamplerOutput]: + # print("hidden_states shape: ", hidden_states.shape) + # print("hidden_states: ", hidden_states) + # Get the hidden states that we use for sampling. if self.logits_as_hidden_states: logits = hidden_states @@ -67,6 +70,9 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) + # print("Logits shape: ", logits.shape) + # print("Logits: ", logits) + # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because # the `embedding` weight is distributed across TP workers. @@ -77,9 +83,12 @@ def forward( assert logits is not None _, vocab_size = logits.shape + # print("Logits shape: ", logits.shape) + # print("Logits: ", logits) # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata) - + # print("Logits shape: ", logits.shape) + # print("Logits: ", logits) # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -92,18 +101,23 @@ def forward( sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties, sampling_tensors.repetition_penalties) - + # print("Logits shape: ", logits.shape) + # print("Logits: ", logits) # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) - + # print("Logits shape: ", logits.shape) + # print("Logits: ", logits) if do_top_p_top_k: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) - + # print("Logits shape: ", logits.shape) + # print("Logits: ", logits) if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) + # print("Logits shape: ", logits.shape) + # print("Logits: ", logits) # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) @@ -112,10 +126,18 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. + # print("Probs shape: ", probs.shape) + # print("Probs: ", probs) + # print("Logprobs shape: ", logprobs.shape) + # print("Logprobs: ", logprobs) + sample_results = _sample(probs, logprobs, sampling_metadata) # Get the logprobs query results. + # print("Sample results: ", sample_results) prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) + # print("Prompt logprobs: ", prompt_logprobs) + # print("Sample logprobs: ", sample_logprobs) return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) @@ -378,6 +400,8 @@ def _sample( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> List[Tuple[List[int], List[int]]]: + # print("probs: ", probs) + # print("logprobs: ", logprobs) categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): @@ -393,11 +417,15 @@ def _sample( # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: sample_indices = categorized_sample_indices[sampling_type] + # print("sampling_type: ", sampling_type) + # print("sample_indices: ", sample_indices) num_tokens = len(sample_indices) if num_tokens == 0: continue seq_group_ids = categorized_seq_group_ids[sampling_type] + # print("seq_group_ids: ", seq_group_ids) seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + # print("seq_groups: ", seq_groups) is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_metadata[sampling_type] = (seq_group_ids, seq_groups, is_prompts, sample_indices) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index e4f3a785cd99a..ee0f85ad901c9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -45,6 +45,7 @@ "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), } # Models not supported by ROCm. diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 661da0fe0434e..fbc3d01252761 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -26,18 +26,21 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -53,8 +56,8 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -134,8 +137,7 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPT2Attention(config, linear_method) @@ -149,6 +151,7 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: residual = hidden_states + print("Block input: ", hidden_states) hidden_states = self.ln_1(hidden_states) attn_output = self.attn( hidden_states=hidden_states, @@ -163,6 +166,7 @@ def forward( feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states + print("Block output: ", hidden_states) return hidden_states @@ -195,12 +199,17 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) + print("input_embeds: ", inputs_embeds) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds - + print("hidden_states: ", hidden_states) for i in range(len(self.h)): layer = self.h[i] hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + if i == 0 and kv_caches[0][0] is not None: + print("hidden_states shape: ", hidden_states.shape) + print("kv cache shape", kv_caches[i][0].shape) + print("Input metadata", input_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -227,6 +236,9 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: + if kv_caches[0][0] is not None: + print("input_ids: ", input_ids) + print("slot mapping shape: ", input_metadata.slot_mapping.shape) hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata) return hidden_states @@ -240,11 +252,13 @@ def sample( sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py new file mode 100644 index 0000000000000..e1a100060f703 --- /dev/null +++ b/vllm/model_executor/models/t5.py @@ -0,0 +1,628 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/t5/modeling_t5.py +# Copyright 2023 The vLLM team. +# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" +from typing import List, Optional, Tuple + +import math, copy + +import torch +from torch import nn +from transformers import T5Config + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.enc_dec_attention import ( + EncoderAttention, + DecoderAttention, + CrossAttention, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_states, _ = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wi_1 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, + config: T5Config, + is_cross: bool, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + total_num_heads = config.num_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.n_heads = total_num_heads // tensor_model_parallel_world_size + self.inner_dim = self.n_heads * self.key_value_proj_dim + + self.q = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.k = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.v = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.o = RowParallelLinear(self.inner_dim, self.d_model, bias=False) + + if has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) + + self.is_cross = is_cross + if self.is_decoder: + if self.is_cross: + self.attn = CrossAttention(self.n_heads, self.key_value_proj_dim, 1) + else: + self.attn = DecoderAttention(self.n_heads, self.key_value_proj_dim, 1) + else: + self.attn = EncoderAttention(self.n_heads, self.key_value_proj_dim, 1) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device="cuda")[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device="cuda")[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # shape (query_length, key_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) + # shape (1, num_heads, query_length, key_length) + values = values.permute([2, 0, 1]).unsqueeze(0) + return values + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + # print("hidden_states shape", hidden_states.shape) + # print("hidden_states", hidden_states) + q, _ = self.q(hidden_states) + + # print("q shape", q.shape) + # print("q", q) + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] + prompt_len = input_metadata.prompt_lens.max().item() + context_len = input_metadata.context_lens.max().item() + context_len = max(context_len, 1) + # print("batch_size", batch_size) + # print("seq_len", seq_len) + # print("prompt_len", prompt_len) + # print("context_len", context_len) + + block_size = 16 + + if not self.is_decoder: + # print("encoder self attention!") + assert kv_cache is None + # Encoder self attention, no cache operations + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) + + + + if input_metadata.attn_bias is None: + input_metadata.attn_bias = self.compute_bias( + prompt_len, (prompt_len + block_size - 1) // block_size * block_size + ).repeat(batch_size, 1, 1, 1) + for i in range(batch_size): + input_metadata.attn_bias[ + i, + :, + :, + input_metadata.prompt_lens[i] :, + ] = torch.finfo(input_metadata.attn_bias.dtype).min + + # print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape) + # print("input_metadata.attn_bias", input_metadata.attn_bias) + attn_output = self.attn(q, k, v, input_metadata) + + elif not self.is_cross: + # print("decoder self attention!") + # Decoder self attention + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) + + if input_metadata.attn_bias is None: + position_bias = self.compute_bias( + 1 if input_metadata.is_prompt else context_len, + (context_len + block_size - 1) // block_size * block_size + ).repeat(batch_size, 1, 1, 1) + # print("position_bias shape", position_bias.shape) + # print("position_bias", position_bias) + input_metadata.attn_bias = position_bias[:, :, -seq_len:, :].contiguous() + # print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape) + # print("input_metadata.attn_bias", input_metadata.attn_bias) + + key_cache, value_cache = kv_cache + + attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) + + else: + # print("cross attention!") + # Cross attention + + key_cache, value_cache = kv_cache + if input_metadata.is_prompt: + assert encoder_hidden_states is not None + k, _ = self.k(encoder_hidden_states) + v, _ = self.v(encoder_hidden_states) + # print("k shape", k.shape) + # for i in range(k.shape[0]): + # for j in range(k.shape[1]): + # print(f"key at batch {i} and pos {j}: ", k[i, j, :].reshape(1, 8, 64)) + attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) + else: + attn_output = self.attn( + q, None, None, key_cache, value_cache, input_metadata + ) + + attn_output, _ = self.o(attn_output) + return attn_output + + +class T5LayerSelfAttention(nn.Module): + def __init__( + self, + config, + has_relative_attention_bias, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.SelfAttention = T5Attention( + config, + is_cross=False, + has_relative_attention_bias=has_relative_attention_bias, + linear_method=linear_method, + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + # print("self attention input shape: ", normed_hidden_states.shape) + # print("self_attention input: ", normed_hidden_states) + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=None, + ) + # print("self attention output shape: ", attention_output.shape) + # print("self_attention output: ", attention_output) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.EncDecAttention = T5Attention( + config, + is_cross=True, + has_relative_attention_bias=False, + linear_method=linear_method, + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + # print("cross attention input shape: ", normed_hidden_states.shape) + # print("cross_attention input: ", normed_hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + # print("cross attention output shape: ", attention_output.shape) + # print("cross_attention output: ", attention_output) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5Block(nn.Module): + def __init__( + self, + config, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + linear_method=linear_method, + ) + ) + if self.is_decoder: + self.layer.append( + T5LayerCrossAttention(config, linear_method=linear_method) + ) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ): + hidden_states = self.layer[0]( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + if self.is_decoder: + hidden_states = self.layer[1]( + hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + return hidden_states + + +class T5Stack(nn.Module): + def __init__( + self, + config: T5Config, + embed_tokens: torch.Tensor, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.embed_tokens = embed_tokens + + self.block = nn.ModuleList( + [ + T5Block( + config, + has_relative_attention_bias=(i == 0), + linear_method=linear_method, + ) + for i in range(config.num_layers) + ] + ) + + self.final_layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + + def forward( + self, + input_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + # print("input_ids: ", input_ids) + hidden_states = self.embed_tokens(input_ids) + + # print("hidden_states shape: ", hidden_states.shape) + # print("hidden_states: ", hidden_states) + for i, layer_module in enumerate(self.block): + kv_cache = kv_caches[i] if self.is_decoder else None + + layer_outputs = layer_module( + hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = layer_outputs + + hidden_states = self.final_layer_norm(hidden_states) + # if encoder_hidden_states is not None: + # print("hidden_states shape:" , hidden_states.shape) + # print("encoder_hidden_states shape:" , encoder_hidden_states.shape) + # # Attach encoder hidden states + # hidden_states = torch.cat( + # [encoder_hidden_states, hidden_states], dim=1 + # ) + # print("final_hidden_states shape: ", hidden_states.shape) + # print("final_hidden_states: ", hidden_states) + return hidden_states + + +class T5ForConditionalGeneration(nn.Module): + def __init__( + self, config: T5Config, linear_method: Optional[LinearMethodBase] = None + ): + super().__init__() + self.config = config + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + self.encoder = T5Stack(encoder_config, self.shared, linear_method) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + self.decoder = T5Stack(decoder_config, self.shared, linear_method) + + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + # print("input_ids shape: ", input_ids.shape) + # print("input_ids: ", input_ids) + # print("input_metadata: ", input_metadata) + if input_metadata.is_prompt: + # prompt run, need to run encoder once + hidden_states = self.encoder(input_ids, kv_caches, input_metadata, None) + # Clear the attention bias + input_metadata.attn_bias = None + batch_size = input_ids.shape[0] + input_ids = ( + torch.ones(batch_size, 1, dtype=torch.long) + * self.config.decoder_start_token_id + ).cuda() + + else: + hidden_states = None + + if kv_caches[0][0] is not None: # Skip decoder for profiling run + hidden_states = self.decoder( + input_ids, kv_caches, input_metadata, hidden_states + ) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + hidden_states = hidden_states * (self.model_dim**-0.5) + + return hidden_states + + def sample(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata): + # logger.info(f"decoder_outputs: {decoder_outputs}") + next_tokens = self.sampler(self.shared.weight, hidden_states, sampling_metadata) + # logger.info(f"next_tokens: {next_tokens}") + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "EncDecAttention.relative_attention_bias" in name: + continue + + assert name in params_dict, f"{name} not in params_dict" + param = params_dict[name] + assert param.shape == loaded_weight.shape, ( + f"{name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}" + ) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index 040e9756e15c6..f601837b40915 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -135,6 +135,7 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + is_decoder_encoder: bool, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id @@ -147,8 +148,20 @@ def __init__( self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] + initial_token_ids = prompt_token_ids + if is_decoder_encoder: + # We need to seperate the prompt and generated tokens for encoder-decoder models. + num_prompt_blocks = (len(prompt_token_ids) + block_size - + 1) // block_size + padded_prompt_len = num_prompt_blocks * block_size + initial_token_ids = prompt_token_ids + [0] * ( + padded_prompt_len - len(prompt_token_ids)) + # Also need to append decoder_start_token_id + initial_token_ids.append(0) + # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(initial_token_ids) + self.status = SequenceStatus.WAITING # Used for incremental detokenization diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index efe570778fb43..28a97f3e85b99 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -83,6 +83,8 @@ def __init__( # Set enforce_eager to True for Neuron backend, to avoid capturing graph if self.device_config.is_neuron: self.model_config.enforce_eager = True + self.is_encoder_decoder = getattr(self.model_config.hf_config, + "is_encoder_decoder", False) def load_model(self) -> None: self.model = get_model(self.model_config, @@ -135,6 +137,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + block_tables: List[List[int]] = [] + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -145,16 +149,20 @@ def _prepare_prompt( prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + prefix_len = 0 - prefix = seq_group_metadata.prefix - if prefix is not None and prefix.computed: - prefix_len = prefix.get_length() - prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_tables.append(prefix.get_block_numbers()) + if self.is_encoder_decoder: + context_lens.append(1) else: - prefix_block_tables.append([]) - # actual prompt lens - context_lens.append(prefix_len) + prefix = seq_group_metadata.prefix + if prefix is not None and prefix.computed: + prefix_len = prefix.get_length() + prompt_tokens = prompt_tokens[prefix_len:] + prefix_block_tables.append(prefix.get_block_numbers()) + else: + prefix_block_tables.append([]) + # actual prompt lens + context_lens.append(prefix_len) subquery_lens.append(prompt_len - prefix_len) input_tokens.append(prompt_tokens) @@ -204,6 +212,12 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) + if self.is_encoder_decoder: + block_tables.append(block_table) + max_block_table_len = max(max_block_table_len, + len(block_table)) + # print("slot_mapping: ", slot_mapping) + # print("block_tables: ", block_tables) max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, @@ -215,8 +229,16 @@ def _prepare_prompt( pad=0, dtype=torch.long, device=self.device) + if self.is_encoder_decoder and len(block_tables) > 0: + # Pad the slot mapping to the same length and add decoder_start_id + for i in range(len(slot_mapping)): + slot_mapping[i] += [_PAD_SLOT_ID + ] * (max_prompt_len - len(slot_mapping[i])) + slot_mapping[i].append(block_tables[i][-1] * self.block_size) + + max_slot_mapping_len = max_prompt_len + self.is_encoder_decoder slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, + max_slot_mapping_len, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) @@ -224,6 +246,7 @@ def _prepare_prompt( _pad_to_max(mapping, max_prompt_len, pad=0) for mapping in lora_index_mapping ] + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -236,6 +259,28 @@ def _prepare_prompt( dtype=torch.int, device=self.device, ) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables_tensor = _make_tensor_with_pad( + padded_block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int) + else: + max_prompt_block_table_len = max( + len(t) for t in prefix_block_tables) + block_tables_tensor = _make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + ) start_loc_tensor = torch.arange(0, len(prompt_lens) * max_prompt_len, max_prompt_len, @@ -251,9 +296,9 @@ def _prepare_prompt( prompt_lens=prompt_lens_tensor, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=None, + max_context_len=max(context_lens), context_lens=context_lens_tensor, - block_tables=block_tables, + block_tables=block_tables_tensor, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, ) @@ -270,12 +315,14 @@ def _prepare_decode( input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + prompt_lens: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -294,11 +341,27 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - context_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) + prompt_len = len(seq_data.prompt_token_ids) + prompt_lens.append(prompt_len) + + if self.is_encoder_decoder: + # Encoder-decoder model stores prompt and generation tokens separately, + # so we need to adjust to the pad. + prompt_blocks_num = (prompt_len + self.block_size - + 1) // self.block_size + prompt_pad = prompt_blocks_num * self.block_size - prompt_len + position += prompt_pad + 1 # One extra for decoder_start_id + + if self.is_encoder_decoder: + context_len = seq_len - prompt_len + 1 + elif self.sliding_window is not None: + context_len = min(seq_len, self.sliding_window) + else: + context_len = seq_len context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] + max_block_table_len = max(max_block_table_len, len(block_table)) block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset @@ -311,6 +374,15 @@ def _prepare_decode( self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables = padded_block_tables batch_size = len(input_tokens) max_context_len = max(context_lens) @@ -350,6 +422,7 @@ def _prepare_decode( dtype=torch.int, device=self.device) + prompt_lens = torch.tensor(prompt_lens, dtype=torch.int, device=self.device) if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -376,7 +449,7 @@ def _prepare_decode( input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + prompt_lens=prompt_lens, max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -408,7 +481,7 @@ def _prepare_sample( sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - if seq_group_metadata.is_prompt: + if seq_group_metadata.is_prompt and not self.is_encoder_decoder: assert len(seq_ids) == 1 assert subquery_lens is not None subquery_len = subquery_lens[i] @@ -452,6 +525,7 @@ def _prepare_sample( dtype=torch.long, target_device=self.device, pin_memory=pin_memory) + # print("selected_token_indices: ", selected_token_indices) categorized_sample_indices = { t: _async_h2d(seq_ids, dtype=torch.int, @@ -459,11 +533,12 @@ def _prepare_sample( pin_memory=pin_memory) for t, seq_ids in categorized_sample_indices.items() } + # print("categorized_sample_indices: ", categorized_sample_indices) seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) - + # print("selected_token_indices: ", selected_token_indices) sampling_metadata = SamplingMetadata( seq_groups=seq_groups, seq_data=seq_data, @@ -585,6 +660,8 @@ def execute_model( kv_caches=kv_caches, input_metadata=input_metadata, ) + # print("hidden_states shape: ", hidden_states.shape) + # print("hidden_states: ", hidden_states) # Sample the next token. output = self.model.sample( @@ -700,6 +777,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + prompt_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( From f2fd5793c4a1396d9af7bc190daba7b61295ab75 Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Thu, 29 Feb 2024 11:29:58 +0800 Subject: [PATCH 02/11] fix --- csrc/attention/attention_kernels.cu | 61 ++++++++++--------- csrc/ops.h | 8 +-- test.py | 7 ++- vllm/engine/llm_engine.py | 5 +- vllm/model_executor/layers/attention.py | 8 +-- .../layers/enc_dec_attention.py | 9 +-- vllm/model_executor/models/gpt2.py | 52 ++++++---------- vllm/model_executor/models/t5.py | 20 +++--- vllm/worker/model_runner.py | 13 +--- 9 files changed, 85 insertions(+), 98 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 17a0423e69a1e..0f561ab3155b7 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -124,15 +124,14 @@ __device__ void paged_attention_kernel( const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); const int num_blocks = end_block_idx - start_block_idx; - // printf("start_block_idx: %d, end_block_idx: %d, num_blocks: %d\n", start_block_idx, end_block_idx, num_blocks); + // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); const int num_tokens = end_token_idx - start_token_idx; - // printf("start_token_idx: %d, end_token_idx: %d, num_tokens: %d\n", start_token_idx, end_token_idx, num_tokens); + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = - NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -181,7 +180,7 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -244,8 +243,6 @@ __device__ void paged_attention_kernel( qk += (custom_bias_vec != nullptr) ? custom_bias_vec[token_idx] : (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; - // printf("kv_head_idx: %d, token_idx: %d, qk: %f\n", kv_head_idx, token_idx, qk); - // qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -298,11 +295,13 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = - max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = - exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; *exp_sums_ptr = exp_sum; } @@ -421,8 +420,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -495,15 +495,18 @@ __global__ void paged_attention_v2_kernel( } // Grid: (num_heads, num_seqs). -template +template< + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -512,8 +515,8 @@ __global__ void paged_attention_v2_reduce_kernel( if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -532,7 +535,8 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -562,7 +566,8 @@ __global__ void paged_attention_v2_reduce_kernel( // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -575,8 +580,8 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { @@ -588,7 +593,7 @@ __global__ void paged_attention_v2_reduce_kernel( } } -} // namespace vllm +} // namespace vllm #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ diff --git a/csrc/ops.h b/csrc/ops.h index 699ca3b2935f1..ce3a73cc182c3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -14,8 +14,8 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - const c10::optional& custom_bias); + const c10::optional& custom_bias, + const std::string& kv_cache_dtype); void paged_attention_v2( torch::Tensor& out, @@ -32,8 +32,8 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - const c10::optional& custom_bias); + const c10::optional& custom_bias, + const std::string& kv_cache_dtype); void rms_norm( torch::Tensor& out, diff --git a/test.py b/test.py index 447d38b20ba53..f0f28c83fe331 100644 --- a/test.py +++ b/test.py @@ -1,13 +1,16 @@ from vllm import LLM, SamplingParams +# print proxies + + model = LLM("t5-large", enforce_eager=True, dtype="float16", gpu_memory_utilization=0.5) # model = LLM("gpt2", enforce_eager=True, dtype="float16") sampling_params = SamplingParams(max_tokens=100, temperature=0) outputs = model.generate( [ - "Who is Hilter?", - "Who is Hilter?", + "Who are you?", + "Who are you?", "How do you like your egg made", "How do you like your egg made", ], diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7b2496d82a5e4..b8214e6567416 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -143,6 +143,8 @@ def __init__( if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + self.is_encoder_decoder = getattr(self.model_config.hf_config, + "is_encoder_decoder", False) def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) @@ -153,9 +155,6 @@ def _dispatch_worker(self): Worker = imported_worker.Worker return Worker - self.is_encoder_decoder = getattr(self.model_config.hf_config, - "is_encoder_decoder", False) - def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index aeb5fe5f47eb8..508b545e0e9f8 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -12,7 +12,7 @@ from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( - context_attention_fwd, ) + context_attention_fwd) from vllm.utils import is_hip _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -195,9 +195,9 @@ def forward( key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - print("query.shape: ", query.shape) - print("key.shape: ", key.shape) - print("value.shape: ", value.shape) + # print("query.shape: ", query.shape) + # print("key.shape: ", key.shape) + # print("value.shape: ", value.shape) out = xops.memory_efficient_attention_forward( query, key, diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py index 1978a40262f70..a42b9ec73699f 100644 --- a/vllm/model_executor/layers/enc_dec_attention.py +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -8,7 +8,6 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import ( BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias, ) from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata @@ -16,8 +15,6 @@ from vllm.model_executor.layers.attention import paged_attention _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] -# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 class EncDecAttention(nn.Module): @@ -167,7 +164,8 @@ def forward( value, key_cache, value_cache, - input_metadata.slot_mapping[:, -1].flatten().contiguous() + input_metadata.slot_mapping[:, -1].flatten().contiguous(), + input_metadata.kv_cache_dtype ) # print("key_cache after: ", key_cache) @@ -189,6 +187,7 @@ def forward( scale=self.scale, alibi_slopes=None, custom_bias=input_metadata.attn_bias.to(torch.float32), + kv_cache_dtype=input_metadata.kv_cache_dtype, ) return output.view(batch_size, seq_len, hidden_size) @@ -253,6 +252,7 @@ def forward( key_cache, value_cache, input_metadata.slot_mapping[:, :-1].flatten().contiguous(), + input_metadata.kv_cache_dtype, ) # for slot in input_metadata.slot_mapping[:, :-1].flatten(): @@ -278,6 +278,7 @@ def forward( scale=self.scale, alibi_slopes=None, custom_bias=None, + kv_cache_dtype=input_metadata.kv_cache_dtype, ) return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fbc3d01252761..661da0fe0434e 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -26,21 +26,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - LinearMethodBase, - QKVParallelLinear, - RowParallelLinear, -) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size, ) + get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import ( - default_weight_loader, - hf_model_weights_iterator, -) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -56,8 +53,8 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -137,7 +134,8 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPT2Attention(config, linear_method) @@ -151,7 +149,6 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: residual = hidden_states - print("Block input: ", hidden_states) hidden_states = self.ln_1(hidden_states) attn_output = self.attn( hidden_states=hidden_states, @@ -166,7 +163,6 @@ def forward( feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states - print("Block output: ", hidden_states) return hidden_states @@ -199,17 +195,12 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) - print("input_embeds: ", inputs_embeds) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds - print("hidden_states: ", hidden_states) + for i in range(len(self.h)): layer = self.h[i] hidden_states = layer(hidden_states, kv_caches[i], input_metadata) - if i == 0 and kv_caches[0][0] is not None: - print("hidden_states shape: ", hidden_states.shape) - print("kv cache shape", kv_caches[i][0].shape) - print("Input metadata", input_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -236,9 +227,6 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - if kv_caches[0][0] is not None: - print("input_ids: ", input_ids) - print("slot mapping shape: ", input_metadata.slot_mapping.shape) hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata) return hidden_states @@ -252,13 +240,11 @@ def sample( sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index e1a100060f703..ef1ac52d69e80 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -358,16 +358,16 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) - # print("self attention input shape: ", normed_hidden_states.shape) - # print("self_attention input: ", normed_hidden_states) + print("self attention input shape: ", normed_hidden_states.shape) + print("self_attention input: ", normed_hidden_states) attention_output = self.SelfAttention( hidden_states=normed_hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, encoder_hidden_states=None, ) - # print("self attention output shape: ", attention_output.shape) - # print("self_attention output: ", attention_output) + print("self attention output shape: ", attention_output.shape) + print("self_attention output: ", attention_output) hidden_states = hidden_states + attention_output return hidden_states @@ -395,16 +395,16 @@ def forward( encoder_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) - # print("cross attention input shape: ", normed_hidden_states.shape) - # print("cross_attention input: ", normed_hidden_states) + print("cross attention input shape: ", normed_hidden_states.shape) + print("cross_attention input: ", normed_hidden_states) attention_output = self.EncDecAttention( hidden_states=normed_hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, encoder_hidden_states=encoder_hidden_states, ) - # print("cross attention output shape: ", attention_output.shape) - # print("cross_attention output: ", attention_output) + print("cross attention output shape: ", attention_output.shape) + print("cross_attention output: ", attention_output) hidden_states = hidden_states + attention_output return hidden_states @@ -537,8 +537,8 @@ def forward( # hidden_states = torch.cat( # [encoder_hidden_states, hidden_states], dim=1 # ) - # print("final_hidden_states shape: ", hidden_states.shape) - # print("final_hidden_states: ", hidden_states) + print("final_hidden_states shape: ", hidden_states.shape) + print("final_hidden_states: ", hidden_states) return hidden_states diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 28a97f3e85b99..d17d5b080487a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -250,15 +250,6 @@ def _prepare_prompt( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = _make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) if self.is_encoder_decoder: padded_block_tables = [] # Pad the encoder block tables to the same length and then add a decoder block table in the end @@ -271,8 +262,10 @@ def _prepare_prompt( padded_block_tables, max_len=max_block_table_len, pad=0, - dtype=torch.int) + dtype=torch.int, + device = self.device) else: + # Prepare prefix block tables max_prompt_block_table_len = max( len(t) for t in prefix_block_tables) block_tables_tensor = _make_tensor_with_pad( From 2fb690521e481c340e73acd88f1dd858daa79bc3 Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Thu, 29 Feb 2024 23:46:38 +0800 Subject: [PATCH 03/11] lint --- benchmarks/benchmark_throughput.py | 12 +- test.py | 6 +- vllm/engine/llm_engine.py | 3 +- vllm/model_executor/layers/attention.py | 16 +- .../layers/enc_dec_attention.py | 46 ++-- vllm/model_executor/models/t5.py | 209 +++++++++--------- vllm/worker/model_runner.py | 9 +- 7 files changed, 155 insertions(+), 146 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 34b1e8fdaad0b..d80cf8ef5409b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,8 +6,8 @@ from typing import List, Optional, Tuple import torch -from transformers import (AutoModelForCausalLM, T5ForConditionalGeneration, AutoTokenizer, - PreTrainedTokenizerBase) +from transformers import (AutoModelForCausalLM, T5ForConditionalGeneration, + AutoTokenizer, PreTrainedTokenizerBase) from tqdm import tqdm @@ -125,10 +125,14 @@ def run_hf( assert not use_beam_search if "t5" in model: llm = T5ForConditionalGeneration.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + model, + torch_dtype=torch.float16, + trust_remote_code=trust_remote_code) else: llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + model, + torch_dtype=torch.float16, + trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token diff --git a/test.py b/test.py index f0f28c83fe331..cdf49ebad2b12 100644 --- a/test.py +++ b/test.py @@ -2,8 +2,10 @@ # print proxies - -model = LLM("t5-large", enforce_eager=True, dtype="float16", gpu_memory_utilization=0.5) +model = LLM("t5-large", + enforce_eager=True, + dtype="float16", + gpu_memory_utilization=0.5) # model = LLM("gpt2", enforce_eager=True, dtype="float16") sampling_params = SamplingParams(max_tokens=100, temperature=0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b8214e6567416..e925c02629054 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -144,7 +144,8 @@ def __init__( self.forward_dag = self._compiled_ray_dag() self.is_encoder_decoder = getattr(self.model_config.hf_config, - "is_encoder_decoder", False) + "is_encoder_decoder", False) + def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 508b545e0e9f8..c82d5aff83a3c 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -230,18 +230,10 @@ def forward( else: # Decoding run. output = paged_attention( - query, - key_cache, - value_cache, - input_metadata.block_tables, - input_metadata.context_lens, - input_metadata.max_context_len, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - None, - input_metadata.kv_cache_dtype - ) + query, key_cache, value_cache, input_metadata.block_tables, + input_metadata.context_lens, input_metadata.max_context_len, + self.num_kv_heads, self.scale, self.alibi_slopes, None, + input_metadata.kv_cache_dtype) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py index a42b9ec73699f..9d63fbe7ae36d 100644 --- a/vllm/model_executor/layers/enc_dec_attention.py +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -7,8 +7,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, -) + BlockDiagonalCausalMask, ) from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata from vllm.utils import is_hip @@ -18,6 +17,7 @@ class EncDecAttention(nn.Module): + def __init__( self, num_heads: int, @@ -30,13 +30,12 @@ def __init__( self.scale = float(scale) if self.head_size not in _SUPPORTED_HEAD_SIZES: - raise ValueError( - f"head_size ({self.head_size}) is not supported. " - f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}." - ) + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") class EncoderAttention(EncDecAttention): + def __init__( self, num_heads: int, @@ -78,8 +77,7 @@ def forward( # print("query shape: ", query.shape) if input_metadata.attn_bias is None: input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size - ) + [seq_len] * batch_size) # When using custom attention bias, xformers requires the bias to # be sliced from a tensor whose length is a multiple of 8. # padded_len = (seq_len + 7) // 8 * 8 @@ -98,15 +96,15 @@ def forward( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, - op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] - if (is_hip()) - else None, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, ) output = out.view(batch_size, seq_len, hidden_size) return output class DecoderAttention(EncDecAttention): + def __init__( self, num_heads: int, @@ -160,13 +158,9 @@ def forward( # print("value_cache before: ", value_cache) cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, + key, value, key_cache, value_cache, input_metadata.slot_mapping[:, -1].flatten().contiguous(), - input_metadata.kv_cache_dtype - ) + input_metadata.kv_cache_dtype) # print("key_cache after: ", key_cache) # print("value_cache after: ", value_cache) @@ -174,7 +168,9 @@ def forward( max_prompt_len = input_metadata.prompt_lens.max().item() block_size = value_cache.shape[3] prompt_table_len = (max_prompt_len + block_size - 1) // block_size - block_tables = input_metadata.block_tables[:, prompt_table_len:].contiguous() + block_tables = input_metadata.block_tables[:, + prompt_table_len:].contiguous( + ) # print("decoder self attention block_tables", block_tables) output = paged_attention( query=query, @@ -193,6 +189,7 @@ def forward( class CrossAttention(EncDecAttention): + def __init__( self, num_heads: int, @@ -240,11 +237,8 @@ def forward( # print("slot mapping: ", input_metadata.slot_mapping[:, :-1].flatten()) # Reshape the keys and values and store them in the cache. # It only happens during the first pass. - if ( - input_metadata.is_prompt - and key_cache is not None - and value_cache is not None - ): + if (input_metadata.is_prompt and key_cache is not None + and value_cache is not None): assert key is not None and value is not None cache_ops.reshape_and_cache( key, @@ -254,7 +248,7 @@ def forward( input_metadata.slot_mapping[:, :-1].flatten().contiguous(), input_metadata.kv_cache_dtype, ) - + # for slot in input_metadata.slot_mapping[:, :-1].flatten(): # if slot != -1: # block_number = slot//16; @@ -265,7 +259,9 @@ def forward( # print("max_prompt_len: ", max_prompt_len) block_size = value_cache.shape[3] prompt_table_len = (max_prompt_len + block_size - 1) // block_size - block_tables = input_metadata.block_tables[:, :prompt_table_len].contiguous() + block_tables = input_metadata.block_tables[:, : + prompt_table_len].contiguous( + ) output = paged_attention( query=query, diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index ef1ac52d69e80..f3403926d3e6c 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -40,8 +40,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size, -) + get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import ( default_weight_loader, @@ -53,6 +52,7 @@ class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. @@ -67,8 +67,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -78,6 +80,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): super().__init__() self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) @@ -92,10 +95,15 @@ def forward(self, hidden_states): class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): super().__init__() - self.wi_0 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) - self.wi_1 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) self.act = get_act_fn(config.dense_act_fn) @@ -108,6 +116,7 @@ def forward(self, hidden_states): class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): super().__init__() if config.is_gated_act: @@ -115,7 +124,8 @@ def __init__(self, config: T5Config): else: self.DenseReluDense = T5DenseActDense(config) - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) @@ -125,6 +135,7 @@ def forward(self, hidden_states): class T5Attention(nn.Module): + def __init__( self, config: T5Config, @@ -139,7 +150,8 @@ def __init__( self.d_model = config.d_model self.key_value_proj_dim = config.d_kv total_num_heads = config.num_heads - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) assert total_num_heads % tensor_model_parallel_world_size == 0 self.n_heads = total_num_heads // tensor_model_parallel_world_size self.inner_dim = self.n_heads * self.key_value_proj_dim @@ -151,22 +163,25 @@ def __init__( if has_relative_attention_bias: self.relative_attention_bias = nn.Embedding( - self.relative_attention_num_buckets, self.n_heads - ) + self.relative_attention_num_buckets, self.n_heads) self.is_cross = is_cross if self.is_decoder: if self.is_cross: - self.attn = CrossAttention(self.n_heads, self.key_value_proj_dim, 1) + self.attn = CrossAttention(self.n_heads, + self.key_value_proj_dim, 1) else: - self.attn = DecoderAttention(self.n_heads, self.key_value_proj_dim, 1) + self.attn = DecoderAttention(self.n_heads, + self.key_value_proj_dim, 1) else: - self.attn = EncoderAttention(self.n_heads, self.key_value_proj_dim, 1) + self.attn = EncoderAttention(self.n_heads, self.key_value_proj_dim, + 1) @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 @@ -190,12 +205,12 @@ def _relative_position_bucket( relative_buckets = 0 if bidirectional: num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_buckets += (relative_position > 0).to( + torch.long) * num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -204,31 +219,28 @@ def _relative_position_bucket( # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) + torch.log(relative_position.float() / max_exact) / + math.log(max_distance / max_exact) * + (num_buckets - max_exact)).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1), ) - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) return relative_buckets def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" - context_position = torch.arange(query_length, dtype=torch.long, device="cuda")[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device="cuda")[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) + context_position = torch.arange(query_length, + dtype=torch.long, + device="cuda")[:, None] + memory_position = torch.arange(key_length, + dtype=torch.long, + device="cuda")[None, :] + relative_position = (memory_position - context_position + ) # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), @@ -273,19 +285,15 @@ def forward( k, _ = self.k(hidden_states) v, _ = self.v(hidden_states) - - if input_metadata.attn_bias is None: input_metadata.attn_bias = self.compute_bias( - prompt_len, (prompt_len + block_size - 1) // block_size * block_size - ).repeat(batch_size, 1, 1, 1) + prompt_len, (prompt_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) for i in range(batch_size): input_metadata.attn_bias[ - i, - :, - :, - input_metadata.prompt_lens[i] :, - ] = torch.finfo(input_metadata.attn_bias.dtype).min + i, :, :, + input_metadata.prompt_lens[i]:, ] = torch.finfo( + input_metadata.attn_bias.dtype).min # print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape) # print("input_metadata.attn_bias", input_metadata.attn_bias) @@ -300,17 +308,20 @@ def forward( if input_metadata.attn_bias is None: position_bias = self.compute_bias( 1 if input_metadata.is_prompt else context_len, - (context_len + block_size - 1) // block_size * block_size - ).repeat(batch_size, 1, 1, 1) + (context_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) # print("position_bias shape", position_bias.shape) # print("position_bias", position_bias) - input_metadata.attn_bias = position_bias[:, :, -seq_len:, :].contiguous() + input_metadata.attn_bias = position_bias[:, :, + -seq_len:, :].contiguous( + ) # print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape) # print("input_metadata.attn_bias", input_metadata.attn_bias) key_cache, value_cache = kv_cache - attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) else: # print("cross attention!") @@ -324,18 +335,19 @@ def forward( # print("k shape", k.shape) # for i in range(k.shape[0]): # for j in range(k.shape[1]): - # print(f"key at batch {i} and pos {j}: ", k[i, j, :].reshape(1, 8, 64)) - attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) + # print(f"key at batch {i} and pos {j}: ", k[i, j, :].reshape(1, 8, 64)) + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) else: - attn_output = self.attn( - q, None, None, key_cache, value_cache, input_metadata - ) + attn_output = self.attn(q, None, None, key_cache, value_cache, + input_metadata) attn_output, _ = self.o(attn_output) return attn_output class T5LayerSelfAttention(nn.Module): + def __init__( self, config, @@ -349,7 +361,8 @@ def __init__( has_relative_attention_bias=has_relative_attention_bias, linear_method=linear_method, ) - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) def forward( self, @@ -373,6 +386,7 @@ def forward( class T5LayerCrossAttention(nn.Module): + def __init__( self, config, @@ -385,7 +399,8 @@ def __init__( has_relative_attention_bias=False, linear_method=linear_method, ) - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) def forward( self, @@ -410,6 +425,7 @@ def forward( class T5Block(nn.Module): + def __init__( self, config, @@ -424,12 +440,10 @@ def __init__( config, has_relative_attention_bias=has_relative_attention_bias, linear_method=linear_method, - ) - ) + )) if self.is_decoder: self.layer.append( - T5LayerCrossAttention(config, linear_method=linear_method) - ) + T5LayerCrossAttention(config, linear_method=linear_method)) self.layer.append(T5LayerFF(config)) @@ -452,9 +466,9 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) if self.is_decoder: hidden_states = self.layer[1]( @@ -469,9 +483,9 @@ def forward( torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max, ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -480,6 +494,7 @@ def forward( class T5Stack(nn.Module): + def __init__( self, config: T5Config, @@ -490,20 +505,16 @@ def __init__( self.is_decoder = config.is_decoder self.embed_tokens = embed_tokens - self.block = nn.ModuleList( - [ - T5Block( - config, - has_relative_attention_bias=(i == 0), - linear_method=linear_method, - ) - for i in range(config.num_layers) - ] - ) + self.block = nn.ModuleList([ + T5Block( + config, + has_relative_attention_bias=(i == 0), + linear_method=linear_method, + ) for i in range(config.num_layers) + ]) - self.final_layer_norm = T5LayerNorm( - config.d_model, eps=config.layer_norm_epsilon - ) + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) def forward( self, @@ -531,8 +542,8 @@ def forward( hidden_states = self.final_layer_norm(hidden_states) # if encoder_hidden_states is not None: - # print("hidden_states shape:" , hidden_states.shape) - # print("encoder_hidden_states shape:" , encoder_hidden_states.shape) + # print("hidden_states shape:" , hidden_states.shape) + # print("encoder_hidden_states shape:" , encoder_hidden_states.shape) # # Attach encoder hidden states # hidden_states = torch.cat( # [encoder_hidden_states, hidden_states], dim=1 @@ -543,9 +554,10 @@ def forward( class T5ForConditionalGeneration(nn.Module): - def __init__( - self, config: T5Config, linear_method: Optional[LinearMethodBase] = None - ): + + def __init__(self, + config: T5Config, + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config self.model_dim = config.d_model @@ -574,22 +586,20 @@ def forward( # print("input_metadata: ", input_metadata) if input_metadata.is_prompt: # prompt run, need to run encoder once - hidden_states = self.encoder(input_ids, kv_caches, input_metadata, None) + hidden_states = self.encoder(input_ids, kv_caches, input_metadata, + None) # Clear the attention bias input_metadata.attn_bias = None batch_size = input_ids.shape[0] - input_ids = ( - torch.ones(batch_size, 1, dtype=torch.long) - * self.config.decoder_start_token_id - ).cuda() + input_ids = (torch.ones(batch_size, 1, dtype=torch.long) * + self.config.decoder_start_token_id).cuda() else: hidden_states = None if kv_caches[0][0] is not None: # Skip decoder for profiling run - hidden_states = self.decoder( - input_ids, kv_caches, input_metadata, hidden_states - ) + hidden_states = self.decoder(input_ids, kv_caches, input_metadata, + hidden_states) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab @@ -598,9 +608,11 @@ def forward( return hidden_states - def sample(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata): + def sample(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata): # logger.info(f"decoder_outputs: {decoder_outputs}") - next_tokens = self.sampler(self.shared.weight, hidden_states, sampling_metadata) + next_tokens = self.sampler(self.shared.weight, hidden_states, + sampling_metadata) # logger.info(f"next_tokens: {next_tokens}") return next_tokens @@ -613,8 +625,7 @@ def load_weights( ): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + model_name_or_path, cache_dir, load_format, revision): if "EncDecAttention.relative_attention_bias" in name: continue @@ -622,7 +633,7 @@ def load_weights( param = params_dict[name] assert param.shape == loaded_weight.shape, ( f"{name} shape mismatch between model and checkpoint: " - f"{param.shape} != {loaded_weight.shape}" - ) - weight_loader = getattr(param, "weight_loader", default_weight_loader) + f"{param.shape} != {loaded_weight.shape}") + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d17d5b080487a..0e2cf1650ba84 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -263,7 +263,7 @@ def _prepare_prompt( max_len=max_block_table_len, pad=0, dtype=torch.int, - device = self.device) + device=self.device) else: # Prepare prefix block tables max_prompt_block_table_len = max( @@ -354,7 +354,8 @@ def _prepare_decode( context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] - max_block_table_len = max(max_block_table_len, len(block_table)) + max_block_table_len = max(max_block_table_len, + len(block_table)) block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset @@ -415,7 +416,9 @@ def _prepare_decode( dtype=torch.int, device=self.device) - prompt_lens = torch.tensor(prompt_lens, dtype=torch.int, device=self.device) + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. From be58c3b95af0e7c9c3f58e3ae56726f42b140076 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 1 Mar 2024 15:35:13 -0500 Subject: [PATCH 04/11] T5 enc/dec example file; linting/formatting --- examples/offline_inference_enc_dec.py | 38 +++++++++++++++++++ test.py | 22 ----------- .../layers/enc_dec_attention.py | 1 - vllm/model_executor/models/t5.py | 8 ++-- vllm/sequence.py | 2 +- vllm/worker/model_runner.py | 1 - 6 files changed, 42 insertions(+), 30 deletions(-) create mode 100644 examples/offline_inference_enc_dec.py delete mode 100644 test.py diff --git a/examples/offline_inference_enc_dec.py b/examples/offline_inference_enc_dec.py new file mode 100644 index 0000000000000..b2a47865b1cd2 --- /dev/null +++ b/examples/offline_inference_enc_dec.py @@ -0,0 +1,38 @@ +''' +Affirm T5 model outputs match between vLLM and native PyTorch + +Scenarios: +* t5-small, t5-large +* float16, float32, bfloat16, bfloat32 +* Custom prompts & num. prompts +''' + +from vllm import LLM, SamplingParams + +hf_model_id="t5-small" +dtype="float16" +prompts=[ + "Who are you?", + "Who are you?", + "How do you like your egg made", + "How do you like your egg made", +] + + +model = LLM(hf_model_id, + enforce_eager=True, + dtype=dtype, + gpu_memory_utilization=0.5) + +sampling_params = SamplingParams(max_tokens=100, temperature=0) + +outputs = model.generate( + prompts, + sampling_params=sampling_params, +) + +# Print the vLLM outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/test.py b/test.py deleted file mode 100644 index cdf49ebad2b12..0000000000000 --- a/test.py +++ /dev/null @@ -1,22 +0,0 @@ -from vllm import LLM, SamplingParams - -# print proxies - -model = LLM("t5-large", - enforce_eager=True, - dtype="float16", - gpu_memory_utilization=0.5) -# model = LLM("gpt2", enforce_eager=True, dtype="float16") -sampling_params = SamplingParams(max_tokens=100, temperature=0) - -outputs = model.generate( - [ - "Who are you?", - "Who are you?", - "How do you like your egg made", - "How do you like your egg made", - ], - sampling_params=sampling_params, -) - -print(outputs) diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py index 9d63fbe7ae36d..fa14e5d09ff49 100644 --- a/vllm/model_executor/layers/enc_dec_attention.py +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from xformers import ops as xops from xformers.ops.fmha.attn_bias import ( diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index f3403926d3e6c..c16577b50e37c 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -18,7 +18,8 @@ """ PyTorch T5 model.""" from typing import List, Optional, Tuple -import math, copy +import math +import copy import torch from torch import nn @@ -34,11 +35,9 @@ from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearMethodBase, - QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, ) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -46,7 +45,6 @@ default_weight_loader, hf_model_weights_iterator, ) -from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -63,7 +61,7 @@ def __init__(self, hidden_size, eps=1e-6): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 diff --git a/vllm/sequence.py b/vllm/sequence.py index f601837b40915..9f74e74d20b04 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -150,7 +150,7 @@ def __init__( self.logical_token_blocks: List[LogicalTokenBlock] = [] initial_token_ids = prompt_token_ids if is_decoder_encoder: - # We need to seperate the prompt and generated tokens for encoder-decoder models. + # We need to separate the prompt and generated tokens for encoder-decoder models. num_prompt_blocks = (len(prompt_token_ids) + block_size - 1) // block_size padded_prompt_len = num_prompt_blocks * block_size diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0e2cf1650ba84..43e0a0b9074a5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -773,7 +773,6 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() - prompt_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( From 70837fd046929ce60ac95084e0edf639a02f2b71 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 1 Mar 2024 16:37:08 -0500 Subject: [PATCH 05/11] native/vllm t5 comparison test --- examples/offline_inference_enc_dec.py | 52 ++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/examples/offline_inference_enc_dec.py b/examples/offline_inference_enc_dec.py index b2a47865b1cd2..f554302fefb8c 100644 --- a/examples/offline_inference_enc_dec.py +++ b/examples/offline_inference_enc_dec.py @@ -5,20 +5,48 @@ * t5-small, t5-large * float16, float32, bfloat16, bfloat32 * Custom prompts & num. prompts + +Output: for several prompts, compare native PyTorch & vLLM prompt completions ''' +import torch from vllm import LLM, SamplingParams +from transformers import T5Tokenizer, T5ForConditionalGeneration -hf_model_id="t5-small" -dtype="float16" -prompts=[ +hf_model_id = "t5-small" +dtype = "bfloat16" +prompts = [ "Who are you?", "Who are you?", "How do you like your egg made", "How do you like your egg made", ] +dtype_obj = getattr(torch, dtype) + +# Native PyTorch test + +# - Model and tokenizer initialization +tokenizer = T5Tokenizer.from_pretrained(hf_model_id) +model = T5ForConditionalGeneration.from_pretrained(hf_model_id).to( + dtype=dtype_obj) + +# - Assume 'dtype' is already defined, e.g., dtype=torch.float32 +# - Tokenizing the prompts list with specified data type +input_ids = tokenizer(prompts, + return_tensors="pt", + padding=True, + truncation=True).input_ids + +# - If using GPU, also send input_ids to the same device as the model +if torch.cuda.is_available(): + model = model.cuda() # Move model to GPU + input_ids = input_ids.cuda() # Move input_ids to GPU + +# - Generating outputs for all tokenized prompts +native_outputs = model.generate(input_ids).cpu() +# vLLM test model = LLM(hf_model_id, enforce_eager=True, dtype=dtype, @@ -26,13 +54,19 @@ sampling_params = SamplingParams(max_tokens=100, temperature=0) -outputs = model.generate( +vllm_outputs = model.generate( prompts, sampling_params=sampling_params, ) -# Print the vLLM outputs. -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +# Print native & vLLM outputs +i = 0 +for native_output, vllm_output in zip(native_outputs, vllm_outputs): + prompt = prompts[i] # Get the corresponding prompt for this output + native_generated_text = tokenizer.decode( + native_output, skip_special_tokens=True) # Decode the generated text + vllm_generated_text = vllm_output.outputs[0].text + print( + f"Prompt: {prompt!r}, Native PyTorch generated text: {native_generated_text!r}, vLLM generated text: {vllm_generated_text!r}" + ) + i += 1 From 43e920e233e4da5658a24fea9b2c19787a7fde3e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 2 Mar 2024 00:05:28 -0500 Subject: [PATCH 06/11] remove debug print statements --- vllm/model_executor/models/t5.py | 44 -------------------------------- 1 file changed, 44 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index c16577b50e37c..d6c27ed5b4903 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -258,26 +258,17 @@ def forward( input_metadata: InputMetadata, encoder_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: - # print("hidden_states shape", hidden_states.shape) - # print("hidden_states", hidden_states) q, _ = self.q(hidden_states) - # print("q shape", q.shape) - # print("q", q) batch_size = hidden_states.shape[0] seq_len = hidden_states.shape[1] prompt_len = input_metadata.prompt_lens.max().item() context_len = input_metadata.context_lens.max().item() context_len = max(context_len, 1) - # print("batch_size", batch_size) - # print("seq_len", seq_len) - # print("prompt_len", prompt_len) - # print("context_len", context_len) block_size = 16 if not self.is_decoder: - # print("encoder self attention!") assert kv_cache is None # Encoder self attention, no cache operations k, _ = self.k(hidden_states) @@ -293,12 +284,9 @@ def forward( input_metadata.prompt_lens[i]:, ] = torch.finfo( input_metadata.attn_bias.dtype).min - # print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape) - # print("input_metadata.attn_bias", input_metadata.attn_bias) attn_output = self.attn(q, k, v, input_metadata) elif not self.is_cross: - # print("decoder self attention!") # Decoder self attention k, _ = self.k(hidden_states) v, _ = self.v(hidden_states) @@ -308,13 +296,9 @@ def forward( 1 if input_metadata.is_prompt else context_len, (context_len + block_size - 1) // block_size * block_size).repeat(batch_size, 1, 1, 1) - # print("position_bias shape", position_bias.shape) - # print("position_bias", position_bias) input_metadata.attn_bias = position_bias[:, :, -seq_len:, :].contiguous( ) - # print("input_metadata.attn_bias shape", input_metadata.attn_bias.shape) - # print("input_metadata.attn_bias", input_metadata.attn_bias) key_cache, value_cache = kv_cache @@ -322,7 +306,6 @@ def forward( input_metadata) else: - # print("cross attention!") # Cross attention key_cache, value_cache = kv_cache @@ -330,10 +313,6 @@ def forward( assert encoder_hidden_states is not None k, _ = self.k(encoder_hidden_states) v, _ = self.v(encoder_hidden_states) - # print("k shape", k.shape) - # for i in range(k.shape[0]): - # for j in range(k.shape[1]): - # print(f"key at batch {i} and pos {j}: ", k[i, j, :].reshape(1, 8, 64)) attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) else: @@ -369,16 +348,12 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) - print("self attention input shape: ", normed_hidden_states.shape) - print("self_attention input: ", normed_hidden_states) attention_output = self.SelfAttention( hidden_states=normed_hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, encoder_hidden_states=None, ) - print("self attention output shape: ", attention_output.shape) - print("self_attention output: ", attention_output) hidden_states = hidden_states + attention_output return hidden_states @@ -408,16 +383,12 @@ def forward( encoder_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) - print("cross attention input shape: ", normed_hidden_states.shape) - print("cross_attention input: ", normed_hidden_states) attention_output = self.EncDecAttention( hidden_states=normed_hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, encoder_hidden_states=encoder_hidden_states, ) - print("cross attention output shape: ", attention_output.shape) - print("cross_attention output: ", attention_output) hidden_states = hidden_states + attention_output return hidden_states @@ -521,11 +492,8 @@ def forward( input_metadata: InputMetadata, encoder_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: - # print("input_ids: ", input_ids) hidden_states = self.embed_tokens(input_ids) - # print("hidden_states shape: ", hidden_states.shape) - # print("hidden_states: ", hidden_states) for i, layer_module in enumerate(self.block): kv_cache = kv_caches[i] if self.is_decoder else None @@ -539,15 +507,6 @@ def forward( hidden_states = layer_outputs hidden_states = self.final_layer_norm(hidden_states) - # if encoder_hidden_states is not None: - # print("hidden_states shape:" , hidden_states.shape) - # print("encoder_hidden_states shape:" , encoder_hidden_states.shape) - # # Attach encoder hidden states - # hidden_states = torch.cat( - # [encoder_hidden_states, hidden_states], dim=1 - # ) - print("final_hidden_states shape: ", hidden_states.shape) - print("final_hidden_states: ", hidden_states) return hidden_states @@ -579,9 +538,6 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - # print("input_ids shape: ", input_ids.shape) - # print("input_ids: ", input_ids) - # print("input_metadata: ", input_metadata) if input_metadata.is_prompt: # prompt run, need to run encoder once hidden_states = self.encoder(input_ids, kv_caches, input_metadata, From 431f0147521939c3cfa550c628a77ea26e849e04 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 2 Mar 2024 00:24:21 -0500 Subject: [PATCH 07/11] silence warning; legacy=False for tokenizer; lint/format --- examples/offline_inference_enc_dec.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference_enc_dec.py b/examples/offline_inference_enc_dec.py index f554302fefb8c..83c9d4d0377ee 100644 --- a/examples/offline_inference_enc_dec.py +++ b/examples/offline_inference_enc_dec.py @@ -8,11 +8,15 @@ Output: for several prompts, compare native PyTorch & vLLM prompt completions ''' - +import warnings import torch from vllm import LLM, SamplingParams from transformers import T5Tokenizer, T5ForConditionalGeneration +warnings.filterwarnings("ignore", + category=UserWarning, + module="transformers.generation.utils.*") + hf_model_id = "t5-small" dtype = "bfloat16" prompts = [ @@ -27,7 +31,7 @@ # Native PyTorch test # - Model and tokenizer initialization -tokenizer = T5Tokenizer.from_pretrained(hf_model_id) +tokenizer = T5Tokenizer.from_pretrained(hf_model_id, legacy=False) model = T5ForConditionalGeneration.from_pretrained(hf_model_id).to( dtype=dtype_obj) From 8a5060fc458029bf2194de10862ac5ae2cc94ab1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 5 Mar 2024 07:39:13 -0500 Subject: [PATCH 08/11] fix _make_tensor_with_pad args change which broke decoder scenario --- vllm/worker/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 43e0a0b9074a5..6f868c09e45f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -273,6 +273,7 @@ def _prepare_prompt( max_len=max_prompt_block_table_len, pad=0, dtype=torch.int, + device=self.device ) start_loc_tensor = torch.arange(0, len(prompt_lens) * max_prompt_len, From 29d6f4467a768105d561d6391bdb86c9f71f4f72 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 5 Mar 2024 12:52:22 -0500 Subject: [PATCH 09/11] fixed bug caused by non-handling of self.model_config is None in model_runner.py --- vllm/worker/model_runner.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6f868c09e45f6..cfd21411ffba7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -83,8 +83,13 @@ def __init__( # Set enforce_eager to True for Neuron backend, to avoid capturing graph if self.device_config.is_neuron: self.model_config.enforce_eager = True - self.is_encoder_decoder = getattr(self.model_config.hf_config, - "is_encoder_decoder", False) + + # Unpack HF is_encoder_decoder config attribute + # NOTE: must handle "self.model_config is None" case imposed by certain tests i.e. test_prepare_prompt() + # In the None case, default to is_encoder_decoder == False since vLLM decoder-only mode is known to handle + # the None case correctly. + self.is_encoder_decoder = False if self.model_config is None else \ + getattr(self.model_config.hf_config, "is_encoder_decoder", False) def load_model(self) -> None: self.model = get_model(self.model_config, From a4950baad6d3b296750587b7053b5483f615c31f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 5 Mar 2024 13:11:49 -0500 Subject: [PATCH 10/11] remove commented-out print statements --- vllm/model_executor/layers/attention.py | 34 ----------------- .../layers/enc_dec_attention.py | 37 ------------------- vllm/model_executor/layers/sampler.py | 36 ++---------------- vllm/worker/model_runner.py | 11 +----- 4 files changed, 6 insertions(+), 112 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index c82d5aff83a3c..24b4e7d2851f4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -195,9 +195,6 @@ def forward( key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - # print("query.shape: ", query.shape) - # print("key.shape: ", key.shape) - # print("value.shape: ", value.shape) out = xops.memory_efficient_attention_forward( query, key, @@ -292,7 +289,6 @@ def paged_attention( num_seqs, num_heads, head_size = query.shape max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) - # print("max_num_partitions: ", max_num_partitions) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -303,25 +299,6 @@ def paged_attention( use_v1 = max_context_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: - # print("v1") - # print("output: ", output) - # print("query: ", query) - # print("num_kv_heads: ", num_kv_heads) - # print("scale: ", scale) - # print("block_tables: ", block_tables) - # print("context_lens: ", context_lens) - # print("block_size: ", block_size) - # print("max_context_len: ", max_context_len) - # print("alibi_slopes: ", alibi_slopes) - # print("custom_bias: ", custom_bias) - # print("key_cache shape: ", key_cache.shape) - # print("value_cache shape: ", value_cache.shape) - # for block_table in block_tables: - # for block in block_table: - # print(f"key_cache at {block} shape: ", key_cache[block].shape) - # print(f"key_cache at {block}: ", key_cache[block]) - # print(f"value_cache at {block} shape: ", value_cache[block].shape) - # print(f"value_cache at {block}: ", value_cache[block]) # Run PagedAttention V1. ops.paged_attention_v1( output, @@ -339,7 +316,6 @@ def paged_attention( kv_cache_dtype, ) else: - # print("v2") # Run PagedAttention V2. assert _PARTITION_SIZE % block_size == 0 tmp_output = torch.empty( @@ -353,16 +329,6 @@ def paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) - # print("output: ", output) - # print("query: ", query) - # print("num_kv_heads: ", num_kv_heads) - # print("scale: ", scale) - # print("block_tables: ", block_tables) - # print("context_lens: ", context_lens) - # print("block_size: ", block_size) - # print("max_context_len: ", max_context_len) - # print("alibi_slopes: ", alibi_slopes) - # print("custom_bias: ", custom_bias) ops.paged_attention_v2( output, exp_sums, diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py index fa14e5d09ff49..bbdeee8e5e343 100644 --- a/vllm/model_executor/layers/enc_dec_attention.py +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -73,20 +73,12 @@ def forward( query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_heads, self.head_size) value = value.view(batch_size, seq_len, self.num_heads, self.head_size) - # print("query shape: ", query.shape) if input_metadata.attn_bias is None: input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( [seq_len] * batch_size) - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - # padded_len = (seq_len + 7) // 8 * 8 - # pad_len = padded_len - seq_len - # input_metadata.attn_bias = F.pad(input_metadata.attn_bias, (0, pad_len)) - # print("attention bias padded shape: ", input_metadata.attn_bias.shape) input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] - # print("attention bias shape: ", input_metadata.attn_bias.shape) # Normal attention out = xops.memory_efficient_attention_forward( query, @@ -135,42 +127,28 @@ def forward( Output tensor. """ - # print("key shape pre view: ", key.shape) - # print("value shape pre view: ", value.shape) - batch_size, seq_len, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - # print("key shape: ", key.shape) - # print("key: ", key) - # print("value shape: ", value.shape) - # print("value: ", value) - # print("slot mapping: ", input_metadata.slot_mapping[:, -1].flatten()) # Reshape the keys and values and store them in the cache. # If key_cache and value_cache are not provided, the new key and value # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - # print("key_cache before: ", key_cache) - # print("value_cache before: ", value_cache) cache_ops.reshape_and_cache( key, value, key_cache, value_cache, input_metadata.slot_mapping[:, -1].flatten().contiguous(), input_metadata.kv_cache_dtype) - # print("key_cache after: ", key_cache) - # print("value_cache after: ", value_cache) - max_prompt_len = input_metadata.prompt_lens.max().item() block_size = value_cache.shape[3] prompt_table_len = (max_prompt_len + block_size - 1) // block_size block_tables = input_metadata.block_tables[:, prompt_table_len:].contiguous( ) - # print("decoder self attention block_tables", block_tables) output = paged_attention( query=query, key_cache=key_cache, @@ -222,18 +200,10 @@ def forward( # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) if key is not None: - # print("key shape pre view: ", key.shape) key = key.view(-1, self.num_heads, self.head_size) - # print("key_shape: ", key.shape) - # print("key sum", key.sum((1, 2))) if value is not None: - # print("value shape pre view: ", value.shape) value = value.view(-1, self.num_heads, self.head_size) - # print("value_shape: ", value.shape) - # print("value sum", value.sum((1, 2))) - # print("slot mapping: ", input_metadata.slot_mapping[:, :-1].flatten().shape) - # print("slot mapping: ", input_metadata.slot_mapping[:, :-1].flatten()) # Reshape the keys and values and store them in the cache. # It only happens during the first pass. if (input_metadata.is_prompt and key_cache is not None @@ -248,14 +218,7 @@ def forward( input_metadata.kv_cache_dtype, ) - # for slot in input_metadata.slot_mapping[:, :-1].flatten(): - # if slot != -1: - # block_number = slot//16; - # block_offset = slot%16; - # print(f"key_cache sum at {slot}: ", key_cache[block_number, :, :, block_offset, :].sum()) - # print(f"value_cache sum at {slot}: ", value_cache[block_number, :, :, block_offset].sum()) max_prompt_len = input_metadata.prompt_lens.int().max().item() - # print("max_prompt_len: ", max_prompt_len) block_size = value_cache.shape[3] prompt_table_len = (max_prompt_len + block_size - 1) // block_size block_tables = input_metadata.block_tables[:, : diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 25baac43fac71..71655b216fb3d 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -57,9 +57,6 @@ def forward( sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[SamplerOutput]: - # print("hidden_states shape: ", hidden_states.shape) - # print("hidden_states: ", hidden_states) - # Get the hidden states that we use for sampling. if self.logits_as_hidden_states: logits = hidden_states @@ -70,9 +67,6 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) - # print("Logits shape: ", logits.shape) - # print("Logits: ", logits) - # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because # the `embedding` weight is distributed across TP workers. @@ -83,12 +77,9 @@ def forward( assert logits is not None _, vocab_size = logits.shape - # print("Logits shape: ", logits.shape) - # print("Logits: ", logits) # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata) - # print("Logits shape: ", logits.shape) - # print("Logits: ", logits) + # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -101,23 +92,18 @@ def forward( sampling_tensors.presence_penalties, sampling_tensors.frequency_penalties, sampling_tensors.repetition_penalties) - # print("Logits shape: ", logits.shape) - # print("Logits: ", logits) + # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) - # print("Logits shape: ", logits.shape) - # print("Logits: ", logits) + if do_top_p_top_k: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) - # print("Logits shape: ", logits.shape) - # print("Logits: ", logits) + if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) - # print("Logits shape: ", logits.shape) - # print("Logits: ", logits) # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) @@ -126,18 +112,10 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - # print("Probs shape: ", probs.shape) - # print("Probs: ", probs) - # print("Logprobs shape: ", logprobs.shape) - # print("Logprobs: ", logprobs) - sample_results = _sample(probs, logprobs, sampling_metadata) # Get the logprobs query results. - # print("Sample results: ", sample_results) prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - # print("Prompt logprobs: ", prompt_logprobs) - # print("Sample logprobs: ", sample_logprobs) return _build_sampler_output(sample_results, sampling_metadata, prompt_logprobs, sample_logprobs) @@ -400,8 +378,6 @@ def _sample( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> List[Tuple[List[int], List[int]]]: - # print("probs: ", probs) - # print("logprobs: ", logprobs) categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): @@ -417,15 +393,11 @@ def _sample( # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: sample_indices = categorized_sample_indices[sampling_type] - # print("sampling_type: ", sampling_type) - # print("sample_indices: ", sample_indices) num_tokens = len(sample_indices) if num_tokens == 0: continue seq_group_ids = categorized_seq_group_ids[sampling_type] - # print("seq_group_ids: ", seq_group_ids) seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - # print("seq_groups: ", seq_groups) is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_metadata[sampling_type] = (seq_group_ids, seq_groups, is_prompts, sample_indices) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cfd21411ffba7..8d6da05ffc334 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -221,8 +221,6 @@ def _prepare_prompt( block_tables.append(block_table) max_block_table_len = max(max_block_table_len, len(block_table)) - # print("slot_mapping: ", slot_mapping) - # print("block_tables: ", block_tables) max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, @@ -278,8 +276,7 @@ def _prepare_prompt( max_len=max_prompt_block_table_len, pad=0, dtype=torch.int, - device=self.device - ) + device=self.device) start_loc_tensor = torch.arange(0, len(prompt_lens) * max_prompt_len, max_prompt_len, @@ -527,7 +524,6 @@ def _prepare_sample( dtype=torch.long, target_device=self.device, pin_memory=pin_memory) - # print("selected_token_indices: ", selected_token_indices) categorized_sample_indices = { t: _async_h2d(seq_ids, dtype=torch.int, @@ -535,12 +531,11 @@ def _prepare_sample( pin_memory=pin_memory) for t, seq_ids in categorized_sample_indices.items() } - # print("categorized_sample_indices: ", categorized_sample_indices) seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) - # print("selected_token_indices: ", selected_token_indices) + sampling_metadata = SamplingMetadata( seq_groups=seq_groups, seq_data=seq_data, @@ -662,8 +657,6 @@ def execute_model( kv_caches=kv_caches, input_metadata=input_metadata, ) - # print("hidden_states shape: ", hidden_states.shape) - # print("hidden_states: ", hidden_states) # Sample the next token. output = self.model.sample( From 9c0376028065f9abd7493f64bf28949bd56da2de Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 5 Mar 2024 13:15:01 -0500 Subject: [PATCH 11/11] small cleanup --- vllm/model_executor/models/t5.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index d6c27ed5b4903..e49c99e5d5315 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -564,10 +564,8 @@ def forward( def sample(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata): - # logger.info(f"decoder_outputs: {decoder_outputs}") next_tokens = self.sampler(self.shared.weight, hidden_states, sampling_metadata) - # logger.info(f"next_tokens: {next_tokens}") return next_tokens def load_weights(