From e83538fb3bf097766e7190b9199374a5b2a46019 Mon Sep 17 00:00:00 2001 From: Dheevatsa Date: Mon, 1 Jul 2019 15:05:35 -0700 Subject: [PATCH] added model saving, loading and checkpointing support to PyTorch (#8) * added model saving, loading and checkpointing support to PyTorch * minor fixes; updated README --- README.md | 12 +++ data_utils.py | 4 +- dlrm_data_caffe2.py | 182 ++++++++++++++++++++-------------------- dlrm_data_pytorch.py | 195 +++++++++++++++++++++---------------------- dlrm_s_caffe2.py | 21 +++-- dlrm_s_pytorch.py | 166 +++++++++++++++++++++++++++--------- 6 files changed, 340 insertions(+), 240 deletions(-) diff --git a/README.md b/README.md index d6984569..c96890d2 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,18 @@ Benchmarking *NOTE: Benchmarking scripts accept extra arguments which will passed along, such as --num-batches=100 to limit the number of data samples* +Model saving/checkpointing and using saved models +------------------------------------------------- +During training, the model can be saved using --save-model= + +The model is saved every time the testing is done (as specified by --test-freq) and only if there is an improvement in test accuracy. + +A previously saved model can be loaded using --load-model= + +Once loaded the model can be used to continue training, with the saved model being a checkpoint. +Alternatively, the saved model can be used to evaluate only on the test data-set by specifying --inference-only option. + + Version ------- 0.1 : Initial release of the DLRM code diff --git a/data_utils.py b/data_utils.py index 1f316331..d9752b7f 100644 --- a/data_utils.py +++ b/data_utils.py @@ -81,9 +81,7 @@ def processKaggleCriteoAdData(split, d_path): break # process data if not all files exist - if idx < split + 1: - - # process data + if idx <= split: for i in range(1, split + 1): with np.load(str(d_path) + "kaggle_day_{0}.npz".format(i)) as data: diff --git a/dlrm_data_caffe2.py b/dlrm_data_caffe2.py index 20704ca3..94496add 100644 --- a/dlrm_data_caffe2.py +++ b/dlrm_data_caffe2.py @@ -41,6 +41,7 @@ def read_dataset( split=True, raw_data="", processed_data="", + inference_only=False, ): # load print("Loading %s dataset..." % dataset) @@ -68,63 +69,66 @@ def read_dataset( print("Sparse features = %d, Dense features = %d" % (n_emb, m_den)) # adjust parameters - lX = [] - lS = [] - lS_lengths = [] - lS_indices = [] - lT = [] - train_nsamples = len(y_train) - data_size = train_nsamples - nbatches = int(np.floor((data_size * 1.0) / mini_batch_size)) - print("Training data") - if num_batches != 0 and num_batches < nbatches: - print( - "Limiting to %d batches of the total % d batches" % (num_batches, nbatches) - ) - nbatches = num_batches - else: - print("Total number of batches %d" % nbatches) - - # training data main loop - for j in range(0, nbatches): - # number of data points in a batch - print("Reading in batch: %d / %d" % (j + 1, nbatches), end="\r") - n = min(mini_batch_size, data_size - (j * mini_batch_size)) - # dense feature - idx_start = j * mini_batch_size - # WARNING: X_int_train is a PyTorch tensor - lX.append((X_int_train[idx_start : (idx_start + n)]).numpy().astype(np.float32)) - # Training targets - outputs - # WARNING: y_train is a PyTorch tensor - lT.append( - (y_train[idx_start : idx_start + n]) - .numpy() - .reshape(-1, 1) - .astype(np.int32) - ) - # sparse feature (sparse indices) - lS_emb_indices = [] - # for each embedding generate a list of n lookups, - # where each lookup is composed of multiple sparse indices - for size in range(n_emb): - lS_batch_indices = [] - for _b in range(n): - # num of sparse indices to be used per embedding, e.g. for - # store lengths and indices - lS_batch_indices += ( - (X_cat_train[idx_start + _b][size].view(-1)) - .numpy() - .astype(np.int32) - ).tolist() - lS_emb_indices.append(lS_batch_indices) - lS_indices.append(lS_emb_indices) - # Criteo Kaggle data it is 1 because data is categorical - lS_lengths.append([(list(np.ones(n).astype(np.int32))) for _ in range(n_emb)]) - - lS = lS_indices.copy() + if not inference_only: + lX = [] + lS_lengths = [] + lS_indices = [] + lT = [] + train_nsamples = len(y_train) + data_size = train_nsamples + nbatches = int(np.floor((data_size * 1.0) / mini_batch_size)) + print("Training data") + if num_batches != 0 and num_batches < nbatches: + print( + "Limiting to %d batches of the total % d batches" + % (num_batches, nbatches) + ) + nbatches = num_batches + else: + print("Total number of batches %d" % nbatches) + + # training data main loop + for j in range(0, nbatches): + # number of data points in a batch + print("Reading in batch: %d / %d" % (j + 1, nbatches), end="\r") + n = min(mini_batch_size, data_size - (j * mini_batch_size)) + # dense feature + idx_start = j * mini_batch_size + # WARNING: X_int_train is a PyTorch tensor + lX.append( + (X_int_train[idx_start : (idx_start + n)]).numpy().astype(np.float32) + ) + # Training targets - outputs + # WARNING: y_train is a PyTorch tensor + lT.append( + (y_train[idx_start : idx_start + n]) + .numpy() + .reshape(-1, 1) + .astype(np.int32) + ) + # sparse feature (sparse indices) + lS_emb_indices = [] + # for each embedding generate a list of n lookups, + # where each lookup is composed of multiple sparse indices + for size in range(n_emb): + lS_batch_indices = [] + for _b in range(n): + # num of sparse indices to be used per embedding, e.g. for + # store lengths and indices + lS_batch_indices += ( + (X_cat_train[idx_start + _b][size].view(-1)) + .numpy() + .astype(np.int32) + ).tolist() + lS_emb_indices.append(lS_batch_indices) + lS_indices.append(lS_emb_indices) + # Criteo Kaggle data it is 1 because data is categorical + lS_lengths.append( + [(list(np.ones(n).astype(np.int32))) for _ in range(n_emb)] + ) + print("\n") # adjust parameters - print("\n") lX_test = [] lS_lengths_test = [] lS_indices_test = [] @@ -154,10 +158,7 @@ def read_dataset( # Training targets - outputs # WARNING: y_train is a PyTorch tensor lT.append( - (y_test[idx_start : idx_start + n]) - .numpy() - .reshape(-1, 1) - .astype(np.int32) + (y_test[idx_start : idx_start + n]).numpy().reshape(-1, 1).astype(np.int32) ) # sparse feature (sparse indices) lS_emb_indices = [] @@ -178,21 +179,36 @@ def read_dataset( [(list(np.ones(n).astype(np.int32))) for _ in range(n_emb)] ) - return ( - nbatches, - lX, - lS, - lS_lengths, - lS_indices, - lT, - nbatches_test, - lX_test, - lS_lengths_test, - lS_indices_test, - lT_test, - ln_emb, - m_den, - ) + if not inference_only: + return ( + nbatches, + lX, + lS_lengths, + lS_indices, + lT, + nbatches_test, + lX_test, + lS_lengths_test, + lS_indices_test, + lT_test, + ln_emb, + m_den, + ) + else: + return ( + nbatches_test, + lX_test, + lS_lengths_test, + lS_indices_test, + lT_test, + None, + None, + None, + None, + None, + ln_emb, + m_den, + ) # uniform ditribution (input data) @@ -214,7 +230,6 @@ def generate_random_input_data( # inputs and targets lX = [] - lS = [] lS_lengths = [] lS_indices = [] for j in range(0, nbatches): @@ -224,13 +239,11 @@ def generate_random_input_data( Xt = ra.rand(n, m_den).astype(np.float32) lX.append(Xt) # sparse feature (sparse indices) - lS_emb = [] lS_emb_lengths = [] lS_emb_indices = [] # for each embedding generate a list of n lookups, # where each lookup is composed of multiple sparse indices for size in ln_emb: - lS_batch = [] lS_batch_lengths = [] lS_batch_indices = [] for _ in range(n): @@ -249,17 +262,14 @@ def generate_random_input_data( # reset sparse_group_size in case some index duplicates were removed sparse_group_size = np.int32(sparse_group.size) # store lengths and indices - lS_batch.append(sparse_group.tolist()) lS_batch_lengths += [sparse_group_size] lS_batch_indices += sparse_group.tolist() - lS_emb.append(lS_batch) lS_emb_lengths.append(lS_batch_lengths) lS_emb_indices.append(lS_batch_indices) - lS.append(lS_emb) lS_lengths.append(lS_emb_lengths) lS_indices.append(lS_emb_indices) - return (nbatches, lX, lS, lS_lengths, lS_indices) + return (nbatches, lX, lS_lengths, lS_indices) # uniform distribution (output data) @@ -307,7 +317,6 @@ def generate_synthetic_input_data( # inputs and targets lX = [] - lS = [] lS_lengths = [] lS_indices = [] for j in range(0, nbatches): @@ -317,13 +326,11 @@ def generate_synthetic_input_data( Xt = ra.rand(n, m_den).astype(np.float32) lX.append(Xt) # sparse feature (sparse indices) - lS_emb = [] lS_emb_lengths = [] lS_emb_indices = [] # for each embedding generate a list of n lookups, # where each lookup is composed of multiple sparse indices for i, size in enumerate(ln_emb): - lS_batch = [] lS_batch_lengths = [] lS_batch_indices = [] for _ in range(n): @@ -369,17 +376,14 @@ def generate_synthetic_input_data( # reset sparse_group_size in case some index duplicates were removed sparse_group_size = np.int32(sparse_group.size) # store lengths and indices - lS_batch.append(sparse_group.tolist()) lS_batch_lengths += [sparse_group_size] lS_batch_indices += sparse_group.tolist() - lS_emb.append(lS_batch) lS_emb_lengths.append(lS_batch_lengths) lS_emb_indices.append(lS_batch_indices) - lS.append(lS_emb) lS_lengths.append(lS_emb_lengths) lS_indices.append(lS_emb_indices) - return (nbatches, lX, lS, lS_lengths, lS_indices) + return (nbatches, lX, lS_lengths, lS_indices) def generate_stack_distance(cumm_val, cumm_dist, max_i, i, enable_padding=False): @@ -467,7 +471,7 @@ def trace_profile(trace, enable_padding=False): try: # found # i = rstack.index(r) # WARNING: I believe below is the correct depth in terms of meaning of the - # algorithm, but that's not what seems to be in the paper alg. + # algorithm, but that is not what seems to be in the paper alg. # -1 can be subtracted if we defined the distance between # consecutive accesses (e.g. r, r) as 0 rather than 1. sd = l - i # - 1 diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py index 945a9ea4..5b28c28b 100644 --- a/dlrm_data_pytorch.py +++ b/dlrm_data_pytorch.py @@ -44,6 +44,7 @@ def read_dataset( split=True, raw_data="", processed_data="", + inference_only=False, ): # load print("Loading %s dataset..." % dataset) @@ -71,75 +72,70 @@ def read_dataset( print("Sparse features = %d, Dense features = %d" % (n_emb, m_den)) # adjust parameters - lX = [] - lS = [] - lS_offsets = [] - lS_indices = [] - lT = [] - train_nsamples = len(y_train) - data_size = train_nsamples - nbatches = int(np.floor((data_size * 1.0) / mini_batch_size)) - print("Training data") - if num_batches != 0 and num_batches < nbatches: - print( - "Limiting to %d batches of the total % d batches" % (num_batches, nbatches) - ) - nbatches = num_batches - else: - print("Total number of batches %d" % nbatches) - - # training data main loop - for j in range(0, nbatches): - # number of data points in a batch - print("Reading in batch: %d / %d" % (j + 1, nbatches), end="\r") - n = min(mini_batch_size, data_size - (j * mini_batch_size)) - # dense feature - idx_start = j * mini_batch_size - # WARNING: X_int_train is a PyTorch tensor - lX.append( - torch.tensor( - (X_int_train[idx_start : (idx_start + n)]).numpy().astype(np.float32) + if not inference_only: + lX = [] + lS_offsets = [] + lS_indices = [] + lT = [] + train_nsamples = len(y_train) + data_size = train_nsamples + nbatches = int(np.floor((data_size * 1.0) / mini_batch_size)) + print("Training data") + if num_batches != 0 and num_batches < nbatches: + print( + "Limiting to %d batches of the total % d batches" + % (num_batches, nbatches) ) - ) - # Training targets - ouptuts - # WARNING: y_train is a PyTorch tensor - lT.append( - torch.tensor( - (y_train[idx_start : idx_start + n]) - .numpy() - .reshape(-1, 1) - .astype(np.float32) + nbatches = num_batches + else: + print("Total number of batches %d" % nbatches) + + # training data main loop + for j in range(0, nbatches): + # number of data points in a batch + print("Reading in batch: %d / %d" % (j + 1, nbatches), end="\r") + n = min(mini_batch_size, data_size - (j * mini_batch_size)) + # dense feature + idx_start = j * mini_batch_size + # WARNING: X_int_train is a PyTorch tensor + lX.append( + torch.tensor( + (X_int_train[idx_start : (idx_start + n)]) + .numpy() + .astype(np.float32) + ) ) - ) - # sparse feature (sparse indices) - lS_emb = [] - lS_emb_indices = [] - # for each embedding generate a list of n lookups, - # where each lookup is composed of multiple sparse indices - for size in range(n_emb): - lS_batch_indices = [] - for _b in range(n): - # WARNING: X_cat_train is a PyTorch tensor - # store lengths and indices - lS_batch_indices += ( - (X_cat_train[idx_start + _b][size].view(-1)) + # Training targets - ouptuts + # WARNING: y_train is a PyTorch tensor + lT.append( + torch.tensor( + (y_train[idx_start : idx_start + n]) .numpy() - .astype(np.int64) - ).tolist() - lS_emb_indices.append(torch.tensor(lS_batch_indices)) - lS_emb.append(lS_batch_indices) - lS_indices.append(lS_emb_indices) - lS.append(lS_emb) - # Criteo Kaggle data it is 1 because data is categorical - lS_offsets.append( - [ - torch.tensor(list(range(n))) - for _ in range(n_emb) - ] - ) + .reshape(-1, 1) + .astype(np.float32) + ) + ) + # sparse feature (sparse indices) + lS_emb_indices = [] + # for each embedding generate a list of n lookups, + # where each lookup is composed of multiple sparse indices + for size in range(n_emb): + lS_batch_indices = [] + for _b in range(n): + # WARNING: X_cat_train is a PyTorch tensor + # store lengths and indices + lS_batch_indices += ( + (X_cat_train[idx_start + _b][size].view(-1)) + .numpy() + .astype(np.int64) + ).tolist() + lS_emb_indices.append(torch.tensor(lS_batch_indices)) + lS_indices.append(lS_emb_indices) + # Criteo Kaggle data it is 1 because data is categorical + lS_offsets.append([torch.tensor(list(range(n))) for _ in range(n_emb)]) + print("\n") # adjust parameters - print("\n") lX_test = [] lS_offsets_test = [] lS_indices_test = [] @@ -195,28 +191,39 @@ def read_dataset( lS_emb_indices.append(torch.tensor(lS_batch_indices)) lS_indices_test.append(lS_emb_indices) # Criteo Kaggle data it is 1 because data is categorical - lS_offsets_test.append( - [ - torch.tensor(list(range(n))) - for _ in range(n_emb) - ] - ) + lS_offsets_test.append([torch.tensor(list(range(n))) for _ in range(n_emb)]) + print("\n") - return ( - nbatches, - lX, - lS, - lS_offsets, - lS_indices, - lT, - nbatches_test, - lX_test, - lS_offsets_test, - lS_indices_test, - lT_test, - ln_emb, - m_den, - ) + if not inference_only: + return ( + nbatches, + lX, + lS_offsets, + lS_indices, + lT, + nbatches_test, + lX_test, + lS_offsets_test, + lS_indices_test, + lT_test, + ln_emb, + m_den, + ) + else: + return ( + nbatches_test, + lX_test, + lS_offsets_test, + lS_indices_test, + lT_test, + None, + None, + None, + None, + None, + ln_emb, + m_den, + ) # uniform ditribution (input data) @@ -238,7 +245,6 @@ def generate_random_input_data( # inputs lX = [] - lS = [] lS_offsets = [] lS_indices = [] for j in range(0, nbatches): @@ -248,13 +254,11 @@ def generate_random_input_data( Xt = ra.rand(n, m_den).astype(np.float32) lX.append(torch.tensor(Xt)) # sparse feature (sparse indices) - lS_emb = [] lS_emb_offsets = [] lS_emb_indices = [] # for each embedding generate a list of n lookups, # where each lookup is composed of multiple sparse indices for size in ln_emb: - lS_batch = [] lS_batch_offsets = [] lS_batch_indices = [] offset = 0 @@ -274,19 +278,16 @@ def generate_random_input_data( # reset sparse_group_size in case some index duplicates were removed sparse_group_size = np.int64(sparse_group.size) # store lengths and indices - lS_batch.append(sparse_group.tolist()) lS_batch_offsets += [offset] lS_batch_indices += sparse_group.tolist() # update offset for next iteration offset += sparse_group_size - lS_emb.append(lS_batch) lS_emb_offsets.append(torch.tensor(lS_batch_offsets)) lS_emb_indices.append(torch.tensor(lS_batch_indices)) - lS.append(lS_emb) lS_offsets.append(lS_emb_offsets) lS_indices.append(lS_emb_indices) - return (nbatches, lX, lS, lS_offsets, lS_indices) + return (nbatches, lX, lS_offsets, lS_indices) # uniform distribution (output data) @@ -334,7 +335,6 @@ def generate_synthetic_input_data( # inputs and targets lX = [] - lS = [] lS_offsets = [] lS_indices = [] for j in range(0, nbatches): @@ -344,13 +344,11 @@ def generate_synthetic_input_data( Xt = ra.rand(n, m_den).astype(np.float32) lX.append(torch.tensor(Xt)) # sparse feature (sparse indices) - lS_emb = [] lS_emb_offsets = [] lS_emb_indices = [] # for each embedding generate a list of n lookups, # where each lookup is composed of multiple sparse indices for i, size in enumerate(ln_emb): - lS_batch = [] lS_batch_offsets = [] lS_batch_indices = [] offset = 0 @@ -397,19 +395,16 @@ def generate_synthetic_input_data( # reset sparse_group_size in case some index duplicates were removed sparse_group_size = np.int64(sparse_group.size) # store lengths and indices - lS_batch.append(sparse_group.tolist()) lS_batch_offsets += [offset] lS_batch_indices += sparse_group.tolist() # update offset for next iteration offset += sparse_group_size - lS_emb.append(lS_batch) lS_emb_offsets.append(torch.tensor(lS_batch_offsets)) lS_emb_indices.append(torch.tensor(lS_batch_indices)) - lS.append(lS_emb) lS_offsets.append(lS_emb_offsets) lS_indices.append(lS_emb_indices) - return (nbatches, lX, lS, lS_offsets, lS_indices) + return (nbatches, lX, lS_offsets, lS_indices) def generate_stack_distance(cumm_val, cumm_dist, max_i, i, enable_padding=False): @@ -498,7 +493,7 @@ def trace_profile(trace, enable_padding=False): try: # found # i = rstack.index(r) # WARNING: I believe below is the correct depth in terms of meaning of the - # algorithm, but that"s not what seems to be in the paper alg. + # algorithm, but that is not what seems to be in the paper alg. # -1 can be subtracted if we defined the distance between # consecutive accesses (e.g. r, r) as 0 rather than 1. sd = l - i # - 1 diff --git a/dlrm_s_caffe2.py b/dlrm_s_caffe2.py index e17a8aa6..11a0bb4b 100644 --- a/dlrm_s_caffe2.py +++ b/dlrm_s_caffe2.py @@ -839,7 +839,7 @@ def print_activations(self): ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") if args.data_generation == "dataset": # input and target data - (nbatches, lX, lS, lS_l, lS_i, lT, + (nbatches, lX, lS_l, lS_i, lT, nbatches_test, lX_test, lS_l_test, lS_i_test, lT_test, ln_emb, m_den) = dc.read_dataset( args.data_set, args.mini_batch_size, args.data_randomize, args.num_batches, @@ -850,12 +850,12 @@ def print_activations(self): ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") m_den = ln_bot[0] if args.data_generation == "random": - (nbatches, lX, lS, lS_l, lS_i) = dc.generate_random_input_data( + (nbatches, lX, lS_l, lS_i) = dc.generate_random_input_data( args.data_size, args.num_batches, args.mini_batch_size, args.round_targets, args.num_indices_per_lookup, args.num_indices_per_lookup_fixed, m_den, ln_emb) elif args.data_generation == "synthetic": - (nbatches, lX, lS, lS_l, lS_i) = dc.generate_synthetic_input_data( + (nbatches, lX, lS_l, lS_i) = dc.generate_synthetic_input_data( args.data_size, args.num_batches, args.mini_batch_size, args.round_targets, args.num_indices_per_lookup, args.num_indices_per_lookup_fixed, m_den, ln_emb, @@ -925,9 +925,8 @@ def print_activations(self): for j in range(0, nbatches): print("mini-batch: %d" % j) print(lX[j]) - print(lS[j]) - # print(lS_l[j]) - # print(lS_i[j]) + print(lS_l[j]) + print(lS_i[j]) print(lT[j]) ### construct the neural network specified above ### @@ -990,8 +989,11 @@ def print_activations(self): total_loss = 0 total_accu = 0 total_iter = 0 - for k in range(0, args.nepochs): - for j in range(0, nbatches): + k = 0 + + while k < args.nepochs: + j = 0 + while j < nbatches: # forward and backward pass, where the latter runs only # when gradients and loss have been added to the net time1 = time.time() @@ -1027,6 +1029,9 @@ def print_activations(self): # debug prints # print(Z) # print(T) + + j += 1 # nbatches + k += 1 # nepochs # test prints if not args.inference_only and args.debug_mode: diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py index 002311e7..a39a3467 100644 --- a/dlrm_s_pytorch.py +++ b/dlrm_s_pytorch.py @@ -56,6 +56,7 @@ # miscellaneous import bisect import builtins +import shutil import time # data generation @@ -441,6 +442,9 @@ def parallel_forward(self, dense_x, lS_o, lS_i): parser.add_argument("--debug-mode", action="store_true", default=False) parser.add_argument("--enable-profiling", action="store_true", default=False) parser.add_argument("--plot-compute-graph", action="store_true", default=False) + + parser.add_argument("--save-model", type=str, default="") + parser.add_argument("--load-model", type=str, default="") args = parser.parse_args() ### some basic setup ### @@ -468,7 +472,6 @@ def parallel_forward(self, dense_x, lS_o, lS_i): ( nbatches, lX, - lS, lS_o, lS_i, lT, @@ -487,6 +490,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i): True, args.raw_data_file, args.processed_data_file, + args.inference_only, ) ln_bot[0] = m_den else: @@ -494,7 +498,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i): ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") m_den = ln_bot[0] if args.data_generation == "random": - (nbatches, lX, lS, lS_o, lS_i) = dp.generate_random_input_data( + (nbatches, lX, lS_o, lS_i) = dp.generate_random_input_data( args.data_size, args.num_batches, args.mini_batch_size, @@ -505,7 +509,7 @@ def parallel_forward(self, dense_x, lS_o, lS_i): ln_emb, ) elif args.data_generation == "synthetic": - (nbatches, lX, lS, lS_o, lS_i) = dp.generate_synthetic_input_data( + (nbatches, lX, lS_o, lS_i) = dp.generate_synthetic_input_data( args.data_size, args.num_batches, args.mini_batch_size, @@ -611,7 +615,16 @@ def parallel_forward(self, dense_x, lS_o, lS_i): for j in range(0, nbatches): print("mini-batch: %d" % j) print(lX[j].detach().cpu().numpy()) - print(lS[j]) + # transform offsets to lengths when printing + print( + [ + np.diff( + S_o.detach().cpu().tolist() + list(lS_i[j][i].shape) + ).tolist() + for i, S_o in enumerate(lS_o[j]) + ] + ) + print([S_i.detach().cpu().tolist() for S_i in lS_i[j]]) print(lT[j].detach().cpu().numpy()) ### construct the neural network specified above ### @@ -642,18 +655,15 @@ def parallel_forward(self, dense_x, lS_o, lS_i): dlrm.ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) dlrm = dlrm.to(device) # .cuda() - # add training loss if needed - if not args.inference_only: - # specify the loss function - if args.loss_function == "mse": - loss_fn = torch.nn.MSELoss(reduction="mean") - elif args.loss_function == "bce": - loss_fn = torch.nn.BCELoss(reduction="mean") - else: - sys.exit( - "ERROR: --loss-function=" + args.loss_function + " is not supported" - ) + # specify the loss function + if args.loss_function == "mse": + loss_fn = torch.nn.MSELoss(reduction="mean") + elif args.loss_function == "bce": + loss_fn = torch.nn.BCELoss(reduction="mean") + else: + sys.exit("ERROR: --loss-function=" + args.loss_function + " is not supported") + if not args.inference_only: # specify the optimizer algorithm optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate) @@ -679,29 +689,73 @@ def loss_fn_wrap(Z, T, use_gpu, device): else: return loss_fn(Z, T) - # training + # training + + best_gA_test = 0 + total_time = 0 + total_loss = 0 + total_accu = 0 + total_iter = 0 + k = 0 + + # Load model is specified + if not (args.load_model == ""): + print("Loading saved mode {}".format(args.load_model)) + ld_model = torch.load(args.load_model) + dlrm.load_state_dict(ld_model["state_dict"]) + ld_j = ld_model["iter"] + ld_k = ld_model["epoch"] + ld_nepochs = ld_model["nepochs"] + ld_nbatches = ld_model["nbatches"] + ld_nbatches_test = ld_model["nbatches_test"] + ld_gA = ld_model["train_acc"] + ld_gL = ld_model["train_loss"] + ld_total_loss = ld_model["total_loss"] + ld_total_accu = ld_model["total_accu"] + ld_gA_test = ld_model["test_acc"] + ld_gL_test = ld_model["test_loss"] + if not args.inference_only: + optimizer.load_state_dict(ld_model["opt_state_dict"]) + best_gA_test = ld_gA_test + total_loss = ld_total_loss + total_accu = ld_total_accu + k = ld_k # epochs + j = ld_j # batches + else: + args.print_freq = ld_nbatches + args.test_freq = 0 + print( + "Saved model Training state: epoch = {:d}/{:d}, batch = {:d}/{:d}, train loss = {:.6f}, train accuracy = {:3.3f} %".format( + ld_k, ld_nepochs, ld_j, ld_nbatches, ld_gL, ld_gA * 100 + ) + ) + print( + "Saved model Testing state: nbatches = {:d}, test loss = {:.6f}, test accuracy = {:3.3f} %".format( + ld_nbatches_test, ld_gL_test, ld_gA_test * 100 + ) + ) + print("time/loss/accuracy (if enabled):") with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof: - total_time = 0 - total_loss = 0 - total_accu = 0 - total_iter = 0 - for k in range(0, args.nepochs): - for j in range(0, nbatches): + while k < args.nepochs: + j = 0 + while j < nbatches: t1 = time_wrap(use_gpu) + # forward pass Z = dlrm_wrap(lX[j], lS_o[j], lS_i[j], use_gpu, device) - if not args.inference_only: - # loss - E = loss_fn_wrap(Z, lT[j], use_gpu, device) - # compute loss and accuracy - L = E.detach().cpu().numpy() # numpy array - S = Z.detach().cpu().numpy() # numpy array - T = lT[j].detach().cpu().numpy() # numpy array - mbs = T.shape[0] # = args.mini_batch_size except maybe for last - A = np.sum((np.round(S, 0) == T).astype(np.uint8)) / mbs + # loss + E = loss_fn_wrap(Z, lT[j], use_gpu, device) + + # compute loss and accuracy + L = E.detach().cpu().numpy() # numpy array + S = Z.detach().cpu().numpy() # numpy array + T = lT[j].detach().cpu().numpy() # numpy array + mbs = T.shape[0] # = args.mini_batch_size except maybe for last + A = np.sum((np.round(S, 0) == T).astype(np.uint8)) / mbs + if not args.inference_only: # scaled error gradient propagation # (where we do not accumulate gradients across mini-batches) optimizer.zero_grad() @@ -717,13 +771,20 @@ def loss_fn_wrap(Z, T, use_gpu, device): t2 = time_wrap(use_gpu) total_time += t2 - t1 - total_accu += 0 if args.inference_only else A - total_loss += 0 if args.inference_only else L + total_accu += A + total_loss += L total_iter += 1 - # print time, loss and accuracy print_tl = ((j + 1) % args.print_freq == 0) or (j + 1 == nbatches) - if print_tl: + print_ts = ( + (args.test_freq > 0) + and (args.data_generation == "dataset") + and (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches)) + or (j == 0) + ) + + # print time, loss and accuracy + if print_tl or print_ts: gT = 1000.0 * total_time / total_iter if args.print_time else -1 total_time = 0 @@ -745,9 +806,7 @@ def loss_fn_wrap(Z, T, use_gpu, device): total_iter = 0 # testing - print_ts = ((args.test_freq > 0) and - (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches))) - if print_ts: + if print_ts and not args.inference_only: test_accu = 0 test_loss = 0 @@ -781,13 +840,40 @@ def loss_fn_wrap(Z, T, use_gpu, device): gL_test = test_loss / nbatches_test gA_test = test_accu / nbatches_test + is_best = gA_test > best_gA_test + if is_best: + best_gA_test = gA_test + if not (args.save_model == ""): + print("Saving model to {}".format(args.save_model)) + torch.save( + { + "epoch": k, + "nepochs": args.nepochs, + "nbatches": nbatches, + "nbatches_test": nbatches_test, + "iter": j + 1, + "state_dict": dlrm.state_dict(), + "train_acc": gA, + "train_loss": gL, + "test_acc": gA_test, + "test_loss": gL_test, + "total_loss": total_loss, + "total_accu": total_accu, + "opt_state_dict": optimizer.state_dict(), + }, + args.save_model, + ) + print( "Testing at - {}/{} of epoch {}, ".format(j + 1, nbatches, 0) - + "loss {:.6f}, accuracy {:3.3f} %".format( - gL_test, gA_test * 100 + + "loss {:.6f}, accuracy {:3.3f} %, best {:3.3f} %".format( + gL_test, gA_test * 100, best_gA_test * 100 ) ) + j += 1 # nbatches + k += 1 # nepochs + # profiling if args.enable_profiling: with open("dlrm_s_pytorch.prof", "w") as prof_f: