Official Code for EMNLP 2023 Paper "KCTS: Knowledge-Constrained Tree Search Decoding with Token-Level Hallucination Detection" (https://arxiv.org/abs/2310.09044).
pip install -r requirements.txt
pip install -e .
- First, download WoW dataset through ParlAI.
- Then,
export WOW_PATH=<PATH to WOW DATASET>
sh scripts/shell/data_process/preprocess_wow.sh 20 $WOW_PATH
- Generate Partial Negative data
bash scripts/shell/data_process/partial_neg_gen.sh 0 wow 16 # for wow
bash scripts/shell/data_process/partial_neg_gen.sh 0 cnn_dailymail 16 # for cnn/dm data
- Sample Random Negative data (for WoW only)
bash scripts/shell/data_process/random_neg.sh wow
- Mix the datasets to your liking.
# typo expected
from datasets import load_from_disk
partial_data_path = <CHANGE HERE>
random_data_path = <CHANGE HERE>
partial_data = load_from_disk(partial_data_path)
random_data = load_from_disk(random_data_path)
merged_dataset = concatenate_datasets([partial_data, random_data])
merged_dataset.train_test_split(test_size=0.1)
merged_dataset.save_to_disk(SAVE_PATH)
# the numbers are the stdin options of the train script. Details can be found at the top of the script file.
sh scripts/shell/train/train_t5_token_classifier.sh 0 EOS 0 0 0 0 # train f
sh scripts/shell/train/train_t5_token_classifier.sh 0 RIPA 0 0 0 1 # finetune RIPA from f
sh scripts/shell/train/train_t5_token_classifier_cnn.sh 0 RIPA 0 0 0 0 # cnn
sh scripts/shell/guided_run.sh 0 fudge RAND wow 8 0 0 0 ''
sh scripts/shell/guided_run.sh 0 nado ALL wow 8 1 0 0 ''
# KWD
sh scripts/shell/guided_run.sh 0 fudge RIPA wow 8 0 0 0 ''
sh scripts/shell/ppl_mcts_run.sh 0 RIPA '' wow 8 0 0 0 0 0
- Need to train RIPA on GPT2 for this. Checkout
scripts/shell/train/train_token_classifier_gpt.sh
.
export EXP_ROOT=<ROOT DIRECTORY FOR EXPERIMENT>
sh scripts/shell/openai_guided_run.sh 0 RIPA 4 $EXP_ROOT 0 0 3 0 0 0
We use UniEval (Zhong et al., 2022) + MFMA (Lee et al., 2022, for summarization) + Token-based metrics.
sh scripts/eval/unieval.sh
- One can also evaluate the
$f$ confidence, usingscripts/eval/class_prob.sh
script. - Also see
scripts/eval/test_t5_token_classifier.sh
to evaluate the classifier performance.