-
Notifications
You must be signed in to change notification settings - Fork 28.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c3700b0
commit a43e9e3
Showing
7 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2025 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import _LazyModule | ||
from ...utils.import_utils import define_import_structure | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_deepsek_v2 import * | ||
from .modeling_deepsek_v2 import * | ||
else: | ||
import sys | ||
|
||
_file = globals()["__file__"] | ||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
201 changes: 201 additions & 0 deletions
201
src/transformers/models/deepseek_v2/modular_deepseek_v2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# coding=utf-8 | ||
# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, Qwen2MoeDecoderLayer, Qwen2MoeRMSNorm | ||
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig | ||
from torch import nn | ||
|
||
|
||
class DeepseekV2Config(Qwen2MoeConfig): | ||
r""" | ||
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek | ||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the | ||
defaults will yield a similar configuration to that of the DeepSeek-V2. | ||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | ||
documentation from [`PretrainedConfig`] for more information. | ||
Args: | ||
vocab_size (`int`, *optional*, defaults to 102400): | ||
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the | ||
`inputs_ids` passed when calling [`DeepseekV2Model`] | ||
hidden_size (`int`, *optional*, defaults to 4096): | ||
Dimension of the hidden representations. | ||
intermediate_size (`int`, *optional*, defaults to 11008): | ||
Dimension of the MLP representations. | ||
moe_intermediate_size (`int`, *optional*, defaults to 1407): | ||
Dimension of the MoE representations. | ||
num_hidden_layers (`int`, *optional*, defaults to 32): | ||
Number of hidden layers in the Transformer decoder. | ||
num_attention_heads (`int`, *optional*, defaults to 32): | ||
Number of attention heads for each attention layer in the Transformer decoder. | ||
n_shared_experts (`int`, *optional*, defaults to None): | ||
Number of shared experts, None means dense model. | ||
n_routed_experts (`int`, *optional*, defaults to None): | ||
Number of routed experts, None means dense model. | ||
routed_scaling_factor (`float`, *optional*, defaults to 1.0): | ||
Scaling factor or routed experts. | ||
topk_method (`str`, *optional*, defaults to `gready`): | ||
Topk method used in routed gate. | ||
n_group (`int`, *optional*, defaults to None): | ||
Number of groups for routed experts. | ||
topk_group (`int`, *optional*, defaults to None): | ||
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). | ||
num_experts_per_tok (`int`, *optional*, defaults to None): | ||
Number of selected experts, None means dense model. | ||
moe_layer_freq (`int`, *optional*, defaults to 1): | ||
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. | ||
first_k_dense_replace (`int`, *optional*, defaults to 0): | ||
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). | ||
\--k dense layers--/ | ||
norm_topk_prob (`bool`, *optional*, defaults to False): | ||
Whether to normalize the weights of the routed experts. | ||
aux_loss_alpha (`float`, *optional*, defaults to 0.001): | ||
Auxiliary loss weight coefficient. | ||
seq_aux = (`bool`, *optional*, defaults to True): | ||
Whether to compute the auxiliary loss for each individual sample. | ||
num_key_value_heads (`int`, *optional*): | ||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If | ||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if | ||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When | ||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed | ||
by meanpooling all the original heads within that group. For more details checkout [this | ||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to | ||
`num_attention_heads`. | ||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): | ||
The non-linear activation function (function or string) in the decoder. | ||
max_position_embeddings (`int`, *optional*, defaults to 2048): | ||
The maximum sequence length that this model might ever be used with. | ||
initializer_range (`float`, *optional*, defaults to 0.02): | ||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | ||
rms_norm_eps (`float`, *optional*, defaults to 1e-06): | ||
The epsilon used by the rms normalization layers. | ||
use_cache (`bool`, *optional*, defaults to `True`): | ||
Whether or not the model should return the last key/values attentions (not used by all models). Only | ||
relevant if `config.is_decoder=True`. | ||
pad_token_id (`int`, *optional*): | ||
Padding token id. | ||
bos_token_id (`int`, *optional*, defaults to 1): | ||
Beginning of stream token id. | ||
eos_token_id (`int`, *optional*, defaults to 2): | ||
End of stream token id. | ||
pretraining_tp (`int`, *optional*, defaults to 1): | ||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this | ||
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is | ||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this | ||
issue](https://github.com/pytorch/pytorch/issues/76232). | ||
tie_word_embeddings (`bool`, *optional*, defaults to `False`): | ||
Whether to tie weight embeddings | ||
rope_theta (`float`, *optional*, defaults to 10000.0): | ||
The base period of the RoPE embeddings. | ||
rope_scaling (`Dict`, *optional*): | ||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling | ||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is | ||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update | ||
`max_position_embeddings` to the expected new maximum. | ||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): | ||
Whether to use a bias in the query, key, value and output projection layers during self-attention. | ||
attention_dropout (`float`, *optional*, defaults to 0.0): | ||
The dropout ratio for the attention probabilities. | ||
```python | ||
>>> from transformers import DeepseekV2Model, DeepseekV2Config | ||
>>> # Initializing a Deepseek-V2 style configuration | ||
>>> configuration = DeepseekV2Config() | ||
>>> # Accessing the model configuration | ||
>>> configuration = model.config | ||
```""" | ||
|
||
model_type = "deepseek_v2" | ||
keys_to_ignore_at_inference = ["past_key_values"] | ||
|
||
def __init__( | ||
self, | ||
attention_bias=False, | ||
aux_loss_alpha=0.001, | ||
bos_token_id=100000, | ||
eos_token_id=100001, | ||
first_k_dense_replace=0, | ||
kv_lora_rank=512, | ||
moe_layer_freq=1, | ||
n_group=None, | ||
n_routed_experts=None, | ||
n_shared_experts=None, | ||
pretraining_tp=1, | ||
qk_nope_head_dim=128, | ||
qk_rope_head_dim=64, | ||
routed_scaling_factor=1.0, | ||
seq_aux=True, | ||
topk_group=None, | ||
topk_method="greedy", | ||
v_head_dim=128, | ||
**super_kwargs): | ||
super().__init__(**super_kwargs) | ||
|
||
del self.use_sliding_window | ||
del self.sliding_window | ||
del self.max_window_layers | ||
del self.decoder_sparse_step | ||
del self.shared_expert_intermediate_size | ||
del self.output_router_logits | ||
del self.router_aux_loss_coef | ||
del self.mlp_only_layers | ||
|
||
self.attention_bias = attention_bias | ||
self.aux_loss_alpha = aux_loss_alpha | ||
self.bos_token_id = bos_token_id | ||
self.eos_token_id = eos_token_id | ||
self.first_k_dense_replace = first_k_dense_replace | ||
self.kv_lora_rank = kv_lora_rank | ||
self.moe_layer_freq = moe_layer_freq | ||
self.n_group = n_group | ||
self.n_routed_experts = n_routed_experts | ||
self.n_shared_experts = n_shared_experts | ||
self.pretraining_tp = pretraining_tp | ||
self.qk_nope_head_dim = qk_nope_head_dim | ||
self.qk_rope_head_dim = qk_rope_head_dim | ||
self.routed_scaling_factor = routed_scaling_factor | ||
self.seq_aux = seq_aux | ||
self.topk_group = topk_group | ||
self.topk_method = topk_method | ||
self.v_head_dim = v_head_dim | ||
|
||
|
||
class DeepseekV2DecoderLayer(Qwen2MoeDecoderLayer): | ||
pass | ||
|
||
|
||
class DeepseekV2RMSNorm(Qwen2MoeRMSNorm): | ||
pass | ||
|
||
|
||
class DeepseekV2Model(Qwen2MoeModel): | ||
""" | ||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] | ||
Args: | ||
config: DeepseekV2Config | ||
""" | ||
|
||
def __init__(self, config: DeepseekV2Config): | ||
super().__init__(config) | ||
self.padding_idx = config.pad_token_id | ||
self.vocab_size = config.vocab_size | ||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | ||
self.layers = nn.ModuleList( | ||
[DeepseekV2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | ||
) | ||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" | ||
self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
self.gradient_checkpointing = False | ||
# Initialize weights and apply final processing | ||
self.post_init() |