Skip to content

Commit

Permalink
[Neuron] Support inference with transformers-neuronx (vllm-project#2569)
Browse files Browse the repository at this point in the history
  • Loading branch information
liangfu authored Feb 28, 2024
1 parent e46fa5d commit 3b7178c
Show file tree
Hide file tree
Showing 18 changed files with 516 additions and 42 deletions.
33 changes: 33 additions & 0 deletions examples/offline_inference_neuron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(
model="openlm-research/open_llama_3b",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as max sequence length,
# when targeting neuron device. Currently, this is a known limitation in continuous batching
# support in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128,
block_size=128,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, or explicitly assigned.
device="neuron")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
8 changes: 5 additions & 3 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
get_model_old = get_model

def get_model_patched(model_config, device_config, lora_config=None):
return get_model_old(model_config, device_config,
LoRAConfig(max_loras=4, max_lora_rank=8))
def get_model_patched(model_config, device_config, **kwargs):
return get_model_old(model_config,
device_config,
lora_config=LoRAConfig(max_loras=4,
max_lora_rank=8))

with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
Expand Down
41 changes: 35 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version

logger = init_logger(__name__)

Expand Down Expand Up @@ -380,13 +380,21 @@ def __init__(
disable_custom_all_reduce: bool = False,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce

self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
self.worker_use_ray = True
self._verify_args()

Expand Down Expand Up @@ -465,8 +473,29 @@ def _verify_args(self) -> None:

class DeviceConfig:

def __init__(self, device: str = "cuda") -> None:
self.device = torch.device(device)
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
else:
# Device type is assigned explicitly
self.device_type = device

# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
else:
# Set device with device type
self.device = torch.device(self.device_type)

@property
def is_neuron(self):
return self.device_type == "neuron"


@dataclass
Expand Down
16 changes: 7 additions & 9 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class EngineArgs:
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'cuda'
device: str = 'auto'

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -171,7 +171,7 @@ def add_cli_args(
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
choices=[8, 16, 32, 128],
help='token block size')
parser.add_argument('--seed',
type=int,
Expand Down Expand Up @@ -264,13 +264,11 @@ def add_cli_args(
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
parser.add_argument(
"--device",
type=str,
default=EngineArgs.device,
choices=["cuda"],
help=('Device type for vLLM execution. '
'Currently, only CUDA-compatible devices are supported.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron"],
help='Device type for vLLM execution.')
return parser

@classmethod
Expand Down
21 changes: 18 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
import pickle
import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)

Expand All @@ -20,7 +21,8 @@
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
get_open_port, get_distributed_init_method)

if ray:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
Expand All @@ -31,6 +33,12 @@
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5

# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}

# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
Expand Down Expand Up @@ -138,10 +146,17 @@ def __init__(
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker

def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
Worker = self._dispatch_worker()

assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
Expand Down Expand Up @@ -243,7 +258,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
Worker = self._dispatch_worker()

# Initialize torch distributed process group for the workers.
model_config = copy.deepcopy(self.model_config)
Expand Down
4 changes: 4 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,10 @@ def __init__(
self.dtype = dtype
self.device = device

@property
def logits_as_hidden_states(self):
return self.base_layer.logits_as_hidden_states

@property
def vocab_size(self):
return self.base_layer.vocab_size
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.utils import set_random_seed, get_model

__all__ = [
"InputMetadata",
Expand Down
18 changes: 13 additions & 5 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput)
from vllm.utils import is_neuron


class Sampler(nn.Module):
Expand All @@ -32,6 +33,8 @@ def __init__(self,
org_vocab_size: Optional[int] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size

Expand All @@ -55,10 +58,14 @@ def forward(
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)

# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)

# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
Expand Down Expand Up @@ -395,7 +402,8 @@ def _sample(
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
dim=-1)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
Expand All @@ -407,7 +415,7 @@ def _sample(
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices], max_best_of, **seeded_args)
probs[sample_indices.long()], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Optional, Type
from typing import Type

import torch
import torch.nn as nn

from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
Expand Down Expand Up @@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
f"Supported architectures: {ModelRegistry.get_supported_archs()}")


def get_model(model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
model_class = _get_model_architecture(model_config)

# Get the (maybe quantized) linear method.
Expand Down
12 changes: 11 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip
from vllm.utils import is_hip, is_neuron

logger = init_logger(__name__)

Expand Down Expand Up @@ -61,6 +61,9 @@
"Sliding window attention is not yet supported in ROCm's flash attention",
}

# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}


class ModelRegistry:

Expand All @@ -77,8 +80,15 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
elif is_neuron():
if model_arch not in _NEURON_SUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"Neuron for now.")

module_name, model_cls_name = _MODELS[model_arch]
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
Expand Down
Loading

0 comments on commit 3b7178c

Please sign in to comment.