From b81fb8a015330eca25a5a3a6d97ef2384777550a Mon Sep 17 00:00:00 2001 From: laishzh Date: Wed, 12 Jun 2024 18:00:56 +0800 Subject: [PATCH] chore: fix code style issue --- vllm/model_executor/models/bert_embedding.py | 32 ++++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/bert_embedding.py b/vllm/model_executor/models/bert_embedding.py index 9bb87aebf7a2a..93c4068e8b1c2 100644 --- a/vllm/model_executor/models/bert_embedding.py +++ b/vllm/model_executor/models/bert_embedding.py @@ -8,8 +8,7 @@ from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -69,6 +68,7 @@ def _fix_key(key): if "gamma" in key: return key.replace("gamma", "weight") return key + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), @@ -82,7 +82,7 @@ def _fix_key(key): if name.startswith('cls.'): continue - name = name[len(_prefix) :] if name.startswith(_prefix) else name + name = name[len(_prefix):] if name.startswith(_prefix) else name name = _fix_key(name) # use Pooler instead. @@ -114,6 +114,7 @@ def _fix_key(key): class BertModel(nn.Module): + def __init__( self, config: BertConfig, @@ -140,6 +141,7 @@ def forward( class BertEmbedding(nn.Module): + def __init__(self, config: BertConfig): super().__init__() self.size = config.hidden_size @@ -151,7 +153,8 @@ def __init__(self, config: BertConfig): config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.position_embedding_type = config.position_embedding_type @@ -180,14 +183,16 @@ def forward( # position embeddings if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, + position_ids = torch.arange(seq_length, + dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_embeddings = self.position_embeddings(position_ids) # token type embeddings if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, device=device) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -199,6 +204,7 @@ def forward( class BertEncoder(nn.Module): + def __init__( self, config: BertConfig, @@ -230,6 +236,7 @@ def forward( class BertLayer(nn.Module): + def __init__( self, config: BertConfig, @@ -265,6 +272,7 @@ def feed_forward(self, attention_output): class BertAttention(nn.Module): + def __init__( self, config: BertConfig, @@ -319,8 +327,7 @@ def __init__( total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=True, - quant_config=quant_config - ) + quant_config=quant_config) self.attn = Attention( num_heads=self.num_heads, @@ -344,10 +351,12 @@ def forward( class BertSelfOutput(nn.Module): + def __init__(self, config: BertConfig): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): @@ -358,6 +367,7 @@ def forward(self, hidden_states, input_tensor): class BertIntermediate(nn.Module): + def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -370,10 +380,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertOutput(nn.Module): + def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(