Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
diego-escobedo committed May 9, 2022
1 parent 3f51ac8 commit 43d6926
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 29 deletions.
Binary file modified __pycache__/parse.cpython-38.pyc
Binary file not shown.
Binary file modified __pycache__/replay_utils.cpython-38.pyc
Binary file not shown.
Binary file modified gymEnv/wikiGame/envs/__pycache__/wikiGame.cpython-38.pyc
Binary file not shown.
18 changes: 11 additions & 7 deletions gymEnv/wikiGame/envs/wikiGame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import numpy as np
import pandas as pd
import networkx as nx
import pickle

def create_wiki_graph(graph_source):
df = pd.read_csv(graph_source, sep='\t', header=0)
df = df[df['page_id_from'] != df['page_id_to']]
g = nx.from_pandas_edgelist(df, source='page_title_from', target='page_title_to', create_using=nx.DiGraph)
remove = (node for node in list(g) if g.out_degree(node) == 0)
g.remove_nodes_from(remove)
return g

class wikiGame(gym.Env):
Expand All @@ -22,15 +25,17 @@ class wikiGame(gym.Env):
"""
metadata = {'render.modes': ['human', 'graph', 'interactive']}

def __init__(self, has_fixed_dest_node=False, fixed_dest_node='Statistical Theory', wiki_year=2006):
graph_file_txt = f"gymEnv/wikiGame/envs/wikiGraph_{wiki_year}.xml.gz"
def __init__(self, has_fixed_dest_node=False, fixed_dest_node='Massachusetts Institute of Technology', wiki_year=2006):
graph_file_txt = f"gymEnv/wikiGame/envs/wikiGraph_{wiki_year}.gpickle"
graph_file = Path(graph_file_txt)
if graph_file.is_file():
self.graph = nx.read_graphml(graph_file_txt)
print("loading graph file")
self.graph = nx.read_gpickle(graph_file_txt)
else:
print("creating graph file")
graph_source_text = f"gymEnv/wikiGame/envs/enwiki.wikilink_graph.{wiki_year}-03-01.csv.gz"
self.graph = create_wiki_graph(graph_source_text)
nx.write_graphml(self.graph, graph_file_txt)
nx.write_gpickle(self.graph, graph_file_txt)

self.current_vertex, self.goal_vertex = None, None
self.has_fixed_dest_node = has_fixed_dest_node
Expand All @@ -56,13 +61,12 @@ def step(self, action):
if self.goal_vertex == self.current_vertex:
reward = 1
done = 1
return None, reward, done, {"next_neighbors": self.graph.successors(self.current_vertex)} #no observations, this is an MDP not POMDP
return None, reward, done, {} #no observations, this is an MDP not POMDP

def reset(self):
self.goal_vertex = self.fixed_dest_node if self.has_fixed_dest_node else np.random.choice(self.graph.nodes(), 1)[0]
self.current_vertex = np.random.choice(self.graph.nodes(), 1)[0]
while self.current_vertex == self.goal_vertex:
self.current_vertex = np.random.choice(self.graph.nodes(), 1)[0]
return self.current_vertex, \
self.goal_vertex, \
self.graph.successors(self.current_vertex)
self.goal_vertex
48 changes: 30 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def evaluate_expected_rewards(policy_net, state, goal_state_embedding, possible_
rewards[i] = x
return rewards

def select_action(policy_net, state, goal_state_embedding, possible_actions, eps_threshold):
def select_action(policy_net, state, goal_state_embedding, eps_threshold, possible_actions):
sample = random.random()
if sample > eps_threshold:
with torch.no_grad():
Expand All @@ -81,23 +81,29 @@ def select_action(policy_net, state, goal_state_embedding, possible_actions, eps
out = possible_actions[max_reward_ix]
return out
else:
out = np.random.choice(list(possible_actions), 1)
out = np.random.choice(list(possible_actions), 1)[0]
return out

def optimize_model(args, memory, policy_net, target_net, optimizer):
def optimize_model(env, args, memory, policy_net, target_net, optimizer):
if len(memory) < args.batch_size:
return
transitions = memory.sample(args.batch_size)

losses = []
loss_fn = nn.SmoothL1Loss()
for state, cur_possible_actions, _, next_state, next_possible_actions, reward, goal_state_embedding in transitions:
cur_reward_vector, _ = evaluate_expected_rewards(policy_net, state, goal_state_embedding, cur_possible_actions)
expected_reward_vector, _ = evaluate_expected_rewards(target_net, next_state, goal_state_embedding, next_possible_actions)
for state, _, next_state, reward, goal_state_embedding in transitions:
cur_possible_actions = list(env.graph.successors(state))
cur_reward_vector = evaluate_expected_rewards(policy_net, state, goal_state_embedding, cur_possible_actions)
next_possible_actions = list(env.graph.successors(next_state))
expected_reward_vector = evaluate_expected_rewards(target_net, next_state, goal_state_embedding, next_possible_actions)
try:
future_val = reward + args.gamma * expected_reward_vector.max()
temporal_diff = loss_fn(cur_reward_vector.max(), future_val)
losses.append(temporal_diff)
except:
print(state, type(cur_possible_actions), len(cur_possible_actions),next_state, type(next_possible_actions), len(next_possible_actions))
continue

future_val = reward + args.gamma * expected_reward_vector.max()
temporal_diff = loss_fn(cur_reward_vector.max(), future_val)
losses.append(temporal_diff)
loss = sum(losses)

# Optimize the model
Expand All @@ -112,31 +118,37 @@ def train(args, env, memory, policy_net, target_net, optimizer):
steps_done = 0
for i_episode in tqdm(range(args.num_episodes)):
# Initialize the environment and state
state, goal_state, possible_actions = env.reset()
state, goal_state = env.reset()
goal_state_embedding = get_neural_embedding(goal_state)
for t in tqdm(range(args.max_ep_length), position=0, leave=True):
# Select and perform an action
eps_threshold = args.eps_end + (args.eps_start - args.eps_end) * math.exp(-1. * steps_done / args.eps_decay)
action = select_action(policy_net, state, goal_state_embedding, possible_actions, eps_threshold)
possible_actions = list(env.graph.successors(state))

if sum(1 for _ in possible_actions) == 0:
episode_durations.append(25)
plot_durations(episode_durations)
break
action = select_action(policy_net, state, goal_state_embedding, eps_threshold, possible_actions)

steps_done += 1
_, reward, done, info_dict = env.step(action)
possible_actions = info_dict['next_neighbors']
_, reward, done, _ = env.step(action)


next_state = action #by virtue of deterministic observed transitions

reward = torch.tensor([reward], device=device)

# Store the transition in memory
#print("PUSH TO MEM", state, next_state, goal_state, reward,)
next_state = action #by virtue of deterministic observed transitions
next_possible_actions = env.graph.successors(next_state)
memory.push(state, tuple(possible_actions), action, next_sstate, tuple(next_possible_actions), reward, goal_state_embedding)

memory.push(state, action, next_state, reward, goal_state_embedding)

# Move to the next state
state = next_state

# Perform one step of the optimization (on the policy network)
# print("about to optimize model", flush=True)
optimize_model(args, memory, policy_net, target_net, optimizer)
optimize_model(env, args, memory, policy_net, target_net, optimizer)
if done or t == args.max_ep_length-1:
episode_durations.append(t + 1)
plot_durations(episode_durations)
Expand Down
6 changes: 3 additions & 3 deletions parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
parser.add_argument("--save", type=str, default="../results/tmp")
parser.add_argument("--max_ep_length", type=int, default=25)
parser.add_argument("--num_episodes", type=int, default=300)
parser.add_argument("--buffer_capacity", type=int, default=1000)
parser.add_argument("--buffer_capacity", type=int, default=10000)
parser.add_argument("--wiki_year", type=int, default=2006)

#Hyperparams
Expand All @@ -17,7 +17,7 @@
parser.add_argument("--target_update", type=int, default=10)

#QNetwork Params
parser.add_argument("--state_size", type=int, default=1024*3)
parser.add_argument("--state_size", type=int, default=768*3)
parser.add_argument("--fc1_units", type=int, default=1024)
parser.add_argument("--fc2_units", type=int, default=256)

Expand All @@ -26,5 +26,5 @@

#gym params
parser.add_argument("--has_fixed_dest_node", type=bool, default=False)
parser.add_argument("--fixed_dest_node", type=str, default="Applied Statistics")
parser.add_argument("--fixed_dest_node", type=str, default="Massachusetts Institute of Technology")

2 changes: 1 addition & 1 deletion replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


Transition = namedtuple('Transition',
('state', 'state_possible_actions' 'action', 'next_state', 'next_state_possible_actions', 'reward', 'goal_state_embedding'))
('state', 'action', 'next_state', 'reward', 'goal_state_embedding'))


class ReplayMemory(object):
Expand Down

0 comments on commit 43d6926

Please sign in to comment.