diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 66c865c272196..2608a27fd2679 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -136,7 +136,8 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - num_orig_input_tokens_tensor=self.num_orig_input_tokens_tensor[:self.num_prefills], + num_orig_input_tokens_tensor=self. + num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -165,7 +166,8 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - num_orig_input_tokens_tensor=self.num_orig_input_tokens_tensor[:self.num_prefills], + num_orig_input_tokens_tensor=self. + num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len,