diff --git a/scripts/ceval/eval.py b/scripts/ceval/eval.py new file mode 100644 index 0000000..60b625c --- /dev/null +++ b/scripts/ceval/eval.py @@ -0,0 +1,120 @@ +# This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval + +import os +import argparse +import pandas as pd +import torch +import json +from mixtral_evaluator import Mixtral_Evaluator + +import time +choices = ["A", "B", "C", "D"] + +def main(args, evaluator,take): + assert os.path.exists("subject_mapping.json"), "subject_mapping.json not found!" + with open("subject_mapping.json") as f: + subject_mapping = json.load(f) + filenames = os.listdir("data/val") + subject_list = [val_file.replace("_val.csv","") for val_file in filenames] + accuracy, summary = {}, {} + + run_date=time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) + output_dir = args.output_dir + save_result_dir=os.path.join(output_dir,f"take{take}") + if not os.path.exists(save_result_dir): + os.makedirs(save_result_dir,exist_ok=True) + + all_answers = {} + for index,subject_name in enumerate(subject_list): + print(f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_path} with subject of {subject_name}!") + val_file_path=os.path.join('data/val',f'{subject_name}_val.csv') + dev_file_path=os.path.join('data/dev',f'{subject_name}_dev.csv') + test_file_path=os.path.join('data/test',f'{subject_name}_test.csv') + + val_df=pd.read_csv(val_file_path) if args.do_test is False else pd.read_csv(test_file_path) + dev_df=pd.read_csv(dev_file_path) if args.few_shot else None + + correct_ratio, answers = evaluator.eval_subject(subject_name, val_df, dev_df, + save_result_dir=save_result_dir if args.do_save_csv else None, + few_shot=args.few_shot, + cot=args.cot, + with_prompt=args.with_prompt, + constrained_decoding=args.constrained_decoding, + do_test=args.do_test) + print(f"Subject: {subject_name}") + print(f"Acc: {correct_ratio}") + accuracy[subject_name] = correct_ratio + summary[subject_name] = {"score":correct_ratio, + "num":len(val_df), + "correct":correct_ratio*len(val_df)/100} + all_answers[subject_name] = answers + + json.dump(all_answers,open(save_result_dir+'/submission.json','w'),ensure_ascii=False,indent=4) + print("Accuracy:") + for k, v in accuracy.items(): + print(k, ": ", v) + + + total_num = 0 + total_correct = 0 + summary['grouped'] = { + "STEM": {"correct": 0.0, "num": 0}, + "Social Science": {"correct": 0.0, "num": 0}, + "Humanities": {"correct": 0.0, "num": 0}, + "Other": {"correct": 0.0, "num": 0} + } + for subj, info in subject_mapping.items(): + group = info[2] + summary['grouped'][group]["num"] += summary[subj]['num'] + summary['grouped'][group]["correct"] += summary[subj]['correct'] + for group, info in summary['grouped'].items(): + info['score'] = info["correct"] / info["num"] + total_num += info["num"] + total_correct += info["correct"] + summary['All'] = {"score": total_correct / total_num, "num": total_num, "correct": total_correct} + + json.dump(summary,open(save_result_dir+'/summary.json','w'),ensure_ascii=False,indent=2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str) + parser.add_argument("--cot",choices=["False","True"], default="False") + parser.add_argument("--few_shot", choices=["False","True"], default="True") + parser.add_argument("--ntrain", "-k", type=int, default=5) + parser.add_argument("--with_prompt", choices=["False","True"], default="False") + parser.add_argument("--constrained_decoding", choices=["False","True"], default="True") + parser.add_argument("--temperature",type=float,default=0.2) + parser.add_argument("--n_times", default=1,type=int) + parser.add_argument("--do_save_csv", choices=["False","True"], default="False") + parser.add_argument("--output_dir", type=str) + parser.add_argument("--do_test", choices=["False","True"], default="False") + parser.add_argument("--verbose", action="store_true", help="Print detailed information of each example.") + parser.add_argument("--load_in_4bit", action="store_true", help="The model was loaded by 4-bit quantization") + parser.add_argument("--use_flash_attention_2", action="store_true", help="Use flash_attention2 to replace the mixtral attention") + args = parser.parse_args() + + args.cot = args.cot == "True" + args.few_shot = args.few_shot == "True" + args.with_prompt = args.with_prompt == "True" + args.constrained_decoding = args.constrained_decoding == "True" + args.do_test = args.do_test == "True" + args.do_save_csv = args.do_save_csv == "True" + if args.constrained_decoding is True: + args.n_times=max(args.n_times,1) + print(args) + + device = torch.device(0) + print(device) + evaluator=Mixtral_Evaluator( + choices=choices, + k=args.ntrain, + model_path=args.model_path, + device=device, + temperature=args.temperature, + load_in_4bit=args.load_in_4bit, + use_flash_attention_2=args.use_flash_attention_2, + verbose=args.verbose + ) + for i in range(args.n_times): + main(args,evaluator=evaluator,take=i) diff --git a/scripts/ceval/evaluator.py b/scripts/ceval/evaluator.py new file mode 100644 index 0000000..691af6f --- /dev/null +++ b/scripts/ceval/evaluator.py @@ -0,0 +1,47 @@ +# This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval + +import string +class Evaluator: + def __init__(self, choices, model_name, k=-1): + self.choices = choices + self.model_name = model_name + self.k = k + self.puncs = list(string.punctuation) + + def format_example(self, line, include_answer=True): + example = line['question'] + for choice in self.choices: + example += f'\n{choice}. {line[f"{choice}"]}' + example += '\n答案:' + if include_answer: + example += f'{line["answer"]}\n\n' + return example + + def generate_few_shot_prompt(self, subject, dev_df): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + prompt += self.format_example(dev_df.iloc[i, :]) + return prompt + + def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, save_result_dir=None): + pass + + def normalize_answer(self,s): + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude=set(self.puncs) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_punc(lower(s))) + + def exact_match(self,pred, target): + return self.normalize_answer(pred)==self.normalize_answer(target) diff --git a/scripts/ceval/mixtral_evaluator.py b/scripts/ceval/mixtral_evaluator.py new file mode 100644 index 0000000..bdf943b --- /dev/null +++ b/scripts/ceval/mixtral_evaluator.py @@ -0,0 +1,240 @@ +# This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval + +import os +import re +from tqdm import tqdm +import random +import numpy as np +import torch +from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig +from transformers import GenerationConfig +from evaluator import Evaluator + + +class Mixtral_Evaluator(Evaluator): + def __init__(self, choices, k, model_path, device, temperature=0.2, load_in_4bit=False, use_flash_attention_2=False, verbose=False): + super(Mixtral_Evaluator, self).__init__(choices, model_path, k) + load_type = torch.float16 + self.model_path = model_path + self.device = device + self.verbose = verbose + self.load_in_4bit = load_in_4bit + self.use_flash_attention_2 = use_flash_attention_2 + self.tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + bnb_4bit_compute_dtype=load_type, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + quantization_config=quantization_config if self.load_in_4bit else None, + torch_dtype=load_type, + low_cpu_mem_usage=True, + device_map='auto', + attn_implementation="flash_attention_2" if self.use_flash_attention_2 else "sdpa" + ) + self.generation_config = GenerationConfig( + temperature=temperature, + top_k=40, + top_p=0.9, + do_sample=True, + num_beams=1, + repetition_penalty=1.1, + max_new_tokens=20 + ) + + self.sA_id = self.tokenizer.encode("A", add_special_tokens=False)[0] + self.sB_id = self.tokenizer.encode("B", add_special_tokens=False)[0] + self.sC_id = self.tokenizer.encode("C", add_special_tokens=False)[0] + self.sD_id = self.tokenizer.encode("D", add_special_tokens=False)[0] + self.A_id = self.tokenizer.encode(":A")[-1] + self.B_id = self.tokenizer.encode(":B")[-1] + self.C_id = self.tokenizer.encode(":C")[-1] + self.D_id = self.tokenizer.encode(":D")[-1] + + + def eval_subject(self, subject_name, + test_df, + dev_df=None, + few_shot=False, + cot=False, + save_result_dir=None, + with_prompt=False, + constrained_decoding=False, + do_test=False): + all_answers = {} + if constrained_decoding is True: + self.generation_config.output_scores = True + self.generation_config.return_dict_in_generate = True + self.generation_config.max_new_tokens = 1 + self.generation_config.top_p = 1.0 + self.generation_config.top_k = 0 + + correct_num = 0 + if save_result_dir: + result = [] + score = [] + if few_shot: + if with_prompt: + history = self.generate_mixtral_inst_few_shot_prompt(subject_name, dev_df, cot=cot) + else: + history = self.generate_mixtral_few_shot_prompt(subject_name, dev_df, cot=cot) + else: + history = '' + answers = ['NA'] * len(test_df) if do_test is True else list(test_df['answer']) + for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)): + question = self.format_example(row, include_answer=False, cot=cot,with_prompt=with_prompt) + instruction = question + if with_prompt: + prompt_template = ( + "[INST] {instruction} [/INST]" + ) + + instruction = prompt_template.format_map({'instruction': instruction}) + instruction = history + instruction + inputs = self.tokenizer(instruction, return_tensors="pt") + generation_output = self.model.generate( + input_ids = inputs["input_ids"].to(self.device), + attention_mask = inputs['attention_mask'].to(self.device), + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + generation_config = self.generation_config + ) + + _, length = inputs.input_ids.shape + if constrained_decoding is True: + logits = generation_output.scores[0][0] + + logits = logits.float().cpu().detach() + choices1_logits = logits[[self.sA_id,self.sB_id,self.sC_id,self.sD_id]] + choices2_logits = logits[[self.A_id,self.B_id,self.C_id,self.D_id]] + choicesAll_logits = (choices1_logits + choices2_logits).numpy() + assert not (np.any(np.isinf(choicesAll_logits)) or np.any(np.isnan(choicesAll_logits))) + ans = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(choicesAll_logits)] + response = self.tokenizer.decode([logits.argmax(-1).item()]) + else: + response = self.tokenizer.decode(generation_output[0, length:], skip_special_tokens=True) + ans, _ = self.extract_answer(row, response) + if ans == answers[row_index]: + correct_num += 1 + correct = 1 + else: + correct = 0 + if self.verbose is True: + print(f"\n======={str(row_index)}=======") + print(f"question: {question}\n") + print(f"response: {response}\n") + print(f"extracted answer: {ans}") + print(f"ground truth: {answers[row_index]} \n") + if save_result_dir: + result.append(response) + score.append(correct) + + all_answers[str(row_index)] = ans + + correct_ratio = 100*correct_num/len(answers) + + if save_result_dir: + test_df['model_output'] = result + test_df['correctness'] = score + test_df.to_csv(os.path.join(save_result_dir, f'{subject_name}_test.csv')) + + return correct_ratio, all_answers + + def format_example(self, line, include_answer=True, cot=False, with_prompt=False): + example = line['question'] + for choice in self.choices: + example += f'\n{choice}. {line[f"{choice}"]}' + if include_answer: + if cot: + example += "\n答案:让我们一步一步思考,\n" + \ + line["explanation"] + f"\n所以答案是{line['answer']}。\n\n" + else: + example += '\n答案:' + line["answer"] + '\n\n' + else: + if with_prompt is False: + if cot: + example += "\n答案:让我们一步一步思考,\n1." + else: + example += '\n答案:' + else: + if cot: + example += "\n答案是什么?让我们一步一步思考,\n1." + else: + example += '\n答案:' + return example + + def generate_mixtral_few_shot_prompt(self, subject, dev_df, cot=False): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + prompt += self.format_example( + dev_df.iloc[i, :], + include_answer=True, + cot=cot + ) + return prompt + + def generate_mixtral_inst_few_shot_prompt(self, subject, dev_df, cot=False): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + prompt_template = ( + "[INST] {instruction} [/INST]好的,我会结合{subject}相关知识回答" + ) + + prompt = prompt_template.format_map({'instruction':prompt, 'subject':subject}) + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + line = dev_df.iloc[i, :] + q=line['question'] + for choice in self.choices: + q += f'\n{choice}. {line[f"{choice}"]}' + + a = line['answer'] + prompt += "[INST] "+q+"\n答案: [/INST]"+a+"\n" + return prompt + + def extract_answer(self, line, gen_ans): + m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M) + if len(m) > 0 and m[-1] in self.choices: + return m[-1], True + answer_patterns = [ + r'([ABCD])是正确的', + r'选项([ABCD])正确', + r'答案为([ABCD])', + r'答案是([ABCD])', + r'答案([ABCD])', + r'选择([ABCD])', + r'答案:([ABCD])', + r'选择答案([ABCD])' + ] + # RE extraction + for answer_pattern in answer_patterns: + m = re.search(answer_pattern, gen_ans, re.M) + if m: + answer = m.group(1) + return answer, False + # only containing one choice-character + m = re.findall(r'[ABCD]', gen_ans, re.M) + if len(m) >= 1: + answer = m[0] + return answer, False + # only containing one choice-context + choices_dict = {} + pattern = "" + for c in self.choices: + choices_dict[str(line[f'{c}'])] = c + pattern += re.escape(str(line[f'{c}']))+"|" + pattern = pattern[:-1] + m = re.findall(pattern, gen_ans, re.M) + print("w/ escape:",repr(pattern),gen_ans,(len(m)>=1)) + if len(m) >= 1: + answer = choices_dict[m[0]] + return answer, False + return random.sample('ABCD', 1)[0], False diff --git a/scripts/ceval/subject_mapping.json b/scripts/ceval/subject_mapping.json new file mode 100644 index 0000000..493c0f3 --- /dev/null +++ b/scripts/ceval/subject_mapping.json @@ -0,0 +1,262 @@ +{ + "computer_network": [ + "Computer Network", + "\u8ba1\u7b97\u673a\u7f51\u7edc", + "STEM" + ], + "operating_system": [ + "Operating System", + "\u64cd\u4f5c\u7cfb\u7edf", + "STEM" + ], + "computer_architecture": [ + "Computer Architecture", + "\u8ba1\u7b97\u673a\u7ec4\u6210", + "STEM" + ], + "college_programming": [ + "College Programming", + "\u5927\u5b66\u7f16\u7a0b", + "STEM" + ], + "college_physics": [ + "College Physics", + "\u5927\u5b66\u7269\u7406", + "STEM" + ], + "college_chemistry": [ + "College Chemistry", + "\u5927\u5b66\u5316\u5b66", + "STEM" + ], + "advanced_mathematics": [ + "Advanced Mathematics", + "\u9ad8\u7b49\u6570\u5b66", + "STEM" + ], + "probability_and_statistics": [ + "Probability and Statistics", + "\u6982\u7387\u7edf\u8ba1", + "STEM" + ], + "discrete_mathematics": [ + "Discrete Mathematics", + "\u79bb\u6563\u6570\u5b66", + "STEM" + ], + "electrical_engineer": [ + "Electrical Engineer", + "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08", + "STEM" + ], + "metrology_engineer": [ + "Metrology Engineer", + "\u6ce8\u518c\u8ba1\u91cf\u5e08", + "STEM" + ], + "high_school_mathematics": [ + "High School Mathematics", + "\u9ad8\u4e2d\u6570\u5b66", + "STEM" + ], + "high_school_physics": [ + "High School Physics", + "\u9ad8\u4e2d\u7269\u7406", + "STEM" + ], + "high_school_chemistry": [ + "High School Chemistry", + "\u9ad8\u4e2d\u5316\u5b66", + "STEM" + ], + "high_school_biology": [ + "High School Biology", + "\u9ad8\u4e2d\u751f\u7269", + "STEM" + ], + "middle_school_mathematics": [ + "Middle School Mathematics", + "\u521d\u4e2d\u6570\u5b66", + "STEM" + ], + "middle_school_biology": [ + "Middle School Biology", + "\u521d\u4e2d\u751f\u7269", + "STEM" + ], + "middle_school_physics": [ + "Middle School Physics", + "\u521d\u4e2d\u7269\u7406", + "STEM" + ], + "middle_school_chemistry": [ + "Middle School Chemistry", + "\u521d\u4e2d\u5316\u5b66", + "STEM" + ], + "veterinary_medicine": [ + "Veterinary Medicine", + "\u517d\u533b\u5b66", + "STEM" + ], + "college_economics": [ + "College Economics", + "\u5927\u5b66\u7ecf\u6d4e\u5b66", + "Social Science" + ], + "business_administration": [ + "Business Administration", + "\u5de5\u5546\u7ba1\u7406", + "Social Science" + ], + "marxism": [ + "Marxism", + "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406", + "Social Science" + ], + "mao_zedong_thought": [ + "Mao Zedong Thought", + "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba", + "Social Science" + ], + "education_science": [ + "Education Science", + "\u6559\u80b2\u5b66", + "Social Science" + ], + "teacher_qualification": [ + "Teacher Qualification", + "\u6559\u5e08\u8d44\u683c", + "Social Science" + ], + "high_school_politics": [ + "High School Politics", + "\u9ad8\u4e2d\u653f\u6cbb", + "Social Science" + ], + "high_school_geography": [ + "High School Geography", + "\u9ad8\u4e2d\u5730\u7406", + "Social Science" + ], + "middle_school_politics": [ + "Middle School Politics", + "\u521d\u4e2d\u653f\u6cbb", + "Social Science" + ], + "middle_school_geography": [ + "Middle School Geography", + "\u521d\u4e2d\u5730\u7406", + "Social Science" + ], + "modern_chinese_history": [ + "Modern Chinese History", + "\u8fd1\u4ee3\u53f2\u7eb2\u8981", + "Humanities" + ], + "ideological_and_moral_cultivation": [ + "Ideological and Moral Cultivation", + "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840", + "Humanities" + ], + "logic": [ + "Logic", + "\u903b\u8f91\u5b66", + "Humanities" + ], + "law": [ + "Law", + "\u6cd5\u5b66", + "Humanities" + ], + "chinese_language_and_literature": [ + "Chinese Language and Literature", + "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", + "Humanities" + ], + "art_studies": [ + "Art Studies", + "\u827a\u672f\u5b66", + "Humanities" + ], + "professional_tour_guide": [ + "Professional Tour Guide", + "\u5bfc\u6e38\u8d44\u683c", + "Humanities" + ], + "legal_professional": [ + "Legal Professional", + "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c", + "Humanities" + ], + "high_school_chinese": [ + "High School Chinese", + "\u9ad8\u4e2d\u8bed\u6587", + "Humanities" + ], + "high_school_history": [ + "High School History", + "\u9ad8\u4e2d\u5386\u53f2", + "Humanities" + ], + "middle_school_history": [ + "Middle School History", + "\u521d\u4e2d\u5386\u53f2", + "Humanities" + ], + "civil_servant": [ + "Civil Servant", + "\u516c\u52a1\u5458", + "Other" + ], + "sports_science": [ + "Sports Science", + "\u4f53\u80b2\u5b66", + "Other" + ], + "plant_protection": [ + "Plant Protection", + "\u690d\u7269\u4fdd\u62a4", + "Other" + ], + "basic_medicine": [ + "Basic Medicine", + "\u57fa\u7840\u533b\u5b66", + "Other" + ], + "clinical_medicine": [ + "Clinical Medicine", + "\u4e34\u5e8a\u533b\u5b66", + "Other" + ], + "urban_and_rural_planner": [ + "Urban and Rural Planner", + "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", + "Other" + ], + "accountant": [ + "Accountant", + "\u6ce8\u518c\u4f1a\u8ba1\u5e08", + "Other" + ], + "fire_engineer": [ + "Fire Engineer", + "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", + "Other" + ], + "environmental_impact_assessment_engineer": [ + "Environmental Impact Assessment Engineer", + "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", + "Other" + ], + "tax_accountant": [ + "Tax Accountant", + "\u7a0e\u52a1\u5e08", + "Other" + ], + "physician": [ + "Physician", + "\u533b\u5e08\u8d44\u683c", + "Other" + ] +} \ No newline at end of file diff --git a/scripts/cmmlu/categories.py b/scripts/cmmlu/categories.py new file mode 100644 index 0000000..aa7b3cb --- /dev/null +++ b/scripts/cmmlu/categories.py @@ -0,0 +1,148 @@ +# This code is modified from CMMLU Project: https://github.com/haonan-li/CMMLU +name_en2zh = { + "agronomy": "农学", + "anatomy": "解剖学", + "ancient_chinese": "古汉语", + "arts": "艺术学", + "astronomy": "天文学", + "business_ethics": "商业伦理", + "chinese_civil_service_exam": "中国公务员考试", + "chinese_driving_rule": "中国驾驶规则", + "chinese_food_culture": "中国饮食文化", + "chinese_foreign_policy": "中国外交政策", + "chinese_history":"中国历史", + "chinese_literature": "中国文学", + "chinese_teacher_qualification": "中国教师资格", + "clinical_knowledge": "临床知识", + "college_actuarial_science":"大学精算学", + "college_education":"大学教育学", + "college_engineering_hydrology": "大学工程水文学", + "college_law": "大学法律", + "college_mathematics": "大学数学", + "college_medical_statistics":"大学医学统计", + "college_medicine": "大学医学", + "computer_science": "计算机科学", + "computer_security": "计算机安全", + "conceptual_physics": "概念物理学", + "construction_project_management": "建设工程管理", + "economics": "经济学", + "education": "教育学", + "electrical_engineering": "电气工程", + "elementary_chinese":"小学语文", + "elementary_commonsense":"小学常识", + "elementary_information_and_technology": "小学信息技术", + "elementary_mathematics": "初等数学", + "ethnology": "民族学", + "food_science": "食品科学", + "genetics": "遗传学", + "global_facts": "全球事实", + "high_school_biology": "高中生物", + "high_school_chemistry": "高中化学", + "high_school_geography": "高中地理", + "high_school_mathematics": "高中数学", + "high_school_physics": "高中物理学", + "high_school_politics": "高中政治", + "human_sexuality": "人类性行为", + "international_law": "国际法学", + "journalism": "新闻学", + "jurisprudence": "法理学", + "legal_and_moral_basis": "法律与道德基础", + "logical": "逻辑学", + "machine_learning": "机器学习", + "management": "管理学", + "marketing": "市场营销", + "marxist_theory": "马克思主义理论", + "modern_chinese": "现代汉语", + "nutrition": "营养学", + "philosophy": "哲学", + "professional_accounting": "专业会计", + "professional_law": "专业法学", + "professional_medicine": "专业医学", + "professional_psychology": "专业心理学", + "public_relations": "公共关系", + "security_study":"安全研究", + "sociology": "社会学", + "sports_science": "体育学", + "traditional_chinese_medicine": "中医中药", + "virology": "病毒学", + "world_history":"世界历史", + "world_religions": "世界宗教", +} + +subcategories = { + "agronomy": ['other'], + "anatomy": ['biology'], + "ancient_chinese": ['linguistics','china specific'], + "arts": ['arts'], + "astronomy": ['physics'], + "business_ethics": ['business'], + "chinese_civil_service_exam": ['politics','china specific'], + "chinese_driving_rule": ['other','china specific'], + "chinese_food_culture": ['culture','china specific'], + "chinese_foreign_policy": ['politics','china specific'], + "chinese_history":['history','china specific'], + "chinese_literature": ['literature','china specific'], + "chinese_teacher_qualification": ['education','china specific'], + "college_actuarial_science":['math'], + "college_education":['education'], + "college_engineering_hydrology": ['engineering'], + "college_law": ['law'], + "college_mathematics": ['math'], + "college_medical_statistics":['statistics'], + "clinical_knowledge": ['other'], + "college_medicine": ['other'], + "computer_science": ['computer science'], + "computer_security": ['other'], + "conceptual_physics": ['physics'], + "construction_project_management": ['other','china specific'], + "economics": ['economics'], + "education": ['education'], + "elementary_chinese":['linguistics','china specific'], + "elementary_commonsense":['other','china specific'], + "elementary_information_and_technology": ['other'], + "electrical_engineering": ['engineering'], + "elementary_mathematics": ['math'], + "ethnology": ['culture','china specific'], + "food_science": ['other'], + "genetics": ['biology'], + "global_facts": ['global'], + "high_school_biology": ['biology'], + "high_school_chemistry": ['chemistry'], + "high_school_geography": ['geography'], + "high_school_mathematics": ['math'], + "high_school_physics": ['physics'], + "high_school_politics": ['politics','china specific'], + "human_sexuality": ['other'], + "international_law": ['law'], + "journalism": ['sociology'], + "jurisprudence": ['law'], + "legal_and_moral_basis": ['other'], + "logical": ['philosophy'], + "machine_learning": ['computer science'], + "management": ['business'], + "marketing": ['business'], + "marxist_theory": ['philosophy'], + "modern_chinese": ['linguistics','china specific'], + "nutrition": ['other'], + "philosophy": ['philosophy'], + "professional_accounting": ['business'], + "professional_law": ['law'], + "professional_medicine": ['other'], + "professional_psychology": ['psychology'], + "public_relations": ['politics'], + "security_study": ['politics'], + "sociology": ['culture'], + "sports_science": ['other'], + "traditional_chinese_medicine": ['other','china specific'], + "virology": ['biology'], + "world_history":['history'], + "world_religions": ['global'], +} + +categories = { + "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"], + "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"], + "Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"], + "Other":["other"], + "China specific": ["china specific"], +} diff --git a/scripts/cmmlu/eval.py b/scripts/cmmlu/eval.py new file mode 100644 index 0000000..73786cf --- /dev/null +++ b/scripts/cmmlu/eval.py @@ -0,0 +1,134 @@ +# This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval +import os +import argparse +import pandas as pd +import torch +import json +from mxitral_evaluator import Mixtral_Evaluator +from glob import glob +import time +from collections import defaultdict +from categories import name_en2zh, subcategories, categories +choices = ["A", "B", "C", "D"] + +category2subject = defaultdict(list) +for k,v in categories.items(): + for subject, subcat in subcategories.items(): + for c in subcat: + if c in v: + category2subject[k].append(subject) +category2subject_list = defaultdict(list) +for key,value in category2subject.items(): + for val in value: + category2subject_list[val]=[val,name_en2zh[val],key] +category2subject=category2subject_list +choices = ["A", "B", "C", "D"] + + +def main(args, evaluator,take): + + subject_mapping = category2subject #json.load(f) + filenames = [s.split('/')[-1] for s in glob(args.input_dir+"/test/*csv")] + subject_list = [val_file.replace(".csv","") for val_file in filenames] + accuracy, summary = {}, {} + + run_date=time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) + output_dir = args.output_dir + save_result_dir=os.path.join(output_dir,f"take{take}") + if not os.path.exists(save_result_dir): + os.makedirs(save_result_dir,exist_ok=True) + + all_answers = {} + for index,subject_name in enumerate(subject_list): + print(f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_path} with subject of {subject_name}!") + val_file_path=os.path.join(args.input_dir+'/test',f'{subject_name}.csv') + dev_file_path=os.path.join(args.input_dir+'/dev',f'{subject_name}.csv') + + val_df=pd.read_csv(val_file_path) + dev_df=pd.read_csv(dev_file_path) if args.few_shot else None + + correct_ratio, answers = evaluator.eval_subject(subject_name, val_df, dev_df, + save_result_dir=save_result_dir if args.do_save_csv else None, + few_shot=args.few_shot, + cot=args.cot, + with_prompt=args.with_prompt, + constrained_decoding=args.constrained_decoding, + do_test=False) + print(f"Subject: {subject_name}") + print(f"Acc: {correct_ratio}") + accuracy[subject_name] = correct_ratio + summary[subject_name] = {"score":correct_ratio, + "num":len(val_df), + "correct":correct_ratio*len(val_df)/100} + all_answers[subject_name] = answers + + json.dump(all_answers,open(save_result_dir+'/submission.json','w'),ensure_ascii=False,indent=4) + print("\n\nModel:",args.model_path) + print("Accuracy:") + for k, v in accuracy.items(): + print(k, ": ", v) + + total_num = 0 + total_correct = 0 + summary['grouped'] = { + "China specific": {"correct": 0.0, "num": 0}, + "STEM": {"correct": 0.0, "num": 0}, + "Social Science": {"correct": 0.0, "num": 0}, + "Humanities": {"correct": 0.0, "num": 0}, + "Other": {"correct": 0.0, "num": 0} + } + for subj, info in subject_mapping.items(): + group = info[2] + summary['grouped'][group]["num"] += summary[subj]['num'] + summary['grouped'][group]["correct"] += summary[subj]['correct'] + for group, info in summary['grouped'].items(): + info['score'] = info["correct"] / info["num"] + total_num += info["num"] + total_correct += info["correct"] + summary['All'] = {"score": total_correct / total_num, "num": total_num, "correct": total_correct} + + json.dump(summary,open(save_result_dir+'/summary.json','w'),ensure_ascii=False,indent=2) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ntrain", "-k", type=int, default=5) + parser.add_argument("--model_path", type=str) + parser.add_argument("--cot",choices=["False","True"], default="False") + parser.add_argument("--few_shot", choices=["False","True"], default="True") + parser.add_argument("--with_prompt", choices=["False","True"], default="False") + parser.add_argument("--constrained_decoding", choices=["False","True"], default="False") + parser.add_argument("--temperature",type=float,default=0.2) + parser.add_argument("--n_times", default=1,type=int) + parser.add_argument("--do_save_csv", choices=["False","True"], default="False") + parser.add_argument("--output_dir", type=str) + parser.add_argument("--input_dir", type=str) + parser.add_argument("--verbose", action="store_true", help="Print detailed information of each example.") + parser.add_argument("--load_in_4bit", action="store_true", help="The model was loaded by 4-bit quantization") + parser.add_argument("--use_flash_attention_2", action="store_true", help="Use flash_attention2 to replace the mixtral attention") + + args = parser.parse_args() + + args.cot = args.cot == "True" + args.few_shot = args.few_shot == "True" + args.with_prompt = args.with_prompt == "True" + args.do_save_csv = args.do_save_csv == "True" + args.constrained_decoding = args.constrained_decoding == "True" + if args.constrained_decoding is True: + args.n_times=max(args.n_times,1) + print(args) + + device = torch.device(0) + print(device) + evaluator=Mixtral_Evaluator( + choices=choices, + k=args.ntrain, + model_path=args.model_path, + device=device, + temperature=args.temperature, + load_in_4bit=args.load_in_4bit, + use_flash_attention_2=args.use_flash_attention_2, + verbose=args.verbose + ) + for i in range(args.n_times): + main(args,evaluator=evaluator,take=i) diff --git a/scripts/cmmlu/evaluator.py b/scripts/cmmlu/evaluator.py new file mode 100644 index 0000000..32f5d98 --- /dev/null +++ b/scripts/cmmlu/evaluator.py @@ -0,0 +1,47 @@ +# This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval +import string +class Evaluator: + def __init__(self, choices, model_path, k=-1): + self.choices = choices + self.model_path = model_path + self.k = k + self.puncs = list(string.punctuation) + + def format_example(self, line, include_answer=True): + example = line['question'] + # print(example) + for choice in self.choices: + example += f'\n{choice}. {line[f"{choice}"]}' + example += '\n答案:' + if include_answer: + example += f'{line["answer"]}\n\n' + return example + + def generate_few_shot_prompt(self, subject, dev_df): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + prompt += self.format_example(dev_df.iloc[i, :]) + return prompt + + def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, save_result_dir=None): + pass + + def normalize_answer(self,s): + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude=set(self.puncs) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_punc(lower(s))) + + def exact_match(self,pred, target): + return self.normalize_answer(pred)==self.normalize_answer(target) diff --git a/scripts/cmmlu/mixtral_evaluator.py b/scripts/cmmlu/mixtral_evaluator.py new file mode 100644 index 0000000..b825ffa --- /dev/null +++ b/scripts/cmmlu/mixtral_evaluator.py @@ -0,0 +1,240 @@ +# This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval + +import os +import re +from tqdm import tqdm +import random +import numpy as np +import torch +from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig +from transformers import GenerationConfig +from evaluator import Evaluator + + +class Mixtral_Evaluator(Evaluator): + def __init__(self, choices, k, model_path, device, temperature=0.2, load_in_4bit=False, use_flash_attention_2=False, verbose=False): + super(Mixtral_Evaluator, self).__init__(choices, model_path, k) + load_type = torch.float16 + self.model_path = model_path + self.device = device + self.verbose = verbose + self.load_in_4bit = load_in_4bit + self.use_flash_attention_2 = use_flash_attention_2 + self.tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + bnb_4bit_compute_dtype=load_type, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + quantization_config=quantization_config if self.load_in_4bit else None, + torch_dtype=load_type, + low_cpu_mem_usage=True, + device_map='auto', + attn_implementation="flash_attention_2" if self.use_flash_attention_2 else "sdpa" + ) + self.generation_config = GenerationConfig( + temperature=temperature, + top_k=40, + top_p=0.9, + do_sample=True, + num_beams=1, + repetition_penalty=1.1, + max_new_tokens=20 + ) + + self.sA_id = self.tokenizer.encode("A", add_special_tokens=False)[0] + self.sB_id = self.tokenizer.encode("B", add_special_tokens=False)[0] + self.sC_id = self.tokenizer.encode("C", add_special_tokens=False)[0] + self.sD_id = self.tokenizer.encode("D", add_special_tokens=False)[0] + self.A_id = self.tokenizer.encode(":A")[-1] + self.B_id = self.tokenizer.encode(":B")[-1] + self.C_id = self.tokenizer.encode(":C")[-1] + self.D_id = self.tokenizer.encode(":D")[-1] + + def eval_subject(self, subject_name, + test_df, + dev_df=None, + few_shot=False, + cot=False, + save_result_dir=None, + with_prompt=False, + constrained_decoding=False, + do_test=False): + all_answers = {} + if constrained_decoding is True: + self.generation_config.output_scores = True + self.generation_config.return_dict_in_generate = True + self.generation_config.max_new_tokens = 1 + self.generation_config.top_p = 1.0 + self.generation_config.top_k = 0 + + correct_num = 0 + if save_result_dir: + result = [] + score = [] + if few_shot: + if with_prompt: + history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot) + else: + history = self.generate_few_shot_noprompt(subject_name, dev_df, cot=cot) + else: + history = '' + answers = ['NA'] * len(test_df) if do_test is True else list(test_df['Answer']) + for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)): + question = self.format_example(row, include_answer=False, cot=cot,with_prompt=with_prompt) + instruction = question + if with_prompt: + prompt_template = ( + "[INST] {instruction} [/INST]" + ) + + instruction = prompt_template.format_map({'instruction': instruction}) + instruction=history+instruction + + inputs = self.tokenizer(instruction, return_tensors="pt") + generation_output = self.model.generate( + input_ids = inputs["input_ids"].to(self.device), + attention_mask = inputs['attention_mask'].to(self.device), + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + generation_config = self.generation_config + ) + + _, length = inputs.input_ids.shape + if constrained_decoding is True: + logits = generation_output.scores[0][0] + + logits = logits.float().cpu().detach() + choices1_logits = logits[[self.sA_id,self.sB_id,self.sC_id,self.sD_id]] + choices2_logits = logits[[self.A_id,self.B_id,self.C_id,self.D_id]] + choicesAll_logits = (choices1_logits + choices2_logits).numpy() + assert not (np.any(np.isinf(choicesAll_logits)) or np.any(np.isnan(choicesAll_logits))) + ans = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(choicesAll_logits)] + response = self.tokenizer.decode([logits.argmax(-1).item()]) + else: + response = self.tokenizer.decode(generation_output[0, length:], skip_special_tokens=True) + ans, _ = self.extract_answer(row, response) + if ans == answers[row_index]: + correct_num += 1 + correct = 1 + else: + correct = 0 + if self.verbose is True: + print(f"\n======={str(row_index)}=======") + print(f"question: {question}\n") + print(f"response: {response}\n") + print(f"extracted answer: {ans}") + print(f"ground truth: {answers[row_index]} \n") + if save_result_dir: + result.append(response) + score.append(correct) + + all_answers[str(row_index)] = ans + + correct_ratio = 100*correct_num/len(answers) + + if save_result_dir: + test_df['model_output'] = result + test_df['correctness'] = score + test_df.to_csv(os.path.join(save_result_dir, f'{subject_name}_test.csv')) + + return correct_ratio, all_answers + + def format_example(self, line, include_answer=True, cot=False, with_prompt=False): + example = line['Question'] + suffix = "" + for choice in self.choices: + example += f'\n{choice}. {line[f"{choice}"]}' + if include_answer: + if cot: + example += "\n答案:让我们一步一步思考,\n" + \ + line["explanation"] + f"\n所以答案是{line['Answer']}。\n\n" + else: + example += '\n答案:' + suffix + line["Answer"] + '\n\n' + else: + if with_prompt is False: + if cot: + example += "\n答案:让我们一步一步思考,\n1." + else: + example += '\n答案:' + suffix + else: + if cot: + example += "\n答案是什么?让我们一步一步思考,\n1." + else: + example += '\n答案:' + return example + + def generate_few_shot_noprompt(self, subject, dev_df, cot=False): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + prompt += self.format_example( + dev_df.iloc[i, :], + include_answer=True, + cot=cot + ) + return prompt + + def generate_few_shot_prompt(self, subject, dev_df, cot=False): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + prompt_template = ( + "[INST] {instruction} [/INST]好的,我会结合{subject}相关知识回答" + ) + prompt = prompt_template.format_map({'instruction':prompt, "subject":subject}) + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + line=dev_df.iloc[i, :] + q=line['Question'] + for choice in self.choices: + q += f'\n{choice}. {line[f"{choice}"]}' + + a=line['Answer'] + prompt+="[INST] "+q+"\n答案: [/INST]"+a+"\n" + + return prompt + + def extract_answer(self, line, gen_ans): + m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M) + if len(m) > 0 and m[-1] in self.choices: + return m[-1], True + answer_patterns = [ + r'([ABCD])是正确的', + r'选项([ABCD])正确', + r'答案为([ABCD])', + r'答案是([ABCD])', + r'答案([ABCD])', + r'选择([ABCD])', + r'答案:([ABCD])', + r'选择答案([ABCD])' + ] + # RE extraction + for answer_pattern in answer_patterns: + m = re.search(answer_pattern, gen_ans, re.M) + if m: + answer = m.group(1) + return answer, False + # only containing one choice-character + m = re.findall(r'[ABCD]', gen_ans, re.M) + if len(m) >= 1: + answer = m[0] + return answer, False + choices_dict = {} + pattern = "" + for c in self.choices: + choices_dict[str(line[f'{c}'])] = c + pattern += re.escape(str(line[f'{c}']))+"|" + pattern = pattern[:-1] + m = re.findall(pattern, gen_ans, re.M) + print("w/ escape:",repr(pattern),gen_ans,(len(m)>=1)) + if len(m) >= 1: + answer = choices_dict[m[0]] + return answer, False + return random.sample('ABCD', 1)[0], False diff --git a/scripts/longbench/config/dataset2maxlen.json b/scripts/longbench/config/dataset2maxlen.json new file mode 100644 index 0000000..79d0d99 --- /dev/null +++ b/scripts/longbench/config/dataset2maxlen.json @@ -0,0 +1,23 @@ +{ + "narrativeqa": 128, + "qasper": 128, + "multifieldqa_en": 64, + "multifieldqa_zh": 64, + "hotpotqa": 32, + "2wikimqa": 32, + "musique": 32, + "dureader": 128, + "gov_report": 512, + "qmsum": 512, + "multi_news": 512, + "vcsum": 512, + "trec": 64, + "triviaqa": 32, + "samsum": 128, + "lsht": 64, + "passage_count": 32, + "passage_retrieval_en": 32, + "passage_retrieval_zh": 32, + "lcc": 64, + "repobench-p": 64 +} \ No newline at end of file diff --git a/scripts/longbench/config/dataset2prompt.json b/scripts/longbench/config/dataset2prompt.json new file mode 100644 index 0000000..1c85f6b --- /dev/null +++ b/scripts/longbench/config/dataset2prompt.json @@ -0,0 +1,23 @@ +{ + "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", + "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", + "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", + "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", + "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", + "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", + "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", + "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", + "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", + "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", + "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", + "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", + "lcc": "Please complete the code given below. \n{context}Next line of code:\n", + "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" +} \ No newline at end of file diff --git a/scripts/longbench/eval.py b/scripts/longbench/eval.py new file mode 100644 index 0000000..d769d6b --- /dev/null +++ b/scripts/longbench/eval.py @@ -0,0 +1,115 @@ +# The script is from https://github.com/THUDM/LongBench +import os +import json +import argparse +import numpy as np + +from metrics import ( + qa_f1_score, + rouge_zh_score, + qa_f1_zh_score, + rouge_score, + classification_score, + retrieval_score, + retrieval_zh_score, + count_score, + code_sim_score, +) + +dataset2metric = { + "narrativeqa": qa_f1_score, + "qasper": qa_f1_score, + "multifieldqa_en": qa_f1_score, + "multifieldqa_zh": qa_f1_zh_score, + "hotpotqa": qa_f1_score, + "2wikimqa": qa_f1_score, + "musique": qa_f1_score, + "dureader": rouge_zh_score, + "gov_report": rouge_score, + "qmsum": rouge_score, + "multi_news": rouge_score, + "vcsum": rouge_zh_score, + "trec": classification_score, + "triviaqa": qa_f1_score, + "samsum": rouge_score, + "lsht": classification_score, + "passage_retrieval_en": retrieval_score, + "passage_count": count_score, + "passage_retrieval_zh": retrieval_zh_score, + "lcc": code_sim_score, + "repobench-p": code_sim_score, +} + + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir') + parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") + return parser.parse_args(args) + + +def scorer_e(dataset, predictions, answers, lengths, all_classes): + scores = {"0-4k": [], "4-8k": [], "8k+": []} + for (prediction, ground_truths, length) in zip(predictions, answers, lengths): + score = 0. + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip('\n').split('\n')[0] + for ground_truth in ground_truths: + score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) + if length < 4000: + scores["0-4k"].append(score) + elif length < 8000: + scores["4-8k"].append(score) + else: + scores["8k+"].append(score) + for key in scores.keys(): + scores[key] = round(100 * np.mean(scores[key]), 2) + return scores + + +def scorer(dataset, predictions, answers, all_classes): + total_score = 0. + for (prediction, ground_truths) in zip(predictions, answers): + score = 0. + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip('\n').split('\n')[0] + for ground_truth in ground_truths: + score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) + total_score += score + return round(100 * total_score / len(predictions), 2) + + +if __name__ == '__main__': + args = parse_args() + scores = dict() + if args.e: + path = f"{args.output_dir}/pred_e/" + else: + path = f"{args.output_dir}/pred/" + all_files = os.listdir(path) + print("Evaluating on:", all_files) + for filename in all_files: + if not filename.endswith("jsonl"): + continue + predictions, answers, lengths = [], [], [] + dataset = filename.split('.')[0] + with open(f"{path}{filename}", "r", encoding="utf-8") as f: + print(filename) + for line in f: + data = json.loads(line) + predictions.append(data["pred"]) + answers.append(data["answers"]) + all_classes = data["all_classes"] + if "length" in data: + lengths.append(data["length"]) + if args.e: + score = scorer_e(dataset, predictions, answers, lengths, all_classes) + else: + score = scorer(dataset, predictions, answers, all_classes) + scores[dataset] = score + if args.e: + out_path = f"{args.output_dir}/pred_e/result.json" + else: + out_path = f"{args.output_dir}/pred/result.json" + with open(out_path, "w") as f: + json.dump(scores, f, ensure_ascii=False, indent=4) diff --git a/scripts/longbench/metrics.py b/scripts/longbench/metrics.py new file mode 100644 index 0000000..e60531e --- /dev/null +++ b/scripts/longbench/metrics.py @@ -0,0 +1,154 @@ +# The script is from https://github.com/THUDM/LongBench +import re +import string + +import jieba +from fuzzywuzzy import fuzz +import difflib + +from collections import Counter +from rouge import Rouge + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def normalize_zh_answer(s): + """Lower text and remove punctuation, extra whitespace.""" + + def white_space_fix(text): + return "".join(text.split()) + + def remove_punc(text): + cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." + all_punctuation = set(string.punctuation + cn_punctuation) + return "".join(ch for ch in text if ch not in all_punctuation) + + def lower(text): + return text.lower() + + return white_space_fix(remove_punc(lower(s))) + +def count_score(prediction, ground_truth, **kwargs): + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def retrieval_score(prediction, ground_truth, **kwargs): + pattern = r'Paragraph (\d+)' + matches = re.findall(pattern, ground_truth) + ground_truth_id = matches[0] + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth_id): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def retrieval_zh_score(prediction, ground_truth, **kwargs): + pattern = r'段落(\d+)' + matches = re.findall(pattern, ground_truth) + ground_truth_id = matches[0] + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth_id): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def code_sim_score(prediction, ground_truth, **kwargs): + all_lines = prediction.lstrip('\n').split('\n') + prediction = "" + for line in all_lines: + if ('`' not in line) and ('#' not in line) and ('//' not in line): + prediction = line + break + return (fuzz.ratio(prediction, ground_truth) / 100) + +def classification_score(prediction, ground_truth, **kwargs): + em_match_list = [] + all_classes = kwargs["all_classes"] + for class_name in all_classes: + if class_name in prediction: + em_match_list.append(class_name) + for match_term in em_match_list: + if match_term in ground_truth and match_term != ground_truth: + em_match_list.remove(match_term) + if em_match_list != 0: + if ground_truth in em_match_list: + score = (1.0 / len(em_match_list)) + else: + score = 0.0 + else: + best_match = None + highest_similarity = 0 + for string in all_classes: + similarity = difflib.SequenceMatcher(None, string, prediction).ratio() + if similarity > highest_similarity: + highest_similarity = similarity + best_match = string + score = float(best_match == ground_truth) + return score + +def rouge_score(prediction, ground_truth, **kwargs): + rouge = Rouge() + try: + scores = rouge.get_scores([prediction], [ground_truth], avg=True) + except Exception: + return 0.0 + return scores["rouge-l"]["f"] + +def rouge_zh_score(prediction, ground_truth, **kwargs): + prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) + ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) + score = rouge_score(prediction, ground_truth) + return score + +def f1_score(prediction, ground_truth, **kwargs): + common = Counter(prediction) & Counter(ground_truth) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction) + recall = 1.0 * num_same / len(ground_truth) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + +def qa_f1_score(prediction, ground_truth, **kwargs): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + return f1_score(prediction_tokens, ground_truth_tokens) + + +def qa_f1_zh_score(prediction, ground_truth, **kwargs): + prediction_tokens = list(jieba.cut(prediction, cut_all=False)) + ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) + prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] + ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] + prediction_tokens = [token for token in prediction_tokens if len(token) > 0] + ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] + return f1_score(prediction_tokens, ground_truth_tokens) diff --git a/scripts/longbench/pred_mixtral.py b/scripts/longbench/pred_mixtral.py new file mode 100644 index 0000000..79b7e53 --- /dev/null +++ b/scripts/longbench/pred_mixtral.py @@ -0,0 +1,184 @@ +# The script is modified from https://github.com/THUDM/LongBench/blob/main/pred.py +from datasets import load_dataset +import torch +import random +import numpy as np +import json +from transformers import LlamaTokenizer, AutoModelForCausalLM +from transformers import BitsAndBytesConfig +from tqdm import tqdm +import os +import argparse +import sys +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) +dir_path = os.path.dirname(os.path.realpath(__file__)) + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str) +parser.add_argument('--load_in_4bit',action='store_true') +parser.add_argument('--load_in_8bit',action='store_true') +parser.add_argument('--predict_on',type=str, default='zh') +parser.add_argument('--output_dir',type=str, default='pred') +parser.add_argument('--gpus',type=str, default=None) +parser.add_argument('--max_length',type=int, default=4096-512) +parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") +parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash attention to replace the mixtral attention") + +args = parser.parse_args() + +model_path = args.model_path +load_in_4bit = args.load_in_4bit +load_in_8bit = args.load_in_8bit +predict_on = args.predict_on +output_dir = args.output_dir +gpus=args.gpus +max_length = args.max_length + +DO_SAMPLE =True +TEMPERATURE = 0.2 +REPETITION_PENALTY = 1.1 +TOP_P = 0.95 +TOP_K = 40 + +if gpus is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = gpus + + +def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device): + preds = [] + for json_obj in tqdm(data): + prompt = prompt_format.format(**json_obj) + # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + if len(tokenized_prompt) > max_length: + half = int(max_length/2) + prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + + input_data = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) + context_length = input_data.input_ids.shape[-1] + if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue + output = model.generate( + **input_data, + max_new_tokens=max_gen, + num_beams=1, + do_sample=DO_SAMPLE, + repetition_penalty = REPETITION_PENALTY, + top_p = TOP_P, + top_k = TOP_K, + temperature=TEMPERATURE, + min_length=context_length+1, + eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], + pad_token_id=tokenizer.eos_token_id + )[0] + else: + output = model.generate( + **input_data, + max_new_tokens=max_gen, + num_beams=1, + do_sample=DO_SAMPLE, + repetition_penalty = REPETITION_PENALTY, + top_p = TOP_P, + top_k = TOP_K, + temperature=TEMPERATURE, + pad_token_id=tokenizer.eos_token_id + )[0] + pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) + #print(pred) + preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}) + return preds + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.cuda.manual_seed_all(seed) + + +if __name__ == '__main__': + seed_everything(42) + load_type = torch.float16 + if torch.cuda.is_available(): + device = torch.device(0) + else: + device = torch.device('cpu') + + if args.e: + en_datasets = [ "hotpotqa","2wikimqa", + "qasper", "multifieldqa_en", "gov_report", + "trec", "samsum", "triviaqa", + "passage_count", "passage_retrieval_en", "multi_news"] + zh_datasets = [] + code_datasets = [ "lcc", "repobench-p" ] + if not os.path.exists(f"{output_dir}/pred_e"): + os.makedirs(f"{output_dir}/pred_e") + else: + en_datasets = [ "hotpotqa","2wikimqa", "musique", "narrativeqa", + "qasper", "multifieldqa_en", "gov_report", + "qmsum", "trec", "samsum", "triviaqa", + "passage_count", "passage_retrieval_en", "multi_news"] + zh_datasets = [ "dureader", "multifieldqa_zh", + "vcsum","lsht", "passage_retrieval_zh"] + code_datasets = [ "lcc", "repobench-p" ] + + if not os.path.exists(f"{output_dir}/pred"): + os.makedirs(f"{output_dir}/pred") + + datasets = [] + for data_type in predict_on.split(','): + if data_type == 'zh': + datasets += zh_datasets + elif data_type == 'en': + datasets += en_datasets + elif data_type == 'code': + datasets += code_datasets + print(datasets) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) + model = None + if args.load_in_4bit or args.load_in_8bit: + quantization_config = BitsAndBytesConfig( + load_in_4bit=args.load_in_4bit, + load_in_8bit=args.load_in_8bit, + bnb_4bit_compute_dtype=load_type, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=load_type, + low_cpu_mem_usage=True, + device_map='auto', + quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None, + attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa" + ) + model = model.eval() + model_vocab_size = model.get_input_embeddings().weight.size(0) + print(f"Vocab of the base model: {model_vocab_size}") + tokenizer_vocab_size = len(tokenizer) + print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") + + # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output + dataset2prompt = json.load(open(dir_path + "/config/dataset2prompt.json", "r")) + dataset2maxlen = json.load(open(dir_path + "/config/dataset2maxlen.json", "r")) + # predict on each dataset + for dataset in datasets: + print(f"Loading dataset {dataset}") + if args.e: + data = load_dataset('THUDM/LongBench', dataset+'_e', split='test') + output_path = f"{output_dir}/pred_e/{dataset}.jsonl" + else: + data = load_dataset('THUDM/LongBench', dataset, split='test') + output_path = f"{output_dir}/pred/{dataset}.jsonl" + prompt_format = dataset2prompt[dataset] + max_gen = dataset2maxlen[dataset] + preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device) + with open(output_path, "w", encoding="utf-8") as f: + for pred in preds: + json.dump(pred, f, ensure_ascii=False) + f.write('\n') diff --git a/scripts/mmlu/categories.py b/scripts/mmlu/categories.py new file mode 100644 index 0000000..dd6621a --- /dev/null +++ b/scripts/mmlu/categories.py @@ -0,0 +1,66 @@ +subcategories = { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], +} + +categories = { + "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], + "humanities": ["history", "philosophy", "law"], + "social sciences": ["politics", "culture", "economics", "geography", "psychology"], + "other (business, health, misc.)": ["other", "business", "health"], +} \ No newline at end of file diff --git a/scripts/mmlu/eval.py b/scripts/mmlu/eval.py new file mode 100644 index 0000000..7158c2c --- /dev/null +++ b/scripts/mmlu/eval.py @@ -0,0 +1,197 @@ +# modified from https://github.com/baichuan-inc/Baichuan-7B/blob/main/evaluation/evaluate_mmlu.py +import argparse +import os +import torch +import numpy as np +import pandas as pd +from categories import subcategories, categories +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig +choices = ["A", "B", "C", "D"] + + +def format_subject(subject): + line = subject.split("_") + s = "" + for entry in line: + s += " " + entry + return s + + +def format_example(df, idx, include_answer=True): + prompt = df.iloc[idx, 0] + k = df.shape[1] - 2 + for j in range(k): + prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) + prompt += "\nAnswer:" + if include_answer: + prompt += " {}\n\n".format(df.iloc[idx, k + 1]) + return prompt + + +def gen_prompt(train_df, subject, k=-1): + prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( + format_subject(subject) + ) + if k == -1: + k = train_df.shape[0] + for i in range(k): + prompt += format_example(train_df, i) + return prompt + + +@torch.no_grad() +def mmlu_eval(args, subject, model, tokenizer, dev_df, test_df): + cors = [] + all_probs = [] + + for i in range(test_df.shape[0]): + # get prompt and make sure it fits + k = args.ntrain + prompt_end = format_example(test_df, i, include_answer=False) + train_prompt = gen_prompt(dev_df, subject, k) + prompt = train_prompt + prompt_end + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() + + label = test_df.iloc[i, test_df.shape[1] - 1] + + logits = model( + input_ids=input_ids, + ).logits[:,-1].flatten() + + probs = ( + torch.nn.functional.softmax( + torch.tensor( + [ + logits[tokenizer("A").input_ids[-1]], + logits[tokenizer("B").input_ids[-1]], + logits[tokenizer("C").input_ids[-1]], + logits[tokenizer("D").input_ids[-1]], + ] + ), + dim=0, + ) + .detach() + .cpu() + .to(torch.float32) + .numpy() + ) + pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] + + cor = pred == label + cors.append(cor) + all_probs.append(probs) + + acc = np.mean(cors) + cors = np.array(cors) + + all_probs = np.array(all_probs) + print("Average accuracy {:.3f} - {}".format(acc, subject)) + + return cors, acc, all_probs + + +def main(args): + tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False) + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + model = AutoModelForCausalLM.from_pretrained( + args.model_path, + quantization_config=quantization_config if args.load_in_4bit else None, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map='auto', + attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa" + ).eval() + subjects = sorted( + [ + f.split("_test.csv")[0] + for f in os.listdir(os.path.join(args.data_dir, "test")) + if "_test.csv" in f + ] + ) + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + if not os.path.exists(os.path.join(args.save_dir, "results")): + os.makedirs(os.path.join(args.save_dir, "results")) + + all_cors = [] + subcat_cors = { + subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists + } + cat_cors = {cat: [] for cat in categories} + + for subject in subjects: + dev_df = pd.read_csv( + os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None + )[: args.ntrain] + if args.do_test: + test_df = pd.read_csv( + os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None + ) + else: + test_df = pd.read_csv( + os.path.join(args.data_dir, "val", subject + "_val.csv"), header=None + ) + + cors, _, probs = mmlu_eval(args, subject, model, tokenizer, dev_df, test_df) + subcats = subcategories[subject] + for subcat in subcats: + subcat_cors[subcat].append(cors) + for key in categories.keys(): + if subcat in categories[key]: + cat_cors[key].append(cors) + all_cors.append(cors) + + test_df["correct"] = cors + for j in range(probs.shape[1]): + choice = choices[j] + test_df["choice{}_probs".format(choice)] = probs[:, j] + test_df.to_csv( + os.path.join( + args.save_dir, "results", f"{subject}.csv" + ), + index=None, + ) + + for subcat in subcat_cors: + subcat_acc = np.mean(np.concatenate(subcat_cors[subcat])) + print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat)) + + for cat in cat_cors: + cat_acc = np.mean(np.concatenate(cat_cors[cat])) + print("Average accuracy {:.3f} - {}".format(cat_acc, cat)) + weighted_acc = np.mean(np.concatenate(all_cors)) + print("Average accuracy: {:.3f}".format(weighted_acc)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ntrain", "-k", type=int, default=5) + parser.add_argument("--ngpu", "-g", type=int, default=8) + parser.add_argument("--data_dir", "-d", type=str, default="data") + parser.add_argument("--save_dir", "-s", type=str, default="results") + parser.add_argument( + "--model_path", + "-m", + type=str, + ) + parser.add_argument( + "--do_test", + action="store_true" + ) + parser.add_argument( + "--load_in_4bit", + action="store_true" + ) + parser.add_argument( + "--use_flash_attention_2", + action="store_true" + ) + args = parser.parse_args() + main(args) \ No newline at end of file