Skip to content

Commit

Permalink
Save checkpoints to dedicated directory
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettgibo committed Mar 18, 2021
1 parent ae363aa commit 7029bcc
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
*.mp3
*.wav

# PyTorch checkpoints
*.pt

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
7 changes: 4 additions & 3 deletions config/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
},
"train_cfg": {
"model_name": "wavenet",
"num_epochs": 5,
"num_epochs": 100,
"learning_rate": 0.001,
"resume": false,
"checkpoint_path": null,
"save_every": 1
"checkpoint_load_path": null,
"checkpoint_save_path": "checkpoints",
"save_every": 5
},
"generator_cfg": {}
}
17 changes: 9 additions & 8 deletions wavenet/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def train(
num_epochs: int,
learning_rate: float,
resume: bool = False,
checkpoint_path: str = None,
checkpoint_load_path: str = None,
checkpoint_save_path: str = None,
save_every: int = 1,
log_level: int = 30,
):
Expand All @@ -37,7 +38,7 @@ def train(
"""
logger = utils.new_logger("Train", level=log_level)

current_time = time.strftime("%m_%d_%y_%H_%M_%S", time.localtime())
current_time = time.strftime("_%m_%d_%y_%H_%M_%S", time.localtime())
model_name = model_name + current_time

criterion = nn.CrossEntropyLoss()
Expand All @@ -52,8 +53,8 @@ def train(
# load checkpoint if needed/ wanted
start_epoch = 0
if resume:
model, optim, start_epoch = utils.load_model(model, optim, checkpoint_path)
logger.info("Loaded checkpoint from: %s", checkpoint_path)
model, optim, start_epoch = utils.load_model(model, optim, checkpoint_load_path)
logger.info("Loaded checkpoint from: %s", checkpoint_load_path)

# Main training loop
for epoch in range(start_epoch, start_epoch + num_epochs):
Expand Down Expand Up @@ -93,11 +94,11 @@ def train(
# maybe do a test pass every N=1 epochs
if epoch % save_every == save_every - 1:
ckpt_name = f"{model_name}_epoch_{epoch}.pt"
utils.save_model(model, optim, epoch, ckpt_name)
logger.info("Saved model at epoch: %d", epoch)
utils.save_model(
model, optim, epoch, checkpoint_save_path, ckpt_name, logger
)

ckpt_name = f"{model_name}_fin.pt"
utils.save_model(model, optim, epoch, ckpt_name)
logger.info("Saved Final Model")
utils.save_model(model, optim, epoch, checkpoint_save_path, ckpt_name, logger)

return ckpt_name
11 changes: 9 additions & 2 deletions wavenet/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,29 @@ def load_model(model, optimizer, path: str):
return model, optimizer, epoch


def save_model(model, optimizer, epoch: int, name: str) -> None:
def save_model(
model, optimizer, epoch: int, checkpoint_dir: str, name: str, logger
) -> None:
"""Wrapper around saving a PyTorch model.
Args:
model: The model that is saved.
optimizer: The optimizer being used on the provided model
epoch: Current epoch number
checkpoint_dir: Path to directory of checkpoints
name: Name of checkpoint
"""
if not os.path.isdir(checkpoint_dir):
os.mkdir(checkpoint_dir)

checkpoint = {
"epoch": epoch,
"net": model.state_dict(),
"optim": optimizer.state_dict(),
}
torch.save(checkpoint, name)
torch.save(checkpoint, f"{checkpoint_dir}/{name}")
logger.info("Saved %s", name)


def get_device():
Expand Down

0 comments on commit 7029bcc

Please sign in to comment.