Skip to content

Commit

Permalink
Commit the Colab Demo as .ipynb file
Browse files Browse the repository at this point in the history
  • Loading branch information
chiayewken authored Apr 25, 2022
1 parent b583e2a commit 733937e
Showing 1 changed file with 381 additions and 0 deletions.
381 changes: 381 additions & 0 deletions demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,381 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"### RelationPrompt: Leveraging Prompts to Generate Synthetic Data for Zero-Shot Relation Triplet Extraction\n",
"\n",
"GitHub: https://github.com/declare-lab/RelationPrompt"
],
"metadata": {
"id": "qm5jvHp3vpKT"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VP-z5ENrR7S3",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a7f24d70-77b4-4ede-b7a8-0e3e2a38c10e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"fatal: destination path 'RelationPrompt' already exists and is not an empty directory.\n",
"HEAD is now at 8ce3656 Upgrade torch version 1.9.0 -> 1.10.0\n",
"File ‘model_fewrel_unseen_10_seed_0.tar’ already there; not retrieving.\n",
"\n"
]
}
],
"source": [
"!git clone https://github.com/declare-lab/RelationPrompt.git\n",
"!cd RelationPrompt && git checkout 8ce3656\n",
"!cp -a RelationPrompt/* .\n",
"!wget -q -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/zero_rte_data.zip\n",
"!wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_fewrel_unseen_10_seed_0.tar\n",
"!tar -xf model_fewrel_unseen_10_seed_0.tar\n",
"# !wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_wiki_unseen_10_seed_0.tar\n",
"!unzip -nq zero_rte_data.zip\n",
"!pip install -q -r requirements.txt"
]
},
{
"cell_type": "code",
"source": [
"#@title Data Parameters\n",
"data_name = \"fewrel\" #@param [\"fewrel\", \"wiki\"]\n",
"num_unseen_labels = 10 #@param [5,10,15]\n",
"random_seed = 0 #@param [0,1,2,3,4]\n",
"data_limit = 5000 #@param {type:\"number\"}\n",
"data_dir = f\"outputs/data/splits/zero_rte/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}\"\n",
"print(dict(data_dir=data_dir))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Gj8sD5YNdgLJ",
"outputId": "08ead310-7f39-46d7-d62d-dbe8602c7dc1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'data_dir': 'outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0'}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Data Setup\n",
"import json\n",
"import random\n",
"from pathlib import Path\n",
"from wrapper import Generator, Extractor, Dataset\n",
"\n",
"def truncate_data(path:str, limit:int, path_out:str):\n",
" # Use a subset of data for quick demo on Colab\n",
" data = Dataset.load(path)\n",
" random.seed(0)\n",
" random.shuffle(data.sents)\n",
" data.sents = data.sents[:limit]\n",
" data.save(path_out)\n",
"\n",
"path_train = \"train.jsonl\"\n",
"path_dev = \"dev.jsonl\"\n",
"path_test = \"test.jsonl\"\n",
"truncate_data(f\"{data_dir}/train.jsonl\", limit=data_limit, path_out=path_train)\n",
"truncate_data(f\"{data_dir}/dev.jsonl\", limit=data_limit // 10, path_out=path_dev)\n",
"truncate_data(f\"{data_dir}/test.jsonl\", limit=data_limit // 10, path_out=path_test)"
],
"metadata": {
"id": "iO--Mb9nHgGG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Data Exploration\n",
"\n",
"def explore_data(path: str):\n",
" data = Dataset.load(path)\n",
" print(\"labels:\", data.get_labels())\n",
" print()\n",
" for s in random.sample(data.sents, k=3):\n",
" print(\"tokens:\", s.tokens)\n",
" for t in s.triplets:\n",
" print(\"head:\", t.head)\n",
" print(\"tail:\", t.tail)\n",
" print(\"relation:\", t.label)\n",
" print()\n",
"\n",
"explore_data(path_train)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vw3NlKDddMIP",
"outputId": "d100f0aa-58fa-454a-bedf-8166ab7150ed"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"labels: ['after a work by', 'applies to jurisdiction', 'architect', 'characters', 'child', 'constellation', 'contains administrative territorial entity', 'country', 'country of citizenship', 'country of origin', 'crosses', 'developer', 'director', 'distributed by', 'father', 'field of work', 'followed by', 'follows', 'genre', 'has part', 'head of government', 'headquarters location', 'heritage designation', 'instance of', 'instrument', 'language of work or name', 'league', 'licensed to broadcast to', 'located in or next to body of water', 'located in the administrative territorial entity', 'located on terrain feature', 'location of formation', 'manufacturer', 'member of', 'military branch', 'military rank', 'mother', 'mountain range', 'mouth of the watercourse', 'movement', 'notable work', 'occupant', 'occupation', 'operator', 'original language of film or TV show', 'part of', 'participant', 'participating team', 'performer', 'place served by transport hub', 'publisher', 'record label', 'residence', 'said to be the same as', 'screenwriter', 'sibling', 'sport', 'sports season of league or competition', 'spouse', 'subsidiary', 'successful candidate', 'tributary', 'voice type', 'winner', 'work location']\n",
"\n",
"tokens: ['In', 'the', 'Ulster', 'Cycle', 'of', 'Irish', 'mythology', ',', 'Lugaid', 'mac', 'Con', 'Roí', 'was', 'the', 'son', 'of', 'Cú', 'Roí', 'mac', 'Dáire', '.']\n",
"head: [2, 3]\n",
"tail: [16, 17]\n",
"relation: characters\n",
"\n",
"tokens: ['Wanandi', 'was', 'a', 'leading', 'student', 'activist', 'during', 'the', '1965', '-', '66', 'in', 'Indonesia', 'when', ',', 'over', 'time', ',', 'president', 'Sukarno', 'was', 'removed', 'from', 'power', 'and', 'Soeharto', 'became', 'the', 'second', 'president', 'of', 'Indonesia', '.']\n",
"head: [25]\n",
"tail: [12]\n",
"relation: country of citizenship\n",
"\n",
"tokens: ['The', 'Temple', 'of', 'Proserpina', 'or', 'Temple', 'of', 'ProserpineSome', 'theories', 'suggest', 'that', 'the', 'temple', 'was', 'a', 'Greek', 'Temple', 'dedicated', 'to', 'Persephone', ',', 'the', 'Greek', 'equivalent', 'to', 'the', 'Roman', 'Goddess', 'Proserpina', '.']\n",
"head: [19]\n",
"tail: [3]\n",
"relation: said to be the same as\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Use Pretrained Model for Generation\n",
"model = Generator(load_dir=\"gpt2\", save_dir=\"outputs/wrapper/fewrel/unseen_10_seed_0/generator\")\n",
"model.generate(labels=[\"location\", \"religion\"], path_out=\"synthetic.jsonl\")\n",
"explore_data(path=\"synthetic.jsonl\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tUFis82oGUAS",
"outputId": "23e65d65-fdce-4f38-bc9a-ddf8f9cb6fc6"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"labels: ['location', 'religion']\n",
"\n",
"tokens: ['In', '2007', ',', 'he', 'joined', 'a', 'group', 'of', 'artists', 'known', 'as', 'the', 'Moth', 'Boys', ',', 'an', 'annual', 'neo', '-', 'pop', 'quartet', 'that', 'plays', 'in', 'several', 'venues', 'around', 'the', 'country', 'in', 'Las', 'Vegas', '.']\n",
"head: [12, 13]\n",
"tail: [30, 31]\n",
"relation: location\n",
"\n",
"tokens: ['There', 'is', 'a', 'section', 'of', 'the', 'town', 'under', '\"', 'the', 'Graziano', '\"', 'River', ',', 'a', 'channel', 'flowing', 'the', 'river', 'in', 'southwestern', 'Italy', 'from', 'the', 'island', 'of', 'Sardinia', 'to', 'Italy', '.']\n",
"head: [10]\n",
"tail: [26]\n",
"relation: location\n",
"\n",
"tokens: ['In', 'August', '2012', ',', 'the', 'station', 'opened', 'on', 'its', 'regular', 'schedule', 'between', 'Minto', 'Plaza', 'in', 'Osaka', 'and', 'Keito', 'Station', 'in', 'the', 'city', 'of', 'Nara', 'in', 'Japan', '.']\n",
"head: [17, 18]\n",
"tail: [15]\n",
"relation: location\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Use Pretrained Model for Extraction\n",
"model = Extractor(load_dir=\"facebook/bart-base\", save_dir=\"outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final\")\n",
"model.predict(path_in=path_test, path_out=\"pred.jsonl\")\n",
"explore_data(path=\"pred.jsonl\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eGxP3vVmID9W",
"outputId": "c9ba9b59-faf5-4fec-c54e-2d4878176a2c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=64, grad_accumulation=2, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 8/8 [00:07<00:00, 1.11it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"labels: ['', 'competition class', 'location', 'member of political party', 'nominated for', 'operating system', 'original broadcaster', 'owned by', 'position played on team / speciality', 'religion']\n",
"\n",
"tokens: ['The', 'South', 'Bank', 'Show', 'is', 'a', 'television', 'arts', 'magazine', 'show', 'that', 'was', 'produced', 'by', 'ITV', 'between', '1978', 'and', '2010', ',', 'and', 'by', 'Sky', 'Arts', 'from', '2012', '.']\n",
"head: [1, 2, 3]\n",
"tail: [14]\n",
"relation: original broadcaster\n",
"\n",
"tokens: ['Then', 'Senator', 'Neptali', 'Gonzales', ',', 'whom', 'Maceda', 'helped', ',', 'was', 'installed', 'as', 'Senate', 'President', 'from', '1992', '-', '1993', 'and', '1995', '-', '1996', 'succeeded', 'him', '.']\n",
"head: [2, 3]\n",
"tail: [12, 13]\n",
"relation: position played on team / speciality\n",
"\n",
"tokens: ['In', '1908', 'she', 'won', 'the', 'singles', 'title', 'at', 'the', 'Welsh', 'Championships', 'in', 'Newport', 'and', 'successfully', 'defended', 'it', 'in', '1909', '.', 'she', 'also', 'won', 'the', 'Scottish', 'Championships', 'singles', 'title', 'twice', '1908', 'to', '1909', '.']\n",
"head: [9, 10]\n",
"tail: [5]\n",
"relation: competition class\n",
"\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Full Training\n",
"save_dir = f\"outputs/wrapper/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}\"\n",
"print(dict(save_dir=save_dir))\n",
"model_kwargs = dict(batch_size=32, grad_accumulation=4) # For lower memory on Colab\n",
"\n",
"generator = Generator(\n",
" load_dir=\"gpt2\",\n",
" save_dir=str(Path(save_dir) / \"generator\"),\n",
" model_kwargs=model_kwargs,\n",
")\n",
"extractor = Extractor(\n",
" load_dir=\"facebook/bart-base\",\n",
" save_dir=str(Path(save_dir) / \"extractor\"),\n",
" model_kwargs=model_kwargs,\n",
")\n",
"\n",
"generator.fit(path_train, path_dev)\n",
"extractor.fit(path_train, path_dev)\n",
"path_synthetic = str(Path(save_dir) / \"synthetic.jsonl\")\n",
"labels_dev = Dataset.load(path_dev).get_labels()\n",
"labels_test = Dataset.load(path_test).get_labels()\n",
"generator.generate(labels_dev + labels_test, path_out=path_synthetic)\n",
"\n",
"extractor_final = Extractor(\n",
" load_dir=str(Path(save_dir) / \"extractor\" / \"model\"),\n",
" save_dir=str(Path(save_dir) / \"extractor_final\"),\n",
" model_kwargs=model_kwargs,\n",
")\n",
"extractor_final.fit(path_synthetic, path_dev)\n",
"\n",
"path_pred = str(Path(save_dir) / \"pred.jsonl\")\n",
"extractor_final.predict(path_in=path_test, path_out=path_pred)\n",
"results = extractor_final.score(path_pred, path_test)\n",
"print(json.dumps(results, indent=2))"
],
"metadata": {
"id": "qi5PAW5ocjfj",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "633b3678-319d-4f1d-9296-b3e47de80c47"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"{'save_dir': 'outputs/wrapper/fewrel/unseen_10_seed_0'}\n",
"{'select_model': RelationGenerator(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/data', model_name='gpt2', do_pretrain=False, encoder_name='generate', pipe_name='text-generation', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, block_size=128)}\n",
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n",
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n",
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 16/16 [00:09<00:00, 1.72it/s]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"{\n",
" \"path_pred\": \"outputs/wrapper/fewrel/unseen_10_seed_0/pred.jsonl\",\n",
" \"path_gold\": \"test.jsonl\",\n",
" \"precision\": 0.328,\n",
" \"recall\": 0.3215686274509804,\n",
" \"score\": 0.32475247524752476\n",
"}\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "-63QPp0xuDOA"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "RelationPrompt Demo.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 733937e

Please sign in to comment.