Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ritaank committed May 11, 2022
1 parent 7f0d94e commit d65a8a5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
25 changes: 15 additions & 10 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

eval_parser = argparse.ArgumentParser("WALDO evaluation")
eval_parser.add_argument('-p','--path', type=str, help='path to trained model (.pt) file', required=True)
eval_parser.add_argument('--num_tests', type=int, help='how many tests to run?', default=5)
eval_parser.add_argument('--dist_levels', type=list, help='what levels to run tests at', default=[2])
eval_parser.add_argument('--num_tests', type=int, help='how many tests to run?', default=201)
eval_parser.add_argument('--dist_levels', type=list, help='what levels to run tests at', default=[1,2,3,4])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# PATH = '/content/wikigame/models/2006_fixednode-True_2022_05_10-12_07_38_PM.pt'
Expand Down Expand Up @@ -77,27 +77,26 @@ def plot_durations(episode_durations):
display.clear_output(wait=True)
display.display(plt.gcf())

def evaluate(qnet, env, args, potential_start_nodes):
def evaluate(qnet, env, args, potential_start_nodes, level):

episode_durations = []
distance_ratios = []
cos_sims = []

steps_done = 0
env = wikiGame(args)
wins = fails = reached_sink = 0

for _ in tqdm(range(args.num_tests)):
# print(f"i {i}")

source = np.random.choice(potential_start_nodes, 1)[0]
print("our source is ", source)
# print("our source is ", source)
state, goal_state = env.reset(evalMode=True, node=source)
path_taken = [state]
# print(f"starting at {state} and going to {goal_state}")

best_path_list = nx.shortest_path(env.graph, source=state, target=goal_state, weight=None, method='dijkstra')
print("best", best_path_list)
# print("best", best_path_list)
best_len = len(best_path_list)

goal_state_embedding = get_neural_embedding(goal_state)
Expand Down Expand Up @@ -127,7 +126,7 @@ def evaluate(qnet, env, args, potential_start_nodes):
cos_sim = cosine_similarity(get_neural_embedding(state).cpu().unsqueeze(0), eval_goal_state_embedding)
cos_sims.append(cos_sim - initial_cos_sim)

print(path_taken, "\n")
# print(path_taken, "\n")

fails +=1
reached_sink += 1
Expand All @@ -145,7 +144,7 @@ def evaluate(qnet, env, args, potential_start_nodes):
distance_ratios.append(0/best_len)
cos_sims.append(1 - initial_cos_sim)

print(path_taken + [goal_state], "\n")
# print(path_taken + [goal_state], "\n")

wins +=1
break
Expand All @@ -164,19 +163,21 @@ def evaluate(qnet, env, args, potential_start_nodes):
cos_sim = cosine_similarity(get_neural_embedding(state).cpu().unsqueeze(0), eval_goal_state_embedding)
cos_sims.append(cos_sim - initial_cos_sim)

print(path_taken, "\n")
# print(path_taken, "\n")

fails +=1
break

assert wins + fails == args.num_tests, f"w:{wins} and f:{fails} big bug"
print("FOR LEVEL", level)
print("settings\t", args)
print("success rate:\t", wins/args.num_tests)
print("rate we reached a dead end:\t", reached_sink/args.num_tests)
# print("distance ratios, lower is better:\n", distance_ratios)
print("average distance ratio:\t",sum(distance_ratios)/len(distance_ratios))
# print("cos sims, higher is better:\n", cos_sims)
print("avg improvement in cos sim:\t",sum(cos_sims)/len(cos_sims))
print("-----------------------")
# plot_durations(episode_durations)


Expand All @@ -199,9 +200,13 @@ def main(eval_args):
with torch.no_grad():
trained_net.eval()
env = wikiGame(args)
print(eval_args.dist_levels, type(eval_args.dist_levels), "MEEP")
nodes_by_dist = env.get_nodes_by_distances(eval_args.dist_levels) #args.tiers should be a list
for level in eval_args.dist_levels:
evaluate(trained_net, env, args, nodes_by_dist[level])
if len(nodes_by_dist[level]) == 0:
print("skipping level, no nodes, for level: ", level)
else:
evaluate(trained_net, env, args, nodes_by_dist[level], level)

env.close()

Expand Down
37 changes: 29 additions & 8 deletions gymEnv/wikiGame/envs/wikiGame.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, args):
self.max_episodes = args.num_episodes
self.full_graph = self.graph
self.last_trim_call_params = None
self.reset()
# self.reset(evalMode=True)


def render(self, mode='human'):
Expand All @@ -83,6 +83,7 @@ def step(self, action):
def reset(self, evalMode=False, node=None):
self.goal_vertex = self.bfs_center_node if self.has_fixed_dest_node else np.random.choice(self.graph.nodes(), 1)[0]
if self.expanding_bfs and not evalMode:
print("graph is expanding")
curr_bfs_dist = self.calc_bfs_dist_schedule()
if curr_bfs_dist == self.max_bfs_dist:#WE ALREADY REACHED FULL SIZE, DONT BOTHER RECALCULATING
pass
Expand Down Expand Up @@ -115,15 +116,32 @@ def trim_graph(self, graph, target, cutoff):
ret_graph = graph.subgraph(desired_nodes)
return ret_graph

def get_nodes_by_distances(self, tier_values):
lengths = nx.single_target_shortest_path_length(self.graph, self.bfs_center_node, cutoff=max(tier_values))
def get_nodes_by_distances(self, tier_values): #[2]
lengths = dict(nx.single_target_shortest_path_length(self.graph, self.bfs_center_node, cutoff=max(tier_values)+1))
tiers = {}
for key, value in lengths:
print("start")

# print(list(lengths))
# # print(type(tier_values))
# lengths = list(lengths)
# print(type(lengths))
# print("lengths again", lengths)
# print(len(lengths))
# print(lengths[0])
for key,value in lengths.items():
# print("hello there")
# print("inside")
# print(tup, type(tup))
# key,value = tup[0], tup[1]
# tiers.setdefault(value, []).extend([key])
if value in tiers and value in tier_values:
tiers[value].append(key)
else:
tiers[value]=[key]
# print(key, value)
# print(tier_values)
# print('bye')
if value in tier_values:
if value in tiers:
tiers[value].append(key)
else:
tiers[value]=[key]

# unwanted = set(tiers) - set(tier_values)
# for unwanted_key in unwanted:
Expand All @@ -135,5 +153,8 @@ def get_nodes_by_distances(self, tier_values):
print("we verify the best path between these is ")
print(nx.shortest_path(self.graph, source=tiers[test_key][0], target=self.bfs_center_node, weight=None, method='dijkstra'))

print("stats")
for level, amt in tiers.items():
print("level ", level, "nodecount ", len(amt))
return tiers

0 comments on commit d65a8a5

Please sign in to comment.