Skip to content

Commit

Permalink
[Encoder Decoder] Add flash_attn kernel support for encoder-decoder m…
Browse files Browse the repository at this point in the history
…odels (vllm-project#9559)

Signed-off-by: Loc Huynh <jc1da.3011@gmail.com>
sroy745 authored and JC1DA committed Nov 11, 2024
1 parent f53054e commit 4e317f0
Showing 11 changed files with 716 additions and 317 deletions.
88 changes: 51 additions & 37 deletions tests/encoder_decoder/test_e2e_correctness.py
Original file line number Diff line number Diff line change
@@ -7,12 +7,18 @@
import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs

from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close

LIST_ENC_DEC_SUPPORTED_BACKENDS = [
_Backend.XFORMERS, _Backend.FLASH_ATTN, None
]


def vllm_to_hf_output(
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
@@ -29,7 +35,8 @@ def vllm_to_hf_output(


@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
num_logprobs: int,
decoder_prompt_type: DecoderPromptType,
enforce_eager: bool,
attn_backend: _Backend,
) -> None:
'''
End-to-End (E2E) test for the encoder-decoder framework.
@@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
implementations to ensure that both implementations produce consistent
and correct results.
'''
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
test_case_prompts = example_encoder_decoder_prompts[
decoder_prompt_type]

# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}
# Configuration settings for HF baseline
hf_kwargs = {
"top_k": None,
"num_beams": 1,
"repetition_penalty": 1.0,
"top_p": 1.0,
"length_penalty": 1.0,
"early_stopping": False,
"no_repeat_ngram_size": None,
"min_length": 0
}

with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
with vllm_runner(model, dtype=dtype,
enforce_eager=enforce_eager) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
max_tokens,
num_logprobs,
**hf_kwargs,
))
with vllm_runner(model, dtype=dtype,
enforce_eager=enforce_eager) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
test_case_prompts, max_tokens, num_logprobs)

hf_skip_tokens = (1
if decoder_prompt_type == DecoderPromptType.NONE else 0)
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
else 0)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, decoder_prompt_type)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=hf_skip_tokens,
)
156 changes: 115 additions & 41 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,13 @@
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend,
from vllm.attention.selector import (_Backend, get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform

# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]

LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HEAD_SIZES = [64, 256]

NUM_HEADS = [1, 16]
@@ -145,7 +145,8 @@ class that Attention will automatically select when it is constructed.
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE)
device=CUDA_DEVICE,
backend=test_pt.backend_name)
return TestResources(scale, attn_backend, attn, kv_cache)


@@ -592,6 +593,7 @@ def _run_encoder_attention_test(
attn: Attention,
encoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor:
'''
Run encoder attention.
@@ -610,6 +612,8 @@ def _run_encoder_attention_test(
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed {query,key,value} and
@@ -619,20 +623,31 @@ def _run_encoder_attention_test(
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
with set_forward_context(attn_metadata):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)


def _run_decoder_self_attention_test(
test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters,
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor:
'''
Run decoder self-attention test.
@@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
@@ -660,19 +677,30 @@ def _run_decoder_self_attention_test(
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
with set_forward_context(attn_metadata):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)


def _run_encoder_decoder_cross_attention_test(
test_rsrcs: TestResources,
decoder_test_params: PhaseTestParameters,
cross_test_params: Optional[PhaseTestParameters],
attn_metadata: AttentionMetadata,
test_pt: TestPoint,
) -> torch.Tensor:
'''
Run encoder/decoder cross-attention test.
@@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
@@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
with set_forward_context(attn_metadata):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)


@pytest.fixture(autouse=True)
def set_reset_environment(attn_backend):
# Set the default torch datatype to bfloat16 to enable
# testing of the Flash Attention backend. Also clear the
# cached value of the backend.
default_dtype = torch.get_default_dtype()
if attn_backend.name == 'FLASH_ATTN':
torch.set_default_dtype(torch.bfloat16)
get_attn_backend.cache_clear()
yield
# Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests.
torch.set_default_dtype(default_dtype)


@pytest.mark.skipif(current_platform.is_rocm(),
@@ -773,10 +828,8 @@ def test_encoder_only(
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''

# Force Attention wrapper backend
with global_force_attn_backend_context_manager(attn_backend):

# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
@@ -807,10 +860,14 @@ def test_encoder_only(
# PREFILL: encoder attention

enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt))

# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
attn_backend.name)


@pytest.mark.skipif(current_platform.is_rocm(),
@@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''

# Force Attention wrapper backend
with global_force_attn_backend_context_manager(attn_backend):

# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
@@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(

enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata)
prephase_attn_metadata,
test_pt=test_pt)

# - Is encoder attention result correct?
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
attn_backend.name)

# PREFILL: decoder self-attention test

prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt)

# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal(prephase_dec_test_params,
prephase_dec_pckd_act_out)
prephase_dec_pckd_act_out,
attn_backend.name)

# PREFILL: encoder/decoder cross-attention test

prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
prephase_attn_metadata)
test_rsrcs,
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt)

# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal(prephase_cross_test_params,
prephase_cross_pckd_act_out)
prephase_cross_pckd_act_out,
attn_backend.name)

# DECODE: build decode-phase attention metadata

@@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test

decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt)

# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal(decphase_dec_test_params,
decphase_dec_pckd_act_out)
decphase_dec_pckd_act_out,
attn_backend.name)

# DECODE: encoder/decoder cross-attention test

decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
test_rsrcs,
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt)

# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal(decphase_cross_test_params,
decphase_cross_pckd_act_out)
decphase_cross_pckd_act_out,
attn_backend.name)
90 changes: 74 additions & 16 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
@@ -13,8 +13,8 @@

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
@@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
if backend_name == STR_XFORMERS_ATTN_VAL:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from vllm.attention.backends.xformers import XFormersBackend

return XFormersBackend()
elif backend_name == STR_FLASH_ATTN_VAL:
from vllm.attention.backends.flash_attn import FlashAttentionBackend
return FlashAttentionBackend()

raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")


def _make_metadata_tensors(
seq_lens: Optional[List[int]], context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str]
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]],
torch.Tensor, Optional[int]]:
seq_lens: Optional[List[int]],
context_lens: Optional[List[int]],
encoder_seq_lens: Optional[List[int]],
device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
torch.Tensor, torch.Tensor, Optional[int]]:
'''
Build scalar & tensor values required to build attention metadata structure.
@@ -553,6 +558,8 @@ def _make_metadata_tensors(
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
* encoder_seq_start_loc: start idx of each encoder sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
@@ -566,15 +573,34 @@ def _make_metadata_tensors(

seq_start_loc = None

if seq_lens_tensor is not None:
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=seq_lens_tensor.device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])

encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=encoder_seq_lens_tensor.device)
torch.cumsum(encoder_seq_lens_tensor,
dim=0,
dtype=encoder_seq_start_loc.dtype,
out=encoder_seq_start_loc[1:])

return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len)
seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
max_encoder_seq_len)


def make_kv_cache(num_blocks: int,
num_heads: int,
head_size: int,
block_size: int,
device: Union[torch.device, str],
backend: str,
default_val: float = 0.0) -> torch.Tensor:
'''
Create a fake KV cache.
@@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
'''

kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device)
if backend == 'XFORMERS':
kv_cache = torch.rand(
(2, num_blocks, block_size * num_heads * head_size)).to(device)
elif backend == 'FLASH_ATTN':
kv_cache = torch.rand(
(2, num_blocks, block_size, num_heads, head_size)).to(device)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
f"'FLASH_ATTN'.")
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
@@ -858,8 +894,9 @@ def make_test_metadata(
context_lens_tensor,
_,
_,
_,
seq_start_loc,
encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
@@ -874,6 +911,7 @@ def make_test_metadata(
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
max_decode_seq_len=0,
context_lens_tensor=context_lens_tensor,
@@ -882,6 +920,7 @@ def make_test_metadata(
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
@@ -904,8 +943,9 @@ def make_test_metadata(
context_lens_tensor,
_,
_,
_,
seq_start_loc,
encoder_seq_lens_tensor,
encoder_seq_start_loc,
max_encoder_seq_len,
) = _make_metadata_tensors(seq_lens,
context_lens,
@@ -920,14 +960,17 @@ def make_test_metadata(
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
seq_start_loc=seq_start_loc,
max_prefill_seq_len=0,
max_decode_seq_len=max(seq_lens),
max_decode_query_len=1,
context_lens_tensor=context_lens_tensor,
block_tables=kv_mmap.block_tables,
use_cuda_graph=False,
num_encoder_tokens=num_encoder_tokens,
encoder_seq_lens=encoder_seq_lens,
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
encoder_seq_start_loc=encoder_seq_start_loc,
max_encoder_seq_len=max_encoder_seq_len,
cross_slot_mapping=(None if cross_kv_mmap is None else
cross_kv_mmap.slot_mapping),
@@ -936,7 +979,8 @@ def make_test_metadata(


def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test: torch.Tensor) -> None:
output_under_test: torch.Tensor,
backend: str) -> None:
'''
Assert that observed output matches the ideal output
contained in the test parameters data structure.
@@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
* output_under_test: actually observed output value
'''
ideal_output = test_params.packed_qkvo.ideal_output
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output))
if backend == 'XFORMERS':
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output))

elif backend == 'FLASH_ATTN':
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output),
atol=0.01,
rtol=0.016)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
f"'FLASH_ATTN'.")


# Copied/modified from torch._refs.__init__.py
Original file line number Diff line number Diff line change
@@ -85,7 +85,7 @@ def run_test(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
364 changes: 278 additions & 86 deletions vllm/attention/backends/flash_attn.py

Large diffs are not rendered by default.

159 changes: 143 additions & 16 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union

import numpy as np
import torch

from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

@@ -336,11 +337,13 @@ def graph_capture_get_metadata_for_batch(
use_cuda_graph=True,
)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)

@@ -356,11 +359,13 @@ def get_graph_input_buffers(
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
@@ -375,11 +380,13 @@ def prepare_graph_input_buffers(
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)

@@ -411,6 +418,7 @@ def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.num_encoder_tokens = 0

def _add_additonal_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]):
@@ -453,3 +461,122 @@ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
input_buffers["cross_block_tables"].copy_(
attn_metadata.decode_metadata.cross_block_tables,
non_blocking=True)


def is_all_encoder_attn_metadata_set(attn_metadata):
'''
All attention metadata required for encoder attention is set.
'''
return ((attn_metadata.encoder_seq_lens is not None)
and (attn_metadata.encoder_seq_lens_tensor is not None)
and (attn_metadata.max_encoder_seq_len is not None))


def is_all_cross_attn_metadata_set(attn_metadata):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (attn_metadata.is_all_encoder_attn_metadata_set
and (attn_metadata.cross_slot_mapping is not None)
and (attn_metadata.cross_block_tables is not None))


def get_seq_len_block_table_args(
attn_metadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''

if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")


def get_num_prefill_decode_query_kv_tokens(
attn_metadata,
attn_type: AttentionType,
) -> Tuple[int, int, int]:
"""
Calculate the number of prefill and decode tokens for query, key/value
based on the attention metadata and the specified attention type.
Args:
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
attn_type (AttentionType): The type of attention being used.
Returns:
Tuple[int, int, int]: A tuple containing three integers:
- The number of prefill query tokens.
- The number of prefill key/value tokens.
- The number of decode query tokens.
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
"""
num_prefill_query_tokens = 0
num_decode_query_tokens = 0
num_prefill_kv_tokens = 0
if attn_type == AttentionType.ENCODER:
# Encoder attention is only invoked during prefill phase.
# The same input servers a both query and key.
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = 0
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
# The key is the encoder/cross-attention.
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
else: # attn_type == AttentionType.DECODER or
# attn_type == AttentionType.ENCODER_ONLY
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens

return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)
131 changes: 26 additions & 105 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
@@ -11,8 +11,10 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.backends.utils import (
CommonAttentionState, CommonMetadataBuilder,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None

# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
@@ -162,9 +169,7 @@ def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return ((self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None))
return is_all_encoder_attn_metadata_set(self)

@property
def is_all_cross_attn_metadata_set(self):
@@ -173,9 +178,7 @@ def is_all_cross_attn_metadata_set(self):
Superset of encoder attention required metadata.
'''
return (self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None))
return is_all_cross_attn_metadata_set(self)

@property
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
@@ -329,64 +332,6 @@ def _set_attn_bias(
raise AttributeError(f"Invalid attention type {str(attn_type)}")


def _get_seq_len_block_table_args(
attn_metadata: XFormersMetadata,
is_prompt: bool,
attn_type: AttentionType,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''

if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."

# No block tables associated with encoder attention
return (attn_metadata.seq_lens_tensor,
attn_metadata.max_prefill_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")


class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):

_metadata_cls = XFormersMetadata
@@ -574,45 +519,21 @@ def forward(
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)

if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_encoder_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
elif attn_type == AttentionType.DECODER:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.num_encoder_tokens is not None:
num_encoder_tokens = attn_metadata.num_encoder_tokens
else:
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)

output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
query = query[:num_prefill_query_tokens]
if key is not None and value is not None:
key = key[:num_encoder_tokens]
value = value[:num_encoder_tokens]
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
@@ -622,8 +543,8 @@ def forward(
# prefix.
out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta, attn_type=attn_type)
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
assert out.shape == output[:num_prefill_query_tokens].shape
output[:num_prefill_query_tokens] = out
else:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have prefix attention.")
@@ -652,8 +573,8 @@ def forward(
k_scale,
v_scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
assert output[:num_prefill_query_tokens].shape == out.shape
output[:num_prefill_query_tokens] = out

if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
@@ -663,9 +584,9 @@ def forward(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = _get_seq_len_block_table_args(decode_meta, False, attn_type)
) = get_seq_len_block_table_args(decode_meta, False, attn_type)

output[num_prefill_tokens:] = PagedAttention.forward_decode(
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
2 changes: 1 addition & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
@@ -98,7 +98,6 @@ def get_attn_backend(
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""

if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
@@ -108,6 +107,7 @@ def get_attn_backend(
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free)
if backend == _Backend.FLASH_ATTN:
logger.info("Using Flash Attention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
2 changes: 0 additions & 2 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
@@ -624,8 +624,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds

input_ids = input_ids.view(-1, input_ids.shape[-1])
inputs_embeds = self.embed_tokens(input_ids)

embed_pos = self.embed_positions(
4 changes: 2 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
@@ -80,8 +80,8 @@
"currently supported with encoder/"
"decoder models.")

STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
"currently supported with encoder/"
STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only "
"backends currently supported with encoder/"
"decoder models.")

STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
35 changes: 25 additions & 10 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
@@ -36,6 +37,11 @@

logger = init_logger(__name__)

# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]


@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
@@ -101,9 +107,7 @@ def __init__(
models) but these arguments are present here for compatibility with
the base-class constructor.
'''

self._maybe_force_supported_attention_backend()

self._maybe_force_supported_attention_backend(model_config)
super().__init__(
model_config,
parallel_config,
@@ -119,7 +123,12 @@ def __init__(
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self)

def _maybe_force_supported_attention_backend(self):
def _is_xformers_only_encoder_decoder_model(self,
model: ModelConfig) -> bool:
return get_architecture_class_name(
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS

def _maybe_force_supported_attention_backend(self, model: ModelConfig):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
@@ -135,22 +144,26 @@ def raise_backend_err():
is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None

if not (is_forced_by_global or is_forced_by_env_var):
if not (is_forced_by_global or is_forced_by_env_var) \
and self._is_xformers_only_encoder_decoder_model(model):
# The user has not already specified an attention backend
# override
logger.info("EncoderDecoderModelRunner requires "
"XFormers backend; overriding backend "
"auto-selection and forcing XFormers.")
logger.info(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers.", get_architecture_class_name(model))
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if maybe_global_forced_backend != _Backend.XFORMERS:
if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()
elif is_forced_by_env_var:
# Backend override enforced by vLLM backend
# environment variable
if maybe_env_var_forced_backend != _Backend.XFORMERS:
if maybe_env_var_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()

def _list_to_int32_tensor(
@@ -532,13 +545,15 @@ def _prepare_encoder_model_input_tensors(
attn_metadata.encoder_seq_lens,
attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.cross_slot_mapping,
attn_metadata.cross_block_tables,
) = (
sum(encoder_seq_lens),
encoder_seq_lens,
encoder_seq_lens_tensor,
max_encoder_seq_len,
encoder_seq_start_loc,
cross_slot_mapping_tensor,
cross_block_tables,
)

0 comments on commit 4e317f0

Please sign in to comment.