Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
Signed-off-by: SumanthRH <[email protected]>
  • Loading branch information
SumanthRH committed Feb 7, 2025
1 parent 24961c0 commit b726e1f
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions skythought/skythought_evals/util/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import math

from collections import defaultdict
import numpy as np


Expand All @@ -10,17 +10,19 @@ def pass_at_k(N, temp_to_scores):
pass_values = {} # temp -> value
for temp in temp_to_scores:
scores = temp_to_scores[temp] # dict mapping idx -> list of scores
k = N
final_passk_scores = {}
while k > 0:
new_scores = []
for _, sample_scores in scores.items():
k_to_passk_scores = defaultdict(list) # k -> list of scores
for _, sample_scores in scores.items():
k = N
while k > 0:
# calculate pass @ k
num_correct = np.sum(sample_scores)
pass_k = 1 - (math.comb(N - num_correct, k) / math.comb(N, k))
new_scores.append(pass_k)
final_passk_scores[f"{k=}"] = round(np.mean(new_scores) * 100, 3)
k = k // 2
k_to_passk_scores[k].append(pass_k)
k = k // 2

for k in k_to_passk_scores:
final_passk_scores[f"{k=}"] = round(np.mean(k_to_passk_scores[k]) * 100, 3)

# print("Final pass @ k:")
for k, s in final_passk_scores.items():
Expand Down

0 comments on commit b726e1f

Please sign in to comment.