Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers #9982

Merged
merged 79 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
b3575e9
Merge branch 'vllm-project:main' into main
sroy745 Jun 20, 2024
94b0d43
Merge branch 'vllm-project:main' into main
sroy745 Jun 24, 2024
fa8fedf
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
6ed96b4
Merge branch 'vllm-project:main' into main
sroy745 Jun 27, 2024
b71c533
Merge branch 'vllm-project:main' into main
sroy745 Jun 28, 2024
57babef
Merge branch 'vllm-project:main' into main
sroy745 Jun 29, 2024
4b19bac
Merge branch 'vllm-project:main' into main
sroy745 Jul 1, 2024
eb7a1c4
Merge branch 'vllm-project:main' into main
sroy745 Jul 6, 2024
7e2c87e
Merge branch 'vllm-project:main' into main
sroy745 Jul 10, 2024
6212d5f
Merge branch 'vllm-project:main' into main
sroy745 Jul 15, 2024
5491438
Merge branch 'vllm-project:main' into main
sroy745 Jul 17, 2024
68e080a
Merge branch 'vllm-project:main' into main
sroy745 Jul 31, 2024
55e4332
Merge branch 'vllm-project:main' into main
sroy745 Aug 13, 2024
532eb48
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
7cea056
Merge branch 'vllm-project:main' into main
sroy745 Aug 22, 2024
185e056
Merge branch 'vllm-project:main' into main
sroy745 Aug 24, 2024
e2be95f
Merge branch 'vllm-project:main' into main
sroy745 Aug 27, 2024
2ed5473
Merge branch 'vllm-project:main' into main
sroy745 Aug 28, 2024
efa4714
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
fb87d34
Merge branch 'vllm-project:main' into main
sroy745 Aug 29, 2024
5419e49
Merge branch 'vllm-project:main' into main
sroy745 Aug 31, 2024
9ba12f8
Merge branch 'vllm-project:main' into main
sroy745 Sep 2, 2024
25cef3d
Merge branch 'vllm-project:main' into main
sroy745 Sep 3, 2024
9d4cd09
Merge branch 'vllm-project:main' into main
sroy745 Sep 4, 2024
c48cacb
Merge branch 'vllm-project:main' into main
sroy745 Sep 5, 2024
c42c399
Merge branch 'vllm-project:main' into main
sroy745 Sep 7, 2024
3d13e43
Merge branch 'vllm-project:main' into main
sroy745 Sep 9, 2024
7479775
Merge branch 'vllm-project:main' into main
sroy745 Sep 11, 2024
df9b966
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
9a7ed92
Merge branch 'vllm-project:main' into main
sroy745 Sep 17, 2024
118e838
Merge branch 'vllm-project:main' into main
sroy745 Sep 19, 2024
e640c69
Merge branch 'vllm-project:main' into main
sroy745 Sep 20, 2024
89fb6cd
Merge branch 'vllm-project:main' into main
sroy745 Sep 23, 2024
5d886cc
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
56f2065
Merge branch 'vllm-project:main' into main
sroy745 Sep 24, 2024
28e103e
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
2fc1490
Merge branch 'vllm-project:main' into main
sroy745 Sep 25, 2024
8805750
Merge branch 'vllm-project:main' into main
sroy745 Sep 26, 2024
b30e5af
Merge branch 'vllm-project:main' into main
sroy745 Sep 28, 2024
92322f1
Merge branch 'vllm-project:main' into main
sroy745 Sep 30, 2024
85e9001
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
cd4ff89
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
0dd96ed
Merge branch 'vllm-project:main' into main
sroy745 Oct 1, 2024
9d4d969
Merge branch 'vllm-project:main' into main
sroy745 Oct 3, 2024
7d223b5
Merge branch 'vllm-project:main' into main
sroy745 Oct 5, 2024
f327d91
Merge branch 'vllm-project:main' into main
sroy745 Oct 5, 2024
b5adf28
Merge branch 'vllm-project:main' into main
sroy745 Oct 6, 2024
caf0d12
Merge branch 'vllm-project:main' into main
sroy745 Oct 7, 2024
28e77b1
Merge branch 'vllm-project:main' into main
sroy745 Oct 8, 2024
db7e46d
Merge branch 'vllm-project:main' into main
sroy745 Oct 9, 2024
59b35f0
Merge branch 'vllm-project:main' into main
sroy745 Oct 17, 2024
dd9affa
Merge branch 'vllm-project:main' into main
sroy745 Oct 17, 2024
f61a15d
Merge branch 'vllm-project:main' into main
sroy745 Oct 21, 2024
0569773
Merge branch 'vllm-project:main' into main
sroy745 Oct 27, 2024
a2090e0
Merge branch 'vllm-project:main' into main
sroy745 Oct 30, 2024
c9a3f00
Merge branch 'vllm-project:main' into main
sroy745 Nov 1, 2024
b59e6a8
Merge branch 'vllm-project:main' into main
sroy745 Nov 3, 2024
79a0138
Run mllama with both xFormers and FlashAttention.
sroy745 Nov 4, 2024
a567667
Remove prints
sroy745 Nov 4, 2024
53c8a72
Debug
sroy745 Nov 6, 2024
36e3abd
Debug lines
sroy745 Nov 6, 2024
e050052
Removing debug
sroy745 Nov 6, 2024
fd9fdff
Merge branch 'vllm-project:main' into main
sroy745 Nov 8, 2024
0b53354
Address comments
sroy745 Nov 8, 2024
838fee9
Merge branch 'sroy-encdec-flash' of https://github.com/sroy745/vllm i…
sroy745 Nov 8, 2024
e25a2ae
Merge branch 'main' into sroy-encdec-flash
sroy745 Nov 11, 2024
366cbf7
Merge branch 'vllm-project:main' into main
sroy745 Nov 11, 2024
a5868ee
Merge remote-tracking branch 'origin/main' into sroy-encdec-flash
sroy745 Nov 11, 2024
b11d637
Dummy
sroy745 Nov 12, 2024
c7df9d6
Dummy
sroy745 Nov 12, 2024
1a63f7a
Fix tests
sroy745 Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/encoder_decoder/test_e2e_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.attention.selector import (_Backend,
from vllm.attention.selector import (_Backend, get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
Expand All @@ -34,6 +34,14 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs


@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
# Optionally, you could also do some cleanup here after the test runs
sroy745 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
Expand Down
101 changes: 64 additions & 37 deletions tests/models/encoder_decoder/vision_language/test_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)

from vllm.attention.selector import (_Backend, get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

Expand All @@ -14,6 +16,8 @@

_LIMIT_IMAGE_PER_PROMPT = 3

LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
Expand Down Expand Up @@ -221,6 +225,14 @@ def process(hf_inputs: BatchEncoding, **kwargs):
)


@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
# Optionally, you could also do some cleanup here after the test runs
sroy745 marked this conversation as resolved.
Show resolved Hide resolved


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
Expand All @@ -244,30 +256,37 @@ def process(hf_inputs: BatchEncoding, **kwargs):
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens,
num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
num_logprobs,
attn_backend: _Backend) -> None:
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
model, dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:

stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
Expand All @@ -291,26 +310,31 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
cherry_blossom.resize((512, 1024)),
],
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs) -> None:
dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:

stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
Expand All @@ -325,14 +349,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
[stop_sign],
[stop_sign, cherry_blossom],
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
53 changes: 39 additions & 14 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -756,17 +758,19 @@ def forward(
dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)

sroy745 marked this conversation as resolved.
Show resolved Hide resolved
k = self.k_norm(k)
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)

if attention_mask is not None:
output = self.attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
output = self._attention_with_mask(q, k, v, kv_cache,
attention_mask,
kv_range_for_decode,
attn_metadata)
else:
output = self.attn(q,
output = self.attn(q.view(-1,
self.num_local_heads * self.head_dim),
k,
v,
kv_cache,
Expand All @@ -775,7 +779,7 @@ def forward(
out, _ = self.o_proj(output)
return out

def attention_with_mask(
def _attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
Expand All @@ -786,14 +790,35 @@ def attention_with_mask(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) == 3:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
if len(kv_cache.shape) > 1:
if isinstance(attn_metadata, FlashAttentionMetadata):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
)
elif isinstance(attn_metadata, XFormersMetadata):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else:
raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} "
f"class found. Expected the AttentionMetadata to "
f"be either XFormersMetadata or FlashAttentionMetadata.")

# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
Expand Down
47 changes: 9 additions & 38 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import ModelConfig, VllmConfig
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.utils import get_architecture_class_name
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
MultiModalRegistry)
from vllm.sampling_params import SamplingParams
Expand All @@ -35,11 +33,6 @@

logger = init_logger(__name__)

# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]


@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
Expand Down Expand Up @@ -97,7 +90,7 @@ def __init__(
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self._maybe_force_supported_attention_backend(vllm_config.model_config)
self._maybe_force_supported_attention_backend()

super().__init__(
vllm_config=vllm_config,
Expand All @@ -108,12 +101,7 @@ def __init__(
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario(self)

def _is_xformers_only_encoder_decoder_model(self,
model: ModelConfig) -> bool:
return get_architecture_class_name(
model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS

def _maybe_force_supported_attention_backend(self, model: ModelConfig):
def _maybe_force_supported_attention_backend(self):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
Expand All @@ -128,28 +116,11 @@ def raise_backend_err():
maybe_global_forced_backend = get_global_forced_attn_backend()
is_forced_by_global = maybe_global_forced_backend is not None
is_forced_by_env_var = maybe_env_var_forced_backend is not None

if not (is_forced_by_global or is_forced_by_env_var) \
and self._is_xformers_only_encoder_decoder_model(model):
# The user has not already specified an attention backend
# override
logger.info(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers.", get_architecture_class_name(model))
global_force_attn_backend(_Backend.XFORMERS)
elif is_forced_by_global:
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if maybe_global_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()
elif is_forced_by_env_var:
# Backend override enforced by vLLM backend
# environment variable
if maybe_env_var_forced_backend not in\
[_Backend.XFORMERS, _Backend.FLASH_ATTN]:
raise_backend_err()
if (is_forced_by_global and maybe_global_forced_backend \
not in [_Backend.XFORMERS, _Backend.FLASH_ATTN]) or \
(is_forced_by_env_var and maybe_env_var_forced_backend \
not in [_Backend.XFORMERS, _Backend.FLASH_ATTN]):
raise_backend_err()
sroy745 marked this conversation as resolved.
Show resolved Hide resolved

def _list_to_int32_tensor(
self,
Expand Down