Skip to content

Commit

Permalink
save when training complete
Browse files Browse the repository at this point in the history
  • Loading branch information
ritaank committed May 10, 2022
1 parent d5264f7 commit 5b84e84
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
16 changes: 12 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@
*.rar
*.tar
*.zip
*.csv
*.gpickle
*.graphml
*.dat

# Logs and databases #
######################
*.log
*.sql
*.sqlite
*.csv
*.gpickle
*.graphml

# Models #
######################
*.pt

# Temporary Files #
######################
*.pyc
*.dat

# OS generated files #
######################
Expand Down
16 changes: 14 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import warnings
import time
import datetime
from pprint import pprint

import matplotlib
Expand Down Expand Up @@ -137,7 +138,7 @@ def train(args, env, memory, policy_net, target_net, optimizer):
reward = torch.tensor([reward], device=device)

# Store the transition in memory

memory.push(state, action, next_state, reward, goal_state_embedding.detach())

# Move to the next state
Expand All @@ -155,8 +156,19 @@ def train(args, env, memory, policy_net, target_net, optimizer):
target_net.load_state_dict(policy_net.state_dict())

print('Training Complete')
env.render()
# env.render()

save_dict = {'state_dict': target_net.state_dict(), 'args': args}
dest_path = f"models/{args.wiki_year}_fixednode-{args.has_fixed_dest_node}_{datetime.datetime.now().strftime('%Y_%m_%d-%I:%M:%S_%p')}.pt"
torch.save(save_dict, dest_path)
print('Model saved to location ', dest_path)

env.close()
# torch.save(model.state_dict(), filepath)

# #Later to restore:
# model.load_state_dict(torch.load(filepath))
# model.eval()

def plot_durations(episode_durations):
plt.figure(2)
Expand Down

0 comments on commit 5b84e84

Please sign in to comment.