-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase_config.yaml
97 lines (85 loc) · 3.1 KB
/
base_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
general:
project_name: 'grokking'
run_name: '_time:' # default run name is timestamp
run_description: null
out_dir: '$OUT_DIR/$project_name/$run_name'
device_type: 'cuda' # auto select if blank or cpu, cuda
distributed_backend: 'nccl' # nccl, gloo, mpi, horovod)
distributed_init_method: 'env://' # for horovod, use 'tcp://localhost:23456'
torch_compile: true # will not compile if Python > 3.11, torch < 2.1.0 or Windows
seed: 42
dtype: 'float32' # float32, float16, bfloat16
enable_distributed: null # automatically set to true if WORLD_SIZE > 1
_env: # creates env vars
project_name: '_copy: /general/project_name'
run_name: '_copy: /general/run_name'
logging:
project_name: '_copy: /general/project_name'
run_name: '_copy: /general/run_name'
log_dir: '_copy: /general/out_dir'
run_description: '_copy: /general/run_description'
enable_wandb: false
log_filename: 'log.txt'
summaries_filename: 'summaries.txt'
allow_overwrite_log: true
metrics_type: 'classification'
summaries_stdout: true
model:
module: 'nanugpt.models.tiny_transformer.get_model'
module_kwargs:
n_layer: 2
n_embd: 128
n_head: 4
context_length: 5 # currently each input eq has [eos a op b =] which is 5 tokens
scaler:
module: 'nanugpt.scalers.amp_grad_scaler.get_scaler'
module_kwargs: {}
loss:
module: 'nanugpt.losses.grokking_loss.get_loss_factory'
module_kwargs: {}
data:
module: 'nanugpt.data.grokking_data.get_data'
module_kwargs:
operation: 'x/y'
training_fraction: 0.5
val_fraction: null # if null, use 1 - training_fraction, test fraction is 0
device_batch_size: '_copy: /training/device_batch_size'
eval_batch_size: 32768 # 2^15=32768
data_loader_seed: 8
context_length: '_copy: /model/module_kwargs/context_length'
prime: &prime null # to be set by overriden config
tokenizer:
module: 'nanugpt.tokenizers.grokking_tokenizer.get_tokenizer_factory'
module_kwargs:
prime: *prime
training:
device_batch_size: 512
max_steps: 3000 # 1e5 is not enough when weight_decay=0.0
enable_train_log: false
log_every: 20
grad_clip: 0.0 # disabled if 0.0
global_batch_size: 512 # will be automatically divided by GPU count
optimizer:
module: 'nanugpt.optimizers.adamw.get_optim'
module_kwargs:
learning_rate: 1.0E-3
weight_decay: 0.1 # weight_decay=1 makes convergence much faster and original graph is not reproducible
beta1: 0.9
beta2: 0.98
eps: 1.0E-8 # pytorch default
zero_stage: 0 # 0: off, 1: shard optimizer
scheduler:
module: 'nanugpt.schedulers.constant.get_scheduler'
module_kwargs:
warmup_iters: 10
cooldown_iters: 0
max_iters: null
end_factor: 1.0E-2 # factor to multiply at the end
const_lr: '_copy: /optimizer/module_kwargs/learning_rate'
eval:
eval_every: 100
eval_iters: null # number of samples to evaluate for dataset
save_checkpoint: false
checkoint_after: 0 # starts saving checkpoint after these steps
checkpoint_every_hr: 2 # take checkpoint every this hours
checkpoint_keep_best: True # keep only the best checkpoint, otherwise keep all with _{step}.pt