Skip to content

Commit

Permalink
fix registery, determine max_batch_size in MambaCache
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth committed Feb 14, 2025
1 parent 9cfd012 commit 201696b
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 41 deletions.
14 changes: 8 additions & 6 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,22 @@ def test_parallel_sampling(
max_tokens: int,
) -> None:

# Numerical differences produce slightly different output for these
if 'state-spaces' in model:
example_prompts.pop(0)
example_prompts.pop(0)
example_prompts.pop(0)

with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
# using example_prompts index 1 instead of 0 since with 0 the
# logprobs get really close and the test doesn't pass
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
[0])
vllm_model.generate_greedy(example_prompts, max_tokens)[0])
sampling_params = SamplingParams(n=10,
temperature=0.001,
seed=0,
max_tokens=max_tokens)
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
sampling_params)
n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params)
token_ids, texts = n_lt_1_outputs[0]
n_lt_1_outputs = [(token_id, text)
for token_id, text in zip(token_ids, texts)]
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def check_available_online(
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1"),
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True),
Expand Down
10 changes: 2 additions & 8 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# Determine max batch size to set size of MambaCache
self.max_batch_size = self.scheduler_config.max_num_seqs
if not self.model_config.enforce_eager:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.max_batch_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand All @@ -463,8 +457,8 @@ def forward(self,
self.vllm_config.parallel_config, LayerBlockType.mamba)

self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params,
Expand Down
10 changes: 2 additions & 8 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# Determine max batch size to set size of MambaCache
self.max_batch_size = self.scheduler_config.max_num_seqs
if not self.model_config.enforce_eager:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.max_batch_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand All @@ -448,8 +442,8 @@ def forward(self,
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())

mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

Expand Down
10 changes: 2 additions & 8 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)

# Determine max batch size to set size of MambaCache
self.max_batch_size = self.scheduler_config.max_num_seqs
if not self.model_config.enforce_eager:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.max_batch_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)

Expand All @@ -223,8 +217,8 @@ def forward(self,
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())

mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

Expand Down
10 changes: 2 additions & 8 deletions vllm/model_executor/models/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)

# Determine max batch size to set size of MambaCache
self.max_batch_size = self.scheduler_config.max_num_seqs
if not self.model_config.enforce_eager:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.max_batch_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)

Expand All @@ -239,8 +233,8 @@ def forward(self,
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers,
self.max_batch_size, *self._get_mamba_cache_shape())
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())

mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

Expand Down
13 changes: 10 additions & 3 deletions vllm/model_executor/models/mamba_cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Tuple

import torch

from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig


@dataclass
Expand All @@ -22,8 +23,14 @@ def at_layer_idx(self, layer_idx):

class MambaCacheManager:

def __init__(self, dtype, num_mamba_layers, max_batch_size,
conv_state_shape, temporal_state_shape):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int]):

# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)

conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape,
Expand Down

0 comments on commit 201696b

Please sign in to comment.