Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
Signed-off-by: SumanthRH <[email protected]>
  • Loading branch information
SumanthRH committed Feb 6, 2025
1 parent 959869d commit b9c0260
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
28 changes: 22 additions & 6 deletions skythought/skythought_evals/inference_and_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def inference(llm, conversations, max_tokens, temp, args):
return responses


def load_existing_results(result_file):
if not os.path.exists(result_file):
return {}
with open(result_file, "r", encoding="utf-8") as f:
records = json.load(f)
return records


def perform_inference_and_check(
handler: TaskHandler,
temperatures,
Expand All @@ -159,7 +167,7 @@ def perform_inference_and_check(
model_config,
args,
):
results = handler.load_existing_results(result_file)
results = load_existing_results(result_file)
print(f"Loaded {len(results)} existing results.")
train_data = handler.load_and_filter_dataset(
args.start,
Expand All @@ -170,6 +178,9 @@ def perform_inference_and_check(
args=args,
)
remaining_data = handler.process_remaining_data(train_data, results)
if not len(remaining_data):
print("All results saved. Exiting....")
return
conversations = handler.make_conversations(
remaining_data, model_config.system_prompt, model_config.user_template
)
Expand Down Expand Up @@ -229,7 +240,9 @@ def perform_inference_and_check(
prompt = conversations[idx][1]["content"]
results[problem_key]["prompt"] = prompt
results[problem_key]["input_conversation"] = conversations[idx]
temperature_to_scores[temp][idx] = [0 for _ in range(args.n)]
temperature_to_scores[temp][problem_key] = [
0 for _ in range(args.n)
]

if str(temp) not in results[problem_key]["responses"]:
results[problem_key]["responses"][str(temp)] = [
Expand All @@ -244,7 +257,7 @@ def perform_inference_and_check(
results[problem_key]["token_usages"][str(temp)] = token_usages[idx]

# update scores
temperature_to_scores[temp][idx][sample_idx] = response_entry[
temperature_to_scores[temp][problem_key][sample_idx] = response_entry[
"correctness"
]

Expand Down Expand Up @@ -310,7 +323,7 @@ def perform_inference_and_check(


def perform_check(handler: TaskHandler, temperatures, result_file, args):
results = handler.load_existing_results(result_file)
results = load_existing_results(result_file)
print(f"Loaded {len(results)} existing results.")

train_data = handler.load_and_filter_dataset(
Expand Down Expand Up @@ -412,7 +425,7 @@ def perform_inference_and_save(
model_config,
args,
):
results = handler.load_existing_results(result_file)
results = load_existing_results(result_file)
print(f"Loaded {len(results)} existing results.")
train_data = handler.load_and_filter_dataset(
args.start,
Expand All @@ -423,6 +436,9 @@ def perform_inference_and_save(
args=args,
)
remaining_data = handler.process_remaining_data(train_data, results)
if not len(remaining_data):
print("All results saved. Exiting...")
return
conversations = handler.make_conversations(
remaining_data, model_config.system_prompt, model_config.user_template
)
Expand Down Expand Up @@ -651,7 +667,7 @@ def main():
os.makedirs(args.result_dir)
temperature_str = ",".join(map(str, temperatures))
file_suffix = f"{model_config.name}_{args.task}_{args.split}_subset_{args.subset}_filter_{args.filter_difficulty}"
f"_s{args.start}_e{args.end}_t{temperature_str}"
f"_s{args.start}_e{args.end}_t{temperature_str}_n{args.n}"
if (
args.math_difficulty_lower_bound is not None
or args.math_difficulty_upper_bound is not None
Expand Down
9 changes: 0 additions & 9 deletions skythought/skythought_evals/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -98,13 +96,6 @@ def make_conversation_from_contents(
conversation.append({"role": "assistant", "content": content})
return conversation

def load_existing_results(self, result_file):
if not os.path.exists(result_file):
return {}
with open(result_file, "r", encoding="utf-8") as f:
records = json.load(f)
return records

def load_dataset(self, subset=None, split=None, **kwargs) -> HFDataset:
dataset = load_dataset(
path=self.task_config.dataset_path,
Expand Down

0 comments on commit b9c0260

Please sign in to comment.