Skip to content

Commit

Permalink
[ROCm][V1] Add intial ROCm support to V1 (vllm-project#12790)
Browse files Browse the repository at this point in the history
  • Loading branch information
SageMoore authored Feb 14, 2025
1 parent cbc4012 commit ba59b78
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 18 deletions.
16 changes: 16 additions & 0 deletions requirements-rocm-build.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Common dependencies
-r requirements-common.txt

--extra-index-url https://download.pytorch.org/whl/rocm6.2
torch==2.5.1
torchvision==0.20.1
torchaudio==2.5.1

cmake>=3.26
ninja
packaging
setuptools>=61
setuptools-scm>=8
wheel
jinja2
amdsmi==6.2.4
6 changes: 4 additions & 2 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,8 @@ def context_attention_fwd(q,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None):
sliding_window=None,
sm_scale=None):

q_dtype_is_f32 = q.dtype is torch.float32
# need to reduce num. blocks when using fp32
Expand Down Expand Up @@ -759,7 +760,8 @@ def context_attention_fwd(q,
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)

sm_scale = 1.0 / (Lq**0.5)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]

Expand Down
45 changes: 30 additions & 15 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import os
from functools import lru_cache
from typing import TYPE_CHECKING, Dict, List, Optional

Expand Down Expand Up @@ -29,12 +28,6 @@
except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e)

if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
logger.warning("`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
" `spawn` instead.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

Expand Down Expand Up @@ -84,6 +77,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
return "vllm.attention.backends.triton_mla.TritonMLABackend"
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
logger.info("Using ROCm Attention backend on V1 engine.")
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
Expand All @@ -102,7 +98,11 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
# NOTE: When using V1 this function is called when overriding the
# engine args. Calling torch.cuda.get_device_name(device_id) here
# will result in the ROCm context being initialized before other
# processes can be created.
return "AMD"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
Expand All @@ -129,15 +129,30 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Multi-step scheduling is not supported (and not "
"needed) on VLLM V1. Please launch without "
"--num-scheduler-steps.")
else:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Speculative decoding is not yet supported on VLLM V1."
)
else:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import flash_attn_varlen_func

if current_platform.is_cuda():
from vllm.vllm_flash_attn import flash_attn_varlen_func

logger = init_logger(__name__)

Expand Down
182 changes: 182 additions & 0 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention on rocm"""
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata

logger = init_logger(__name__)


class ROCmAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_name() -> str:
return "ROCM_ATTN_VLLM_V1"

@staticmethod
def get_impl_cls() -> Type["ROCmAttentionImpl"]:
return ROCmAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False


class ROCmAttentionImpl(AttentionImpl):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"ROCmAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by ROCmAttention. "
f"Supported head sizes are: {support_head_sizes}.")

if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmAttentionImpl")

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."

if attn_metadata is None:
# Profiling run.
return output

assert attn_metadata.use_cascade is False

# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.

num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

# TODO(sage): Refactor the context_attention_fwd kernel so that this
# overhead can be removed
context_lens = torch.empty_like(attn_metadata.seq_lens)
batch_size = len(attn_metadata.query_start_loc) - 1
assert len(context_lens) == batch_size
for i in range(batch_size):
query_start = attn_metadata.query_start_loc[i]
query_end = attn_metadata.query_start_loc[i + 1]
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
query_start)

# Compute attention and update output up to `num_actual_tokens`.
context_attention_fwd(q=query[:num_actual_tokens],
k=key[:num_actual_tokens],
v=value[:num_actual_tokens],
o=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
k_cache=key_cache,
v_cache=value_cache,
b_loc=attn_metadata.block_table,
b_start_loc=attn_metadata.query_start_loc,
b_seq_len=attn_metadata.seq_lens,
b_ctx_len=context_lens,
max_input_len=attn_metadata.max_query_len,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
return output

0 comments on commit ba59b78

Please sign in to comment.