Skip to content

Commit

Permalink
Support for LLaMA (#841)
Browse files Browse the repository at this point in the history
* llama

* spm tokenizer

* pipeline

* llama to neox conversion script

* llama checkin

* weights script update and pp reversion

* revert for PR

* configs

* 7B-specific tweak

* LLaMA updates

* PR feedback

* initialize multiple_of

---------

Co-authored-by: Quentin-Anthony <[email protected]>
  • Loading branch information
zphang and Quentin-Anthony authored May 2, 2023
1 parent 586f514 commit 299b68c
Show file tree
Hide file tree
Showing 11 changed files with 731 additions and 22 deletions.
26 changes: 26 additions & 0 deletions configs/llama/13B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"pipe-parallel-size": 1,
"model-parallel-size": 2,
"make_vocab_size_divisible_by": 1,

# model settings
"num-layers": 40,
"hidden-size": 5120,
"num-attention-heads": 40,
"seq-length": 2048,
"max-position-embeddings": 2048,
"pos-emb": "rotary",
"rotary-pct": 1,
"no-weight-tying": true,
"gpt-j-residual": false,
"output-layer-parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,

"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": false,
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
}
26 changes: 26 additions & 0 deletions configs/llama/30B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"pipe-parallel-size": 1,
"model-parallel-size": 4,
"make_vocab_size_divisible_by": 1,

# model settings
"num-layers": 60,
"hidden-size": 6656,
"num-attention-heads": 52,
"seq-length": 2048,
"max-position-embeddings": 2048,
"pos-emb": "rotary",
"rotary-pct": 1,
"no-weight-tying": true,
"gpt-j-residual": false,
"output-layer-parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,

"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": false,
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
}
26 changes: 26 additions & 0 deletions configs/llama/65B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"pipe-parallel-size": 1,
"model-parallel-size": 8,
"make_vocab_size_divisible_by": 1,

# model settings
"num-layers": 80,
"hidden-size": 8192,
"num-attention-heads": 64,
"seq-length": 2048,
"max-position-embeddings": 2048,
"pos-emb": "rotary",
"rotary-pct": 1,
"no-weight-tying": true,
"gpt-j-residual": false,
"output-layer-parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,

"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": false,
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
}
26 changes: 26 additions & 0 deletions configs/llama/7B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"pipe-parallel-size": 1,
"model-parallel-size": 1,
"make_vocab_size_divisible_by": 1,

# model settings
"num-layers": 32,
"hidden-size": 4096,
"num-attention-heads": 32,
"seq-length": 2048,
"max-position-embeddings": 2048,
"pos-emb": "rotary",
"rotary-pct": 1,
"no-weight-tying": true,
"gpt-j-residual": false,
"output-layer-parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,

"scaled-upper-triang-masked-softmax-fusion": true,
"bias-gelu-fusion": false,
"use_bias_in_norms": false,
"use_bias_in_attn_linear": false,
"mlp_type": "llama",
"activation": "silu",
}
2 changes: 2 additions & 0 deletions megatron/model/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def get_activation(neox_args):
activation_func = swish
elif neox_args.activation == "mish":
activation_func = mish
elif neox_args.activation == "silu":
activation_func = F.silu
else:
raise ValueError(f"Activation function {neox_args.activation} not recognized")
return activation_func
Expand Down
124 changes: 105 additions & 19 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from megatron.model.utils import exists, get_fusion_type
from megatron.model.positional_embeddings import (
RotaryEmbedding,
apply_rotary_pos_emb,
apply_rotary_pos_emb_torch,
apply_rotary_pos_emb,
AliBi,
)
from megatron.model.fused_bias_dropout import (
Expand Down Expand Up @@ -134,6 +134,65 @@ def forward(self, hidden_states):
return output, output_bias


class LLaMAParallelMLP(nn.Module):
"""LLaMA's MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Note: multiple_of is used to compute the hidden dimension of the MLP
"""

def __init__(
self, neox_args, init_method, output_layer_init_method, parallel_output=False,
multiple_of=256,
):
super().__init__()

self.activation_func = get_activation(neox_args)
self.activation_type = neox_args.activation

self.multiple_of = multiple_of

ff_dim = int(2 * neox_args.hidden_size * 4 / 3)
ff_dim = self.multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.w1 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
bias=False,
)
self.w3 = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=ff_dim,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
bias=False,
)
self.w2 = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=False,
)

def forward(self, hidden_states):
w1_out, _ = self.w1(hidden_states)
w3_out, _ = self.w3(hidden_states)
return self.w2(self.activation_func(w1_out) * w3_out)


class ParallelLinear(nn.Module):
"""
A Parallel Linear Layer transforming the transformer outputs from hidden_size -> vocab_size
Expand Down Expand Up @@ -224,6 +283,7 @@ def __init__(
output_size=3 * neox_args.hidden_size,
gather_output=False,
init_method=init_method,
bias=neox_args.use_bias_in_attn_linear,
)

coeff = None
Expand Down Expand Up @@ -310,6 +370,7 @@ def __init__(
init_method=output_layer_init_method,
skip_bias_add=True,
parallel_output=parallel_output,
bias=neox_args.use_bias_in_attn_linear,
)

def attention(
Expand Down Expand Up @@ -577,6 +638,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
else:
# full rotary
query_rot, key_rot = query_layer, key_layer

apply_rotary_fn = (
apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
)
Expand Down Expand Up @@ -673,6 +735,7 @@ def __init__(
self.bias_dropout_fusion = neox_args.bias_dropout_fusion
self.gpt_j_residual = neox_args.gpt_j_residual
self.gpt_j_tied = neox_args.gpt_j_tied
self.mlp_type = neox_args.mlp_type

if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
Expand All @@ -696,12 +759,22 @@ def __init__(
self.post_attention_layernorm = norm(neox_args.hidden_size, eps=eps)

# MLP
self.mlp = ParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
)
if neox_args.mlp_type == "regular":
self.mlp = ParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
)
elif neox_args.mlp_type == "llama":
self.mlp = LLaMAParallelMLP(
neox_args=neox_args,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
parallel_output=self.gpt_j_residual,
)
else:
raise KeyError(neox_args.mlp_type)

self.layer_past = None # used to cache k/v pairs in inference

Expand Down Expand Up @@ -779,24 +852,37 @@ def forward(self, x, attention_mask, layer_past=None):
attention_output, presents = attention_output
self.layer_past = presents
with torch.enable_grad():
attention_output = bias_dropout_fn(
attention_output,
bias=attention_bias.expand_as(residual),
residual=residual,
prob=self.hidden_dropout,
)
if attention_bias is not None:
# Use special bias_dropout_fn if we have a bias term from the above attention layer
attention_output = bias_dropout_fn(
attention_output,
bias=attention_bias.expand_as(residual),
residual=residual,
prob=self.hidden_dropout,
)
else:
# Otherwise just apply dropout + residual
attention_output = torch.nn.functional.dropout(
attention_output, p=self.hidden_dropout, training=self.training
) + residual

# output = x + mlp(ln2(x))
mlp_output, mlp_bias = self.mlp(
self.post_attention_layernorm(attention_output)
)

with torch.enable_grad():
output = bias_dropout_fn(
mlp_output,
bias=mlp_bias.expand_as(attention_output),
residual=attention_output,
prob=self.hidden_dropout,
)
if self.mlp_type == "llama":
# No dropout either
assert mlp_bias is None
output = mlp_output + attention_output
else:
output = bias_dropout_fn(
mlp_output,
bias=mlp_bias.expand_as(attention_output),
residual=attention_output,
prob=self.hidden_dropout,
)

return output

Expand Down
20 changes: 18 additions & 2 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ class NeoXArgsModel(NeoXArgsTemplate):
Pad the vocab size to be divisible by this value. This is added for computational efficiency reasons.
"""

activation: Literal["gelu", "geglu", "relu", "softsign", "swish", "mish"] = "gelu"
activation: Literal["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu"] = "gelu"
"""
Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish"]
Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu"]
"""

scaled_upper_triang_masked_softmax_fusion: bool = False
Expand Down Expand Up @@ -343,6 +343,22 @@ class NeoXArgsModel(NeoXArgsTemplate):
x = x + attn(y) + mlp(y)
"""

use_bias_in_norms: bool = True
"""
If false, norms (e.g. LayerNorm) will not have bias terms
"""
use_bias_in_attn_linear: bool = True
"""
If false, attn_linear (e.g. QKVO) will not have bias terms
"""

mlp_type: str = "regular"
"""
Types:
regular: Megatron implementation
llama: LLaMA MLP (SiLU-gated MLP)
"""

soft_prompt_tuning: dict = None
"""
Dictionary configuring the soft prompt tuning parameters.
Expand Down
1 change: 0 additions & 1 deletion megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ class HFTokenizer(AbstractTokenizer):
def __init__(self, vocab_file):
name = "HFTokenizer"
super().__init__(name)

self.tokenizer = Tokenizer.from_file(vocab_file)
self.eod_id = self.tokenizer.token_to_id("<|endoftext|>")
self.pad_id = self.tokenizer.token_to_id("<|padding|>")
Expand Down
1 change: 1 addition & 0 deletions prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"GPT2BPETokenizer",
"CharLevelTokenizer",
"TiktokenTokenizer",
"SPMTokenizer",
]
DATASET_CHOICES = [i for i in DATA_DOWNLOADERS.keys() if i != "pass"]

Expand Down
Loading

0 comments on commit 299b68c

Please sign in to comment.