diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..f9dde9e2b531b --- /dev/null +++ b/.clang-format @@ -0,0 +1,4 @@ +# Use the Google style in this project. +BasedOnStyle: Google + +ColumnLimit: 120 diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1ad502526c97c..d80cf8ef5409b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,8 +6,8 @@ from typing import List, Optional, Tuple import torch -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) +from transformers import (AutoModelForCausalLM, T5ForConditionalGeneration, + AutoTokenizer, PreTrainedTokenizerBase) from tqdm import tqdm @@ -123,8 +123,16 @@ def run_hf( trust_remote_code: bool, ) -> float: assert not use_beam_search - llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if "t5" in model: + llm = T5ForConditionalGeneration.from_pretrained( + model, + torch_dtype=torch.float16, + trust_remote_code=trust_remote_code) + else: + llm = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype=torch.float16, + trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index b5be3befa07e2..0f561ab3155b7 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -19,9 +20,12 @@ #include #endif -#include #include #include +#include +#include + +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -29,8 +33,6 @@ #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" #endif -#include - #ifndef USE_ROCM #define WARP_SIZE 32 #else @@ -38,12 +40,12 @@ #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -101,6 +103,7 @@ __device__ void paged_attention_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, max_seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride) { @@ -141,6 +144,10 @@ __device__ void paged_attention_kernel( const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float* custom_bias_vec = custom_bias == nullptr + ? nullptr + : custom_bias + seq_idx * num_kv_heads * num_context_blocks * BLOCK_SIZE + + kv_head_idx * num_context_blocks * BLOCK_SIZE; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread group @@ -232,8 +239,10 @@ __device__ void paged_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + // Add the custom or ALiBi bias if given. + qk += (custom_bias_vec != nullptr) ? custom_bias_vec[token_idx] + : (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) + : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -443,13 +452,14 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); + max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -474,13 +484,14 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); + custom_bias, q_stride, kv_block_stride, kv_head_stride); } // Grid: (num_heads, num_seqs). @@ -600,6 +611,7 @@ __global__ void paged_attention_v2_reduce_kernel( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride); @@ -621,7 +633,8 @@ void paged_attention_v1_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + const c10::optional& custom_bias) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -634,9 +647,11 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + // NOTE: alibi_slopes is optional. + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -696,7 +711,8 @@ void paged_attention_v1_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + custom_bias); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -728,6 +744,7 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { @@ -770,6 +787,7 @@ void paged_attention_v1( context_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ + custom_bias_ptr, \ q_stride, \ kv_block_stride, \ kv_head_stride); \ @@ -802,7 +820,8 @@ void paged_attention_v2_launcher( torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes, + const c10::optional& custom_bias) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -815,9 +834,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + + const float* custom_bias_ptr = custom_bias ? reinterpret_cast(custom_bias.value().data_ptr()) : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -886,7 +906,8 @@ void paged_attention_v2_launcher( block_tables, \ context_lens, \ max_context_len, \ - alibi_slopes); + alibi_slopes, \ + custom_bias); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -921,6 +942,7 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype) { if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Float) { diff --git a/csrc/ops.h b/csrc/ops.h index 249c7451bf73c..156174886cfc3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -14,6 +14,7 @@ void paged_attention_v1( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype); void paged_attention_v2( @@ -31,6 +32,7 @@ void paged_attention_v2( int block_size, int max_context_len, const c10::optional& alibi_slopes, + const c10::optional& custom_bias, const std::string& kv_cache_dtype); void rms_norm( diff --git a/examples/offline_inference_enc_dec.py b/examples/offline_inference_enc_dec.py new file mode 100644 index 0000000000000..83c9d4d0377ee --- /dev/null +++ b/examples/offline_inference_enc_dec.py @@ -0,0 +1,76 @@ +''' +Affirm T5 model outputs match between vLLM and native PyTorch + +Scenarios: +* t5-small, t5-large +* float16, float32, bfloat16, bfloat32 +* Custom prompts & num. prompts + +Output: for several prompts, compare native PyTorch & vLLM prompt completions +''' +import warnings +import torch +from vllm import LLM, SamplingParams +from transformers import T5Tokenizer, T5ForConditionalGeneration + +warnings.filterwarnings("ignore", + category=UserWarning, + module="transformers.generation.utils.*") + +hf_model_id = "t5-small" +dtype = "bfloat16" +prompts = [ + "Who are you?", + "Who are you?", + "How do you like your egg made", + "How do you like your egg made", +] + +dtype_obj = getattr(torch, dtype) + +# Native PyTorch test + +# - Model and tokenizer initialization +tokenizer = T5Tokenizer.from_pretrained(hf_model_id, legacy=False) +model = T5ForConditionalGeneration.from_pretrained(hf_model_id).to( + dtype=dtype_obj) + +# - Assume 'dtype' is already defined, e.g., dtype=torch.float32 +# - Tokenizing the prompts list with specified data type +input_ids = tokenizer(prompts, + return_tensors="pt", + padding=True, + truncation=True).input_ids + +# - If using GPU, also send input_ids to the same device as the model +if torch.cuda.is_available(): + model = model.cuda() # Move model to GPU + input_ids = input_ids.cuda() # Move input_ids to GPU + +# - Generating outputs for all tokenized prompts +native_outputs = model.generate(input_ids).cpu() + +# vLLM test +model = LLM(hf_model_id, + enforce_eager=True, + dtype=dtype, + gpu_memory_utilization=0.5) + +sampling_params = SamplingParams(max_tokens=100, temperature=0) + +vllm_outputs = model.generate( + prompts, + sampling_params=sampling_params, +) + +# Print native & vLLM outputs +i = 0 +for native_output, vllm_output in zip(native_outputs, vllm_outputs): + prompt = prompts[i] # Get the corresponding prompt for this output + native_generated_text = tokenizer.decode( + native_output, skip_special_tokens=True) # Decode the generated text + vllm_generated_text = vllm_output.outputs[0].text + print( + f"Prompt: {prompt!r}, Native PyTorch generated text: {native_generated_text!r}, vLLM generated text: {vllm_generated_text!r}" + ) + i += 1 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index df4858a696530..3f63015c382fe 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -144,6 +144,9 @@ def __init__( if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + self.is_encoder_decoder = getattr(self.model_config.hf_config, + "is_encoder_decoder", False) + def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) @@ -477,7 +480,7 @@ def add_request( block_size = self.cache_config.block_size seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - lora_request) + self.is_encoder_decoder, lora_request) # Check whether the input specifies prefix prefix = self.scheduler.prefix_pool.add_or_get_prefix( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc82018d18eb6..2f475dd7924ae 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -166,6 +166,10 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() + if self.llm_engine.is_encoder_decoder: + assert (prefix_pos is None + ), "Encoder decoder models do not support Prefix Cache yet" + # Add requests to the engine. num_requests = len(prompts) if prompts is not None else len( prompt_token_ids) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f8..79ada53e252ad 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -46,6 +46,7 @@ def __init__( def __repr__(self) -> str: return ("InputMetadata(" f"is_prompt={self.is_prompt}, " + f"prompt_lens={self.prompt_lens}, " f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 2a82325b80213..24b4e7d2851f4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -226,15 +226,11 @@ def forward( else: # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - ) + output = paged_attention( + query, key_cache, value_cache, input_metadata.block_tables, + input_metadata.context_lens, input_metadata.max_context_len, + self.num_kv_heads, self.scale, self.alibi_slopes, None, + input_metadata.kv_cache_dtype) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -274,22 +270,25 @@ def _make_alibi_bias( return attn_bias -def _paged_attention( +def paged_attention( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - input_metadata: InputMetadata, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + custom_bias: Optional[torch.Tensor], + kv_cache_dtype: torch.dtype, ) -> torch.Tensor: output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) + max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -297,8 +296,8 @@ def _paged_attention( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) + use_v1 = max_context_len <= 8192 and (max_num_partitions == 1 + or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -308,12 +307,13 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, ) else: # Run PagedAttention V2. @@ -339,11 +339,12 @@ def _paged_attention( value_cache, num_kv_heads, scale, - input_metadata.block_tables, - input_metadata.context_lens, + block_tables, + context_lens, block_size, - input_metadata.max_context_len, + max_context_len, alibi_slopes, - input_metadata.kv_cache_dtype, + custom_bias, + kv_cache_dtype, ) return output diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py new file mode 100644 index 0000000000000..bbdeee8e5e343 --- /dev/null +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -0,0 +1,242 @@ +"""Multi-head attention for encoder-decoder models.""" +from typing import Optional + +import torch +import torch.nn as nn + +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, ) +from vllm._C import cache_ops +from vllm.model_executor.input_metadata import InputMetadata +from vllm.utils import is_hip +from vllm.model_executor.layers.attention import paged_attention + +_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] + + +class EncDecAttention(nn.Module): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + + +class EncoderAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Encoder attention forward pass. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + custom_bias: Custom bias tensor. + + Returns: + Output tensor. + """ + # query: [batch_size, seq_len, num_heads * head_size] + # key: [batch_size, seq_len, num_heads * head_size] + # value: [batch_size, seq_len, num_heads * head_size] + # custom_bias: [batch_size, seq_len, seq_len] + # output: [batch_size, seq_len, num_heads * head_size] + + assert input_metadata.is_prompt + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_heads, self.head_size) + if input_metadata.attn_bias is None: + input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + + input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] + + # Normal attention + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view(batch_size, seq_len, hidden_size) + return output + + +class DecoderAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Decoder attention forward pass. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + custom_bias: Custom bias tensor. + + Returns: + Output tensor. + """ + + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_heads, self.head_size) + value = value.view(-1, self.num_heads, self.head_size) + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, + input_metadata.slot_mapping[:, -1].flatten().contiguous(), + input_metadata.kv_cache_dtype) + + max_prompt_len = input_metadata.prompt_lens.max().item() + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, + prompt_table_len:].contiguous( + ) + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.context_lens, + max_context_len=input_metadata.max_context_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=input_metadata.attn_bias.to(torch.float32), + kv_cache_dtype=input_metadata.kv_cache_dtype, + ) + return output.view(batch_size, seq_len, hidden_size) + + +class CrossAttention(EncDecAttention): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + ) -> None: + super().__init__(num_heads, head_size, scale) + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ): + """Cross attention forward pass. + Args: + query: Query tensor. + key_cache: Key cache tensor. + value_cache: Value cache tensor. + input_metadata: Input metadata. + key: Key tensor. Only needed in the first pass. + value: Value tensor. Only needed in the first pass. + custom_bias: Custom bias tensor. + Returns: + Output tensor. + """ + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_heads, self.head_size) + + # Reshape the keys and values and store them in the cache. + # It only happens during the first pass. + if (input_metadata.is_prompt and key_cache is not None + and value_cache is not None): + assert key is not None and value is not None + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping[:, :-1].flatten().contiguous(), + input_metadata.kv_cache_dtype, + ) + + max_prompt_len = input_metadata.prompt_lens.int().max().item() + block_size = value_cache.shape[3] + prompt_table_len = (max_prompt_len + block_size - 1) // block_size + block_tables = input_metadata.block_tables[:, : + prompt_table_len].contiguous( + ) + + output = paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + context_lens=input_metadata.prompt_lens.int(), + max_context_len=max_prompt_len, + num_kv_heads=self.num_heads, + scale=self.scale, + alibi_slopes=None, + custom_bias=None, + kv_cache_dtype=input_metadata.kv_cache_dtype, + ) + + return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 75c2ae1e9f48e..a7d68a7cb3c7a 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -45,6 +45,7 @@ "RWForCausalLM": ("falcon", "FalconForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), } diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py new file mode 100644 index 0000000000000..e49c99e5d5315 --- /dev/null +++ b/vllm/model_executor/models/t5.py @@ -0,0 +1,591 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/t5/modeling_t5.py +# Copyright 2023 The vLLM team. +# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" +from typing import List, Optional, Tuple + +import math +import copy + +import torch +from torch import nn +from transformers import T5Config + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.enc_dec_attention import ( + EncoderAttention, + DecoderAttention, + CrossAttention, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + RowParallelLinear, +) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, ) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class T5LayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_states, _ = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(nn.Module): + + def __init__( + self, + config: T5Config, + is_cross: bool, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + total_num_heads = config.num_heads + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.n_heads = total_num_heads // tensor_model_parallel_world_size + self.inner_dim = self.n_heads * self.key_value_proj_dim + + self.q = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.k = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.v = ColumnParallelLinear(self.d_model, self.inner_dim, bias=False) + self.o = RowParallelLinear(self.inner_dim, self.d_model, bias=False) + + if has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads) + + self.is_cross = is_cross + if self.is_decoder: + if self.is_cross: + self.attn = CrossAttention(self.n_heads, + self.key_value_proj_dim, 1) + else: + self.attn = DecoderAttention(self.n_heads, + self.key_value_proj_dim, 1) + else: + self.attn = EncoderAttention(self.n_heads, self.key_value_proj_dim, + 1) + + @staticmethod + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / + math.log(max_distance / max_exact) * + (num_buckets - max_exact)).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, + dtype=torch.long, + device="cuda")[:, None] + memory_position = torch.arange(key_length, + dtype=torch.long, + device="cuda")[None, :] + relative_position = (memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + # shape (query_length, key_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) + # shape (1, num_heads, query_length, key_length) + values = values.permute([2, 0, 1]).unsqueeze(0) + return values + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + q, _ = self.q(hidden_states) + + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] + prompt_len = input_metadata.prompt_lens.max().item() + context_len = input_metadata.context_lens.max().item() + context_len = max(context_len, 1) + + block_size = 16 + + if not self.is_decoder: + assert kv_cache is None + # Encoder self attention, no cache operations + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) + + if input_metadata.attn_bias is None: + input_metadata.attn_bias = self.compute_bias( + prompt_len, (prompt_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) + for i in range(batch_size): + input_metadata.attn_bias[ + i, :, :, + input_metadata.prompt_lens[i]:, ] = torch.finfo( + input_metadata.attn_bias.dtype).min + + attn_output = self.attn(q, k, v, input_metadata) + + elif not self.is_cross: + # Decoder self attention + k, _ = self.k(hidden_states) + v, _ = self.v(hidden_states) + + if input_metadata.attn_bias is None: + position_bias = self.compute_bias( + 1 if input_metadata.is_prompt else context_len, + (context_len + block_size - 1) // block_size * + block_size).repeat(batch_size, 1, 1, 1) + input_metadata.attn_bias = position_bias[:, :, + -seq_len:, :].contiguous( + ) + + key_cache, value_cache = kv_cache + + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + + else: + # Cross attention + + key_cache, value_cache = kv_cache + if input_metadata.is_prompt: + assert encoder_hidden_states is not None + k, _ = self.k(encoder_hidden_states) + v, _ = self.v(encoder_hidden_states) + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + else: + attn_output = self.attn(q, None, None, key_cache, value_cache, + input_metadata) + + attn_output, _ = self.o(attn_output) + return attn_output + + +class T5LayerSelfAttention(nn.Module): + + def __init__( + self, + config, + has_relative_attention_bias, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.SelfAttention = T5Attention( + config, + is_cross=False, + has_relative_attention_bias=has_relative_attention_bias, + linear_method=linear_method, + ) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=None, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.EncDecAttention = T5Attention( + config, + is_cross=True, + has_relative_attention_bias=False, + linear_method=linear_method, + ) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5Block(nn.Module): + + def __init__( + self, + config, + has_relative_attention_bias: bool, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + linear_method=linear_method, + )) + if self.is_decoder: + self.layer.append( + T5LayerCrossAttention(config, linear_method=linear_method)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ): + hidden_states = self.layer[0]( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + if self.is_decoder: + hidden_states = self.layer[1]( + hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + return hidden_states + + +class T5Stack(nn.Module): + + def __init__( + self, + config: T5Config, + embed_tokens: torch.Tensor, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.embed_tokens = embed_tokens + + self.block = nn.ModuleList([ + T5Block( + config, + has_relative_attention_bias=(i == 0), + linear_method=linear_method, + ) for i in range(config.num_layers) + ]) + + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + for i, layer_module in enumerate(self.block): + kv_cache = kv_caches[i] if self.is_decoder else None + + layer_outputs = layer_module( + hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = layer_outputs + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5ForConditionalGeneration(nn.Module): + + def __init__(self, + config: T5Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + self.encoder = T5Stack(encoder_config, self.shared, linear_method) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + self.decoder = T5Stack(decoder_config, self.shared, linear_method) + + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + if input_metadata.is_prompt: + # prompt run, need to run encoder once + hidden_states = self.encoder(input_ids, kv_caches, input_metadata, + None) + # Clear the attention bias + input_metadata.attn_bias = None + batch_size = input_ids.shape[0] + input_ids = (torch.ones(batch_size, 1, dtype=torch.long) * + self.config.decoder_start_token_id).cuda() + + else: + hidden_states = None + + if kv_caches[0][0] is not None: # Skip decoder for profiling run + hidden_states = self.decoder(input_ids, kv_caches, input_metadata, + hidden_states) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + hidden_states = hidden_states * (self.model_dim**-0.5) + + return hidden_states + + def sample(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata): + next_tokens = self.sampler(self.shared.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "EncDecAttention.relative_attention_bias" in name: + continue + + assert name in params_dict, f"{name} not in params_dict" + param = params_dict[name] + assert param.shape == loaded_weight.shape, ( + f"{name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}") + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index 040e9756e15c6..9f74e74d20b04 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -135,6 +135,7 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + is_decoder_encoder: bool, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id @@ -147,8 +148,20 @@ def __init__( self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] + initial_token_ids = prompt_token_ids + if is_decoder_encoder: + # We need to separate the prompt and generated tokens for encoder-decoder models. + num_prompt_blocks = (len(prompt_token_ids) + block_size - + 1) // block_size + padded_prompt_len = num_prompt_blocks * block_size + initial_token_ids = prompt_token_ids + [0] * ( + padded_prompt_len - len(prompt_token_ids)) + # Also need to append decoder_start_token_id + initial_token_ids.append(0) + # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(initial_token_ids) + self.status = SequenceStatus.WAITING # Used for incremental detokenization diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index efe570778fb43..8d6da05ffc334 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -84,6 +84,13 @@ def __init__( if self.device_config.is_neuron: self.model_config.enforce_eager = True + # Unpack HF is_encoder_decoder config attribute + # NOTE: must handle "self.model_config is None" case imposed by certain tests i.e. test_prepare_prompt() + # In the None case, default to is_encoder_decoder == False since vLLM decoder-only mode is known to handle + # the None case correctly. + self.is_encoder_decoder = False if self.model_config is None else \ + getattr(self.model_config.hf_config, "is_encoder_decoder", False) + def load_model(self) -> None: self.model = get_model(self.model_config, self.device_config, @@ -135,6 +142,8 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + block_tables: List[List[int]] = [] + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -145,16 +154,20 @@ def _prepare_prompt( prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + prefix_len = 0 - prefix = seq_group_metadata.prefix - if prefix is not None and prefix.computed: - prefix_len = prefix.get_length() - prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_tables.append(prefix.get_block_numbers()) + if self.is_encoder_decoder: + context_lens.append(1) else: - prefix_block_tables.append([]) - # actual prompt lens - context_lens.append(prefix_len) + prefix = seq_group_metadata.prefix + if prefix is not None and prefix.computed: + prefix_len = prefix.get_length() + prompt_tokens = prompt_tokens[prefix_len:] + prefix_block_tables.append(prefix.get_block_numbers()) + else: + prefix_block_tables.append([]) + # actual prompt lens + context_lens.append(prefix_len) subquery_lens.append(prompt_len - prefix_len) input_tokens.append(prompt_tokens) @@ -204,6 +217,10 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) + if self.is_encoder_decoder: + block_tables.append(block_table) + max_block_table_len = max(max_block_table_len, + len(block_table)) max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, @@ -215,8 +232,16 @@ def _prepare_prompt( pad=0, dtype=torch.long, device=self.device) + if self.is_encoder_decoder and len(block_tables) > 0: + # Pad the slot mapping to the same length and add decoder_start_id + for i in range(len(slot_mapping)): + slot_mapping[i] += [_PAD_SLOT_ID + ] * (max_prompt_len - len(slot_mapping[i])) + slot_mapping[i].append(block_tables[i][-1] * self.block_size) + + max_slot_mapping_len = max_prompt_len + self.is_encoder_decoder slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, + max_slot_mapping_len, pad=_PAD_SLOT_ID, dtype=torch.long, device=self.device) @@ -224,18 +249,34 @@ def _prepare_prompt( _pad_to_max(mapping, max_prompt_len, pad=0) for mapping in lora_index_mapping ] + context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = _make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables_tensor = _make_tensor_with_pad( + padded_block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) + else: + # Prepare prefix block tables + max_prompt_block_table_len = max( + len(t) for t in prefix_block_tables) + block_tables_tensor = _make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) start_loc_tensor = torch.arange(0, len(prompt_lens) * max_prompt_len, max_prompt_len, @@ -251,9 +292,9 @@ def _prepare_prompt( prompt_lens=prompt_lens_tensor, max_seq_len=max_prompt_len, start_loc=start_loc_tensor, - max_context_len=None, + max_context_len=max(context_lens), context_lens=context_lens_tensor, - block_tables=block_tables, + block_tables=block_tables_tensor, use_cuda_graph=False, kv_cache_dtype=self.kv_cache_dtype, ) @@ -270,12 +311,14 @@ def _prepare_decode( input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + prompt_lens: List[int] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + max_block_table_len = 0 for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt @@ -294,11 +337,28 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - context_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) + prompt_len = len(seq_data.prompt_token_ids) + prompt_lens.append(prompt_len) + + if self.is_encoder_decoder: + # Encoder-decoder model stores prompt and generation tokens separately, + # so we need to adjust to the pad. + prompt_blocks_num = (prompt_len + self.block_size - + 1) // self.block_size + prompt_pad = prompt_blocks_num * self.block_size - prompt_len + position += prompt_pad + 1 # One extra for decoder_start_id + + if self.is_encoder_decoder: + context_len = seq_len - prompt_len + 1 + elif self.sliding_window is not None: + context_len = min(seq_len, self.sliding_window) + else: + context_len = seq_len context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] + max_block_table_len = max(max_block_table_len, + len(block_table)) block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset @@ -311,6 +371,15 @@ def _prepare_decode( self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if self.is_encoder_decoder: + padded_block_tables = [] + # Pad the encoder block tables to the same length and then add a decoder block table in the end + for block_table in block_tables: + block_table = block_table[:-1] + [0] * ( + max_block_table_len - len(block_table)) + block_table[-1:] + padded_block_tables.append(block_table) + + block_tables = padded_block_tables batch_size = len(input_tokens) max_context_len = max(context_lens) @@ -350,6 +419,9 @@ def _prepare_decode( dtype=torch.int, device=self.device) + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -376,7 +448,7 @@ def _prepare_decode( input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + prompt_lens=prompt_lens, max_seq_len=None, start_loc=None, max_context_len=max_context_len, @@ -408,7 +480,7 @@ def _prepare_sample( sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) - if seq_group_metadata.is_prompt: + if seq_group_metadata.is_prompt and not self.is_encoder_decoder: assert len(seq_ids) == 1 assert subquery_lens is not None subquery_len = subquery_lens[i]