-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_llama.py
258 lines (225 loc) · 10.1 KB
/
train_llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# Some code based on https://github.com/epfml/landmark-attention
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, Optional
import torch
import transformers
from transformers import Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
import sys
current_directory = os.path.dirname(os.path.abspath(__file__))
parent_directory = os.path.dirname(current_directory)
sys.path.append(current_directory)
sys.path.append(parent_directory)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
model_type: Optional[str] = field(default="llama")
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=8192 * 4,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
use_flash_attn: bool = field(
default=True,
metadata={"help": "Whether use flash attention for training."},
)
use_full_attn: bool = field(
default=False,
metadata={"help": "Whether to use plain, full-attention for training."},
)
peft_type: str = field(
default='lora',
metadata={"help": "Use low rank adaptation or other methods for finetuning."},
)
trainable_params: str = field(
default="embed,norm",
metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
)
resume_from_checkpoint: bool = field(
default=False,
metadata={"help": "resume from checkpoint of outputdir"},
)
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def tokenize_fn(tokenizer, example):
context_length = tokenizer.model_max_length
outputs = tokenizer(
tokenizer.eos_token.join(example["text"]),
truncation=False,
return_tensors="pt",
pad_to_multiple_of=context_length,
padding=True,
)
return {"input_ids": outputs["input_ids"].view(-1, context_length)}
def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
assert model_args.model_type == "llama", "Only support llama now"
# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
orig_rope_scaling = getattr(config, "rope_scaling", None)
if orig_rope_scaling is None:
orig_rope_scaling = {"factor": 1}
orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if training_args.model_max_length > orig_ctx_len:
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
# Load model and tokenizer
if training_args.peft_type == 'adape':
from models.llama.adape import MyLlamaForCausalLM
# from transformers.models.llama.modeling_llama import LlamaForCausalLM as MyLlamaForCausalLM
#! hyperparamter
config.position_size = 4 * config.num_attention_heads
config.use_flash_attention_2 = 'flash' if training_args.use_flash_attn else 'eager'
model = MyLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
elif training_args.peft_type == 'longlora':
from llama_attn_replace import replace_llama_attn
replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
# config._attn_implementation = 'eager'
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
use_flash_attention_2=False,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
elif training_args.peft_type == 'thetalora':
config.rope_theta = 1_000_000
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
use_flash_attention_2=True,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
elif training_args.peft_type == 'lora':
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
use_flash_attention_2=True,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
elif training_args.peft_type == 'ft':
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
use_flash_attention_2=True,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
else:
raise NotImplementedError
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
)
special_tokens_dict = dict()
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)
data_cache_dir = f'../data/tokenized_redpajama/{training_args.model_max_length}'
os.makedirs(data_cache_dir, exist_ok=True)
dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
dataset = dataset.map(partial(tokenize_fn, tokenizer), batched=True, num_proc=48, remove_columns=["text", "meta"], cache_file_names={'train': f"{data_cache_dir}/train.arrow"}, load_from_cache_file=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
if 'lora' in training_args.peft_type:
targets=["q_proj", "k_proj", "v_proj", "o_proj"]
config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=targets,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
# enable trainable params: embed,norm
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
elif training_args.peft_type == 'adape':
for n, p in model.named_parameters():
if not any([i in n for i in ('pe', 'post_attention_linears')]):
p.requires_grad = False
# same setting as lora and longlora
[p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
all_parameters = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
print(f"Finetuning Model Size={all_parameters/2**30:.2f}B parameters")
trainable_prameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable Parameter Size={trainable_prameters/2**30:.2f}B parameters")
print(f"Successfully added {training_args.peft_type} adapters")
model.config.use_cache = False # required for gradient checkpointing
model.enable_input_require_grads() # required for gradient checkpointing
model.gradient_checkpointing_enable() # enable gradient checkpointing
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args,
train_dataset=dataset["train"],
eval_dataset=None,
data_collator=data_collator)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()