Skip to content

Commit

Permalink
[Encoder Decoder] Update Mllama to run with both FlashAttention and X…
Browse files Browse the repository at this point in the history
…Formers (vllm-project#9982)

Signed-off-by: Sourashis Roy <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
sroy745 authored and sumitd2 committed Nov 14, 2024
1 parent 0bc082c commit 32fe9a7
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 80 deletions.
9 changes: 8 additions & 1 deletion tests/encoder_decoder/test_e2e_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.attention.selector import (_Backend,
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
Expand All @@ -34,6 +34,13 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs


@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run


@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
Expand Down
100 changes: 63 additions & 37 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)

from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

Expand All @@ -14,6 +16,8 @@

_LIMIT_IMAGE_PER_PROMPT = 3

LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
Expand Down Expand Up @@ -221,6 +225,13 @@ def process(hf_inputs: BatchEncoding, **kwargs):
)


@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
Expand All @@ -244,30 +255,37 @@ def process(hf_inputs: BatchEncoding, **kwargs):
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens,
num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
num_logprobs,
attn_backend: _Backend) -> None:
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
model, dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:

stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
Expand All @@ -291,26 +309,31 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
cherry_blossom.resize((512, 1024)),
],
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None:
dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:

stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
Expand All @@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
[stop_sign],
[stop_sign, cherry_blossom],
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
2 changes: 2 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def test_rope_customization():
assert longchat_model_config.max_model_len == 4096


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Encoder Decoder models not supported on ROCm.")
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False),
("facebook/bart-base", True),
Expand Down
52 changes: 38 additions & 14 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -799,12 +801,13 @@ def forward(
q = self.q_norm(q)

if attention_mask is not None:
output = self.attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
output = self._attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
else:
output = self.attn(q,
output = self.attn(q.view(-1,
self.num_local_heads * self.head_dim),
k,
v,
kv_cache,
Expand All @@ -813,7 +816,7 @@ def forward(
out, _ = self.o_proj(output)
return out

def attention_with_mask(
def _attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
Expand All @@ -824,14 +827,35 @@ def attention_with_mask(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
if len(kv_cache.shape) > 1:
if isinstance(attn_metadata, FlashAttentionMetadata):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")

# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
Expand Down
34 changes: 6 additions & 28 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import ModelConfig, VllmConfig
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
Expand All @@ -35,11 +33,6 @@

logger = init_logger(__name__)

# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]


@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
Expand Down Expand Up @@ -97,7 +90,7 @@ def __init__(
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self._maybe_force_supported_attention_backend(vllm_config.model_config)
self._maybe_force_supported_attention_backend()

super().__init__(
vllm_config=vllm_config,
Expand All @@ -108,12 +101,7 @@ def __init__(
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self)

def _is_xformers_only_encoder_decoder_model(self,
model: ModelConfig) -> bool:
return get_architecture_class_name(
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS

def _maybe_force_supported_attention_backend(self, model: ModelConfig):
def _maybe_force_supported_attention_backend(self):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
Expand All @@ -128,23 +116,13 @@ def raise_backend_err():
maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None

if not (is_forced_by_global or is_forced_by_env_var) \
and self._is_xformers_only_encoder_decoder_model(model):
# The user has not already specified an attention backend
# override
logger.info(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers.", get_architecture_class_name(model))
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
if is_forced_by_global: # noqa: SIM102
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()
elif is_forced_by_env_var:
elif is_forced_by_env_var: # noqa: SIM102
# Backend override enforced by vLLM backend
# environment variable
if maybe_env_var_forced_backend not in\
Expand Down

0 comments on commit 32fe9a7

Please sign in to comment.