Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Async checkpointing support #12

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,8 @@ def _add_data_args(parser):
help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.")
group.add_argument('--num-checkpoint-workers', type=int, default=2,
help="Number of checkpoint workers")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
Expand Down
40 changes: 38 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

_CHECKPOINT_VERSION = None

_CHECKPOINT_TASK_LIST = []
_CHECKPOINT_NUM_TASKS = 0
dlp = Profile("CHECKPOINT")
def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
Expand Down Expand Up @@ -233,9 +235,11 @@ def get_rng_state():
rng_state_list = [rng_state]

return rng_state_list
from multiprocessing import Process


@dlp.log
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
@dlp.log
def save_checkpoint_sync(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint."""
args = get_args()
assert args is not None
Expand Down Expand Up @@ -339,6 +343,38 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v
if torch.distributed.is_initialized():
torch.distributed.barrier()

def wait_checkpoint():
print_rank_0("waiting for previous checkpointing to finish")
global _CHECKPOINT_TASK_LIST
for t in _CHECKPOINT_TASK_LIST:
t.join()
_CHECKPOINT_TASK_LIST = []

def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
'''
This is for performing async checkpoint
'''
args = get_args()
assert aargs is not None
num_checkpoint_workers = args.num_checkpoint_workers
global _CHECKPOINT_TASK_LIST
global _CHECKPOINT_NUM_TASKS
if args.num_checkpoint_workers > 0:
print_rank_0("Async checkpointing")
if _CHECKPOINT_NUM_TASKS < num_checkpoint_workers:
proc = Process(target=save_checkpoint_sync, args=(iteration, model, optimizer, opt_param_scheduler))
proc.start()
else:
wait_checkpoint()
_CHECKPOINT_NUM_TASKS = 0
proc = Process(target=save_checkpoint_sync, args=(iteration, model, optimizer, opt_param_scheduler))
proc.start()
_CHECKPOINT_TASK_LIST.append(proc)
_CHECKPOINT_NUM_TASKS += 1
else:
save_checkpoint_sync(iteration, model, optimizer, opt_param_scheduler)


@dlp.log
def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
Expand Down
6 changes: 6 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.checkpointing import wait_checkpoint
from megatron.model import Float16Module
from megatron.model import GPTModel
from megatron.core.enums import ModelType
Expand Down Expand Up @@ -332,6 +333,7 @@ def pretrain(
else:
log.info("skipping training (--skip-train is on) ...")
iteration = args.iteration

config = core_transformer_config_from_args(args)
if args.do_valid:
prefix = f"iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from validation set"
Expand Down Expand Up @@ -360,6 +362,7 @@ def pretrain(
write_to_tensorboard=not args.skip_train,
test=True,
)
wait_checkpoint()
return model


Expand Down Expand Up @@ -1797,6 +1800,7 @@ def train(
iteration, model, optimizer, opt_param_scheduler
)
print_datetime("exiting program after receiving SIGTERM.")
wait_checkpoint()
sys.exit()
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler)
Expand All @@ -1815,13 +1819,15 @@ def train(
iteration, model, optimizer, opt_param_scheduler
)
print_datetime("exiting program after {} minutes".format(train_time))
wait_checkpoint()
sys.exit()
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if args.save and not saved_checkpoint:
save_checkpoint_and_time(
iteration, model, optimizer, opt_param_scheduler
)
wait_checkpoint()
torch.distributed.barrier()
print_datetime("exiting program at iteration {}".format(iteration))
sys.exit()
Expand Down