-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlora_ftg.py
572 lines (531 loc) · 20 KB
/
lora_ftg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
import json
import os
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
LlamaConfig,
LlamaForCausalLM,
LlamaTokenizer,
TrainingArguments,
set_seed,
Trainer,
)
from llm_arch.FtGForCausalLM import FtGForCausalLM
from transformers.deepspeed import is_deepspeed_zero3_enabled
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.engine import DeepSpeedEngine
from functools import partial
import math
from multiprocessing import cpu_count
from prompter import (
batch_grouped_sft_generate,
generate_and_tokenize_prompt_with_graph,
)
from datasets import load_dataset
import transformers
from transformers.trainer_pt_utils import torch_distributed_zero_first
from transformers.trainer_utils import get_last_checkpoint
import sys
import logging
from dataclasses import dataclass, field
from typing import Optional, Union
from peft import (
prepare_model_for_int8_training,
LoraConfig,
get_peft_model
)
from transformers.utils import add_start_docstrings
from ftg_trainer import FtGTrainer
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def maybe_zero_3(param, ignore_status=False, name=None):
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={"help": "torch dtype"}
)
use_flash_attention: bool = field(
default=False,
metadata={"help": "Whether to use flash attention"}
)
llama: bool = field(default=False, metadata={"help": "Whether to use llama"})
pretrain_mm_mlp_adapter: Optional[str] = field(
default=None,
metadata={"help": "Path to pretrain mm mlp adapter"}
)
graph_hidden_dim: Optional[int] = field(
default=1000,
metadata={"help": "Graph hidden dim"}
)
mm_projector_type: Optional[str] = field(
default="linear",
metadata={"help": "mm projector type"}
)
@dataclass
class DataArguments:
train_file: Optional[str] = field(
default=None,
metadata={"help": "Path to training file"}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "Path to validation file"}
)
group_sample: bool = field(
default=False,
metadata={"help": "Whether to group sample"}
)
train_graph_emb_path: Optional[str] = field(
default="./FtG/data/processed/codex-m/train_graph_emb.pt",
metadata={"help": "Path to train graph emb"}
)
test_graph_emb_path: Optional[str] = field(
default="./FtG/data/processed/codex-m/test_graph_emb.pt",
metadata={"help": "Path to test graph emb"}
)
@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class TrainingArguments(TrainingArguments):
model_max_length: int = field(
default=512,
metadata={"help": "The maximum length of the model's input sequence"}
)
use_lora: bool = field(
default=False,
metadata={"help": "Whether to use lora"}
)
use_int8_training: bool = field(
default=False,
metadata={"help": "Whether to use int8 training"}
)
lora_config: Optional[str] = field(
default=None,
metadata={"help": "Path to lora config"}
)
ddp_find_unused_parameters: bool = field(
default=False,
metadata={"help": "Whether to use ddp find unused parameters"}
)
gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether to use gradient checkpointing"}
)
evaluation_strategy: str = field(
default="steps",
metadata={"help": "evaluation strategy"}
)
save_total_limit: Optional[int] = field(
default=3,
metadata={"help": "save total limit"}
)
load_best_model_at_end: bool = field(
default=True,
metadata={"help": "load best model at end"}
)
report_to: str = field(
default=None,
metadata={"help": "report to"}
)
deepspeed: str = field(
default=None,
metadata={"help": "deepspeed, please pass the path to deepspeed json config file, e.g., ds_config.json"}
)
do_train: bool = field(
default=True,
metadata={"help": "Whether to run training."}
)
tune_mm_mlp_adapter: bool = field(
default=True,
metadata={"help": "Whether to tune mm mlp adapter"}
)
lora_bias: str = field(
default="none",
metadata={"help": "lora bias"}
)
# remove_unused_columns: bool = field(
# default=False,
# metadata={"help": "Whether to remove unused columns"}
# )
def print_rank_0(msg, log_file, rank=0):
if rank <= 0:
with open(log_file, "a") as f:
print(msg)
f.write(msg + "\n")
def get_model_param_count(
model: Union[DeepSpeedEngine, torch.nn.Module], trainable_only=False
):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
if is_deepspeed_zero3_enabled() and isinstance(model, DeepSpeedEngine):
def numel(p):
return p.ds_numel
else:
def numel(p):
return p.numel()
return sum(
numel(p) for p in model.parameters() if not trainable_only or p.requires_grad
)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
world_size = int(os.environ.get("WORLD_SIZE", 1))
global_rank = torch.distributed.get_rank()
# global_rank = int(os.environ.get("LOCAL_RANK", 0))
log_file = os.path.join(training_args.output_dir, "log.txt")
# setup logging
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, distributed training: {bool(training_args.local_rank != -1)}, fp16-bits training: {training_args.fp16}, bf16-bits training: {training_args.bf16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# Detecting last checkpoint.
last_checkpoint = None
if (
os.path.isdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif (
last_checkpoint is not None and training_args.resume_from_checkpoint is None
):
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(training_args.seed)
training_args._frozen = False
training_args.data_seed = training_args.seed
if model_args.llama:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
print_rank_0(
"Set the eos_token_id and bos_token_id of Llama model tokenizer",
log_file,
global_rank
)
tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<unk>",
}
)
tokenizer.padding_side = "right"
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ['auto', None]
else getattr(torch, model_args.torch_dtype)
)
# int8 is not compatible with DeepSpeed (require not to pass device_map)
if training_args.use_int8_training:
print_rank_0("int8 is not compatible with DeepSpeed", log_file, global_rank)
device_map = (
{"": int(os.environ.get("LOCAL_RANK") or 0)} if world_size != 1 else "auto"
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
load_in_8bit=True,
device_map=device_map,
torch_dtype=torch_dtype,
)
else:
if model_args.llama:
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
config.vocab_size = tokenizer.vocab_size
config.pad_token_id = tokenizer.pad_token_id
config._flash_attn_2_enabled = model_args.use_flash_attention
model = FtGForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch_dtype,
)
print_rank_0(
"tokenizer.eos_token_id = {}".format(tokenizer.eos_token_id),
log_file,
global_rank,
)
print_rank_0(
"tokenizer.pad_token_id = {}".format(tokenizer.pad_token_id),
log_file,
global_rank,
)
print_rank_0(
"tokenizer.bos_token_id = {}".format(tokenizer.bos_token_id),
log_file,
global_rank,
)
# Set peft model
if training_args.use_lora:
print_rank_0(
"Loading lora config from {}".format(training_args.lora_config),
log_file,
global_rank,
)
lora_config = json.load(open(training_args.lora_config))
print_rank_0("Lora config: {}".format(lora_config), log_file, global_rank)
if training_args.use_int8_training:
print_rank_0(
"training_args.use_int8_training !!! (int8 is not compatible with DeepSpeed)",
log_file,
global_rank
)
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=lora_config["lora_r"],
lora_alpha=lora_config["lora_alpha"],
target_modules=lora_config["lora_target_modules"],
lora_dropout=lora_config["lora_dropout"],
bias="none",
task_type="CAUSAL_LM"
)
if hasattr(model, "enable_input_requires_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model = get_peft_model(model, config)
if hasattr(model_args, 'graph_hidden_dim'):
model.get_model().initialize_graph_modules(model_args)
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
print_rank_0("Use gradient checkpointing", log_file, global_rank)
model.gradient_checkpointing_enable()
assert os.path.exists(data_args.train_file), f"{data_args.train_file} file not exists"
with torch_distributed_zero_first(global_rank):
train_data = load_dataset(
"json", data_files=data_args.train_file, cache_dir=model_args.cache_dir
)
val_data = load_dataset(
"json", data_files=data_args.validation_file, cache_dir=model_args.cache_dir
)
train_graph_emb = torch.load(data_args.train_graph_emb_path)
test_graph_emb = torch.load(data_args.test_graph_emb_path)
if data_args.group_sample:
train_data = (
train_data['train'].shuffle().map(
partial(
batch_grouped_sft_generate,
training_args.model_max_length,
tokenizer,
),
batched=True,
desc=f"Grouping texts in chunks of {training_args.model_max_length}",
remove_columns=["id", "conversations"],
num_proc=cpu_count(),
)
)
val_data = val_data['train'].map(
partial(
batch_grouped_sft_generate,
training_args.model_max_length,
tokenizer,
),
batched=True,
desc=f"Grouping texts in chunks of {training_args.model_max_length}",
remove_columns=["id", "conversations"],
num_proc=cpu_count(),
)
else:
train_data = (
train_data['train'].shuffle().map(
partial(
generate_and_tokenize_prompt_with_graph,
training_args.model_max_length,
tokenizer,
padding_side="right",
train_graph_emb=train_graph_emb,
),
num_proc=cpu_count(),
)
)
val_data = val_data['train'].map(
partial(
generate_and_tokenize_prompt_with_graph,
training_args.model_max_length,
tokenizer,
padding_side="right",
train_graph_emb=test_graph_emb,
),
num_proc=cpu_count(),
)
for i in range(2):
print_rank_0(
"Eval tokenized example: {}".format(val_data[i]), log_file, global_rank
)
for i in range(2):
print_rank_0(
"Train tokenized example: {}".format(train_data[i]), log_file, global_rank
)
training_nums = len(train_data)
num_gpus = torch.cuda.device_count()
batch_size = (
training_args.per_device_train_batch_size
* training_args.world_size
* training_args.gradient_accumulation_steps
)
# train steps
t_total = math.ceil(training_nums / batch_size) * training_args.num_train_epochs
# eval steps
training_args.eval_steps = max(t_total // (training_args.num_train_epochs * 4), 5)
# save steps
training_args.save_steps = training_args.eval_steps
training_args.warmup_steps = (
int(t_total * training_args.warmup_ratio)
if training_args.warmup_ratio > 0.0
else training_args.warmup_steps
)
print_rank_0(
"num_gpus = {}, training_nums = {}, t_total = {}, warmup_steps = {}, eval_steps = {}, save_steps = {}".format(
num_gpus,
training_nums,
t_total,
training_args.warmup_steps,
training_args.eval_steps,
training_args.save_steps,
),
log_file,
global_rank,
)
print_rank_0(
"val data nums = {}, training_nums = {}, batch_size = {}".format(
len(val_data), training_nums, batch_size
),
log_file,
global_rank,
)
trainer = FtGTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
print_rank_0(
f"Using {training_args.half_precision_backend} half precision backend", log_file, global_rank
)
# Train !
len_dataloader = len(trainer.get_train_dataloader())
num_update_steps_per_epoch = math.ceil(len_dataloader / training_args.gradient_accumulation_steps)
total_train_batch_size = (
training_args.train_batch_size
* training_args.gradient_accumulation_steps
* training_args.world_size
)
num_examples = trainer.num_examples(trainer.get_train_dataloader())
num_train_samples = num_examples * training_args.num_train_epochs
max_steps = math.ceil(training_args.num_train_epochs * num_update_steps_per_epoch)
print_rank_0(f" Num examples = {num_examples}", log_file, global_rank)
print_rank_0(f" Num train samples = {num_train_samples}", log_file, global_rank)
print_rank_0(f" world_size = {world_size}", log_file, global_rank)
print_rank_0(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}",
log_file,
global_rank,
)
print_rank_0(
f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}",
log_file,
global_rank,
)
print_rank_0(f" Total optimization steps = {max_steps}", log_file, global_rank)
print_rank_0(
f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}",
log_file,
global_rank,
)
model.config.use_cache = False
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
# trainer.save_model(training_args.output_dir)
trainer.save_state()
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'mm_projector_final.bin'))
print_rank_0("Training done", log_file, global_rank)
if __name__ == "__main__":
main()