diff --git a/datasets/download_text_classification.sh b/datasets/download_text_classification.sh index 3f654d19..6b3672ac 100755 --- a/datasets/download_text_classification.sh +++ b/datasets/download_text_classification.sh @@ -1,43 +1,50 @@ #!/bin/sh -DIR="./TextClassification" +DIR="./datasets/TextClassification" mkdir $DIR cd $DIR -rm -rf mnli -wget --content-disposition https://cloud.tsinghua.edu.cn/f/33182c22cb594e88b49b/?dl=1 -tar -zxvf mnli.tar.gz -rm -rf mnli.tar.gz - -rm -rf agnews -wget --content-disposition https://cloud.tsinghua.edu.cn/f/0fb6af2a1e6647b79098/?dl=1 -tar -zxvf agnews.tar.gz -rm -rf agnews.tar.gz - -rm -rf dbpedia -wget --content-disposition https://cloud.tsinghua.edu.cn/f/362d3cdaa63b4692bafb/?dl=1 -tar -zxvf dbpedia.tar.gz -rm -rf dbpedia.tar.gz - -rm -rf imdb -wget --content-disposition https://cloud.tsinghua.edu.cn/f/37bd6cb978d342db87ed/?dl=1 -tar -zxvf imdb.tar.gz -rm -rf imdb.tar.gz - -rm -rf SST-2 -wget --content-disposition https://cloud.tsinghua.edu.cn/f/bccfdb243eca404f8bf3/?dl=1 -tar -zxvf SST-2.tar.gz -rm -rf SST-2.tar.gz - -rm -rf amazon -wget --content-disposition https://cloud.tsinghua.edu.cn/f/e00a4c44aaf844cdb6c9/?dl=1 -tar -zxvf amazon.tar.gz -mv datasets/amazon/ amazon -rm -rf ./datasets -rm -rf amazon.tar.gz - -rm -rf yahoo_answers_topics -wget --content-disposition https://cloud.tsinghua.edu.cn/f/79257038afaa4730a03f/?dl=1 -tar -zxvf yahoo_answers_topics.tar.gz -rm -rf yahoo_answers_topics.tar.gz +# rm -rf mnli +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/33182c22cb594e88b49b/?dl=1 +# tar -zxvf mnli.tar.gz +# rm -rf mnli.tar.gz + +# rm -rf agnews +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/0fb6af2a1e6647b79098/?dl=1 +# tar -zxvf agnews.tar.gz +# rm -rf agnews.tar.gz + +# rm -rf dbpedia +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/362d3cdaa63b4692bafb/?dl=1 +# tar -zxvf dbpedia.tar.gz +# rm -rf dbpedia.tar.gz + +# rm -rf imdb +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/37bd6cb978d342db87ed/?dl=1 +# tar -zxvf imdb.tar.gz +# rm -rf imdb.tar.gz + +# rm -rf SST-2 +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/bccfdb243eca404f8bf3/?dl=1 +# tar -zxvf SST-2.tar.gz +# rm -rf SST-2.tar.gz + +# rm -rf amazon +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/e00a4c44aaf844cdb6c9/?dl=1 +# tar -zxvf amazon.tar.gz +# mv datasets/amazon/ amazon +# rm -rf ./datasets +# rm -rf amazon.tar.gz + +# rm -rf yahoo_answers_topics +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/79257038afaa4730a03f/?dl=1 +# tar -zxvf yahoo_answers_topics.tar.gz +# rm -rf yahoo_answers_topics.tar.gz + +rm -rf dwmw17 +FILEID="1FW_qQX8aubnuFy--y8cY8HW26CFixFei" +FILENAME="dwmw17.zip" +wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=${FILEID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O ${FILENAME} && rm -rf /tmp/cookies.txt +unzip dwmw17.zip +rm dwmw17.zip cd .. diff --git a/experiments/classification_protoverb_dwmw17.yaml b/experiments/classification_protoverb_dwmw17.yaml new file mode 100644 index 00000000..2b0deb3f --- /dev/null +++ b/experiments/classification_protoverb_dwmw17.yaml @@ -0,0 +1,68 @@ +dataset: + name: dwmw17 + path: datasets/TextClassification/dwmw17 + +plm: + model_name: roberta + model_path: roberta-large + optimize: + freeze_para: False + lr: 0.00003 + weight_decay: 0.01 + scheduler: + type: + num_warmup_steps: 500 + +checkpoint: + save_latest: False + save_best: False + +train: + batch_size: 2 + num_epochs: 5 + train_verblizer: post + clean: True + +test: + batch_size: 2 + +template: manual_template +verbalizer: proto_verbalizer + +manual_template: + choice: 0 + file_path: scripts/TextClassification/dwmw17/manual_template.txt + +proto_verbalizer: + parent_config: dwmw17 + choice: 0 + file_path: scripts/TextClassification/dwmw17/icl_verbalizer.json + lr: 0.01 + mid_dim: 128 + epochs: 30 + multi_verb: multi + + + +environment: + num_gpus: 1 + cuda_visible_devices: + - 0 + local_rank: 0 + +learning_setting: few_shot + +few_shot: + parent_config: learning_setting + few_shot_sampling: sampling_from_train + +sampling_from_train: + parent_config: few_shot_sampling + num_examples_per_label: 1 + also_sample_dev: True + num_examples_per_label_dev: 1 + seed: + - 123 + +reproduce: # seed for reproduction + seed: 123 # a seed for all random part \ No newline at end of file diff --git a/openprompt/data_utils/text_classification_dataset.py b/openprompt/data_utils/text_classification_dataset.py index 1dde85c3..4cfe0fad 100644 --- a/openprompt/data_utils/text_classification_dataset.py +++ b/openprompt/data_utils/text_classification_dataset.py @@ -17,6 +17,7 @@ import os import json, csv +import pandas as pd from abc import ABC, abstractmethod from collections import defaultdict, Counter from typing import List, Dict, Callable @@ -27,6 +28,44 @@ from openprompt.data_utils.data_processor import DataProcessor +class Dwmw17Processor(DataProcessor): + """ + from openprompt.data_utils.text_classification_dataset import PROCESSORS + import os + # Get the absolute path of the parent directory of the current file + root_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + + # Set the base path to the 'datasets' directory located in the parent directory + base_path = os.path.join(root_dir, 'datasets/TextClassification') + + + dataset_name = "dwmw17" + dataset_path = os.path.join(base_path, dataset_name) + processor = PROCESSORS[dataset_name.lower()]() + trainvalid_dataset = processor.get_train_examples(dataset_path) + print(trainvalid_dataset) + """ + def __init__(self): + super().__init__() + self.labels = [ "hate speech", "offensive language", "neither" ] + + def get_examples(self, data_dir, split): + path = os.path.join(data_dir, "{}.csv".format(split)) + examples = [] + with open(path, encoding='utf8') as f: + reader = csv.reader(f, delimiter=',') + # Skip first row + next(reader) + for idx, row in enumerate(reader): + idx, _, _, _, _, label, tweet = row + text_a = tweet + example = InputExample( + guid=str(idx), text_a=text_a, label=int(label)) + examples.append(example) + + return examples + + class MnliProcessor(DataProcessor): # TODO Test needed def __init__(self): @@ -358,4 +397,5 @@ def get_examples(self, data_dir, split): "sst-2": SST2Processor, "mnli": MnliProcessor, "yahoo": YahooProcessor, + "dwmw17": Dwmw17Processor } diff --git a/scripts/TextClassification/dwmw17/icl_verbalizer.json b/scripts/TextClassification/dwmw17/icl_verbalizer.json new file mode 100644 index 00000000..c4631d0e --- /dev/null +++ b/scripts/TextClassification/dwmw17/icl_verbalizer.json @@ -0,0 +1,5 @@ +{ + "hate speech": ["Hateful", "Malicious", "Malevolent", "Vicious", "Nefarious", "Sinister", "Discriminatory", "Harmful", "Abusive", "Prejudice"], + "offensive language": ["Offensive", "Insulting", "Rude", "Inappropriate", "Insensitive", "Controversial", "Obscenity", "Profanity"], + "neither": ["Harmless", "Innocent", "Benign", "Nonthreatening", "Inoffensive", "Amicable", "Acceptable", "Respectful", "Neutral"] +} \ No newline at end of file diff --git a/scripts/TextClassification/dwmw17/manual_template.txt b/scripts/TextClassification/dwmw17/manual_template.txt new file mode 100644 index 00000000..49341977 --- /dev/null +++ b/scripts/TextClassification/dwmw17/manual_template.txt @@ -0,0 +1,3 @@ +This tweet contains {"mask"} . {"placeholder": "text_a"} +This tweet is {"mask"} . {"placeholder": "text_a"} +A {"mask"} tweet : {"placeholder": "text_a"} \ No newline at end of file