Skip to content

Commit

Permalink
added model saving
Browse files Browse the repository at this point in the history
  • Loading branch information
BaderTim committed Jan 24, 2024
1 parent 9f2cfa4 commit 5eb70b3
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions baselines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,25 @@ def __init__(self, patience=5):
"""
self.patience = patience
self.counter = 0
self.max_val_pos_mse = np.Inf
self.max_val_rot_mse = np.Inf
self.val_mse = np.Inf

def __call__(self, val_pos_mse, val_rot_mse):
def __call__(self, val_mse):
"""
Call the early stopping function.
Parameters:
- val_pos_mse: float, validation position MSE.
- val_rot_mse: float, validation rotation MSE.
- val_mse: float, validation position MSE.
Returns:
- early_stop: bool, whether to stop the training or not.
"""
if val_pos_mse >= self.max_val_pos_mse and val_rot_mse >= self.max_val_rot_mse:
if val_mse >= self.val_mse:
self.counter += 1
if self.counter >= self.patience:
return True
else:
self.counter = 0
self.max_val_pos_mse = val_pos_mse
self.max_val_rot_mse = val_rot_mse
self.val_mse = val_mse
return False


Expand Down Expand Up @@ -210,6 +207,7 @@ def train_model():
early_stopping = EarlyStopping(patience=10)

# loop over epochs
lowest_val_mse = np.Inf
for epoch in range(epochs):

# for graph models, we need to unpack one more element from the dataloader
Expand Down Expand Up @@ -296,12 +294,15 @@ def train_model():
"val_pos_mse": avg_pos_mse_val,
"val_rot_mse": avg_rot_mse_val
})
if early_stopping(avg_pos_mse_val, avg_rot_mse_val):
avg_val_mse = (avg_pos_mse_val + avg_rot_mse_val)/2
# save model
if avg_val_mse < lowest_val_mse:
torch.save(model.state_dict(), f"models/{wandb.run.name}.pt")
lowest_val_mse = avg_val_mse
if early_stopping(avg_val_mse):
log.info("Early stopping")
break
scheduler.step((avg_pos_mse_val + avg_rot_mse_val) / 2)
# save model
torch.save(model.state_dict(), f"models/{wandb.run.name}.pt")
scheduler.step(avg_val_mse)


if __name__ == "__main__":
Expand All @@ -320,13 +321,13 @@ def train_model():
"metric": {"name": "val_mse", "goal": "maximize"},
"method": "grid",
"parameters": {
"model_name": {"values": ["ASTGCN", "TSViTcls", "CNN", "NTU"]},
"model_name": {"values": ["TSViTcls", "CNN", "NTU", "ASTGCN"]},
"graph_dataset_path": {"values": [args.graph_dataset_path]},
"image_dataset_path": {"values": [args.image_dataset_path]},
"point_dataset_path": {"values": [args.point_dataset_path]},
"epochs": {"values": [30]},
"batch_size": {"values": [8]},
"K": {"values": [2, 4, 8, 16]}
"K": {"values": [16, 8, 4, 2]}
}
}
sweep_id = wandb.sweep(sweep_configuration)
Expand Down

0 comments on commit 5eb70b3

Please sign in to comment.