Skip to content

Commit

Permalink
chore: fix code style issue
Browse files Browse the repository at this point in the history
  • Loading branch information
laishzh committed Jun 12, 2024
1 parent 25f7783 commit b81fb8a
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions vllm/model_executor/models/bert_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand Down Expand Up @@ -69,6 +68,7 @@ def _fix_key(key):
if "gamma" in key:
return key.replace("gamma", "weight")
return key

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
Expand All @@ -82,7 +82,7 @@ def _fix_key(key):
if name.startswith('cls.'):
continue

name = name[len(_prefix) :] if name.startswith(_prefix) else name
name = name[len(_prefix):] if name.startswith(_prefix) else name
name = _fix_key(name)

# use Pooler instead.
Expand Down Expand Up @@ -114,6 +114,7 @@ def _fix_key(key):


class BertModel(nn.Module):

def __init__(
self,
config: BertConfig,
Expand All @@ -140,6 +141,7 @@ def forward(


class BertEmbedding(nn.Module):

def __init__(self, config: BertConfig):
super().__init__()
self.size = config.hidden_size
Expand All @@ -151,7 +153,8 @@ def __init__(self, config: BertConfig):
config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

self.position_embedding_type = config.position_embedding_type
Expand Down Expand Up @@ -180,14 +183,16 @@ def forward(

# position embeddings
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long,
position_ids = torch.arange(seq_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
position_embeddings = self.position_embeddings(position_ids)

# token type embeddings
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long,
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

Expand All @@ -199,6 +204,7 @@ def forward(


class BertEncoder(nn.Module):

def __init__(
self,
config: BertConfig,
Expand Down Expand Up @@ -230,6 +236,7 @@ def forward(


class BertLayer(nn.Module):

def __init__(
self,
config: BertConfig,
Expand Down Expand Up @@ -265,6 +272,7 @@ def feed_forward(self, attention_output):


class BertAttention(nn.Module):

def __init__(
self,
config: BertConfig,
Expand Down Expand Up @@ -319,8 +327,7 @@ def __init__(
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=True,
quant_config=quant_config
)
quant_config=quant_config)

self.attn = Attention(
num_heads=self.num_heads,
Expand All @@ -344,10 +351,12 @@ def forward(


class BertSelfOutput(nn.Module):

def __init__(self, config: BertConfig):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
Expand All @@ -358,6 +367,7 @@ def forward(self, hidden_states, input_tensor):


class BertIntermediate(nn.Module):

def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
Expand All @@ -370,10 +380,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class BertOutput(nn.Module):

def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(
Expand Down

0 comments on commit b81fb8a

Please sign in to comment.