forked from declare-lab/RelationPrompt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Commit the Colab Demo as .ipynb file
- Loading branch information
1 parent
b583e2a
commit 733937e
Showing
1 changed file
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,381 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### RelationPrompt: Leveraging Prompts to Generate Synthetic Data for Zero-Shot Relation Triplet Extraction\n", | ||
"\n", | ||
"GitHub: https://github.com/declare-lab/RelationPrompt" | ||
], | ||
"metadata": { | ||
"id": "qm5jvHp3vpKT" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "VP-z5ENrR7S3", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "a7f24d70-77b4-4ede-b7a8-0e3e2a38c10e" | ||
}, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"fatal: destination path 'RelationPrompt' already exists and is not an empty directory.\n", | ||
"HEAD is now at 8ce3656 Upgrade torch version 1.9.0 -> 1.10.0\n", | ||
"File ‘model_fewrel_unseen_10_seed_0.tar’ already there; not retrieving.\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"!git clone https://github.com/declare-lab/RelationPrompt.git\n", | ||
"!cd RelationPrompt && git checkout 8ce3656\n", | ||
"!cp -a RelationPrompt/* .\n", | ||
"!wget -q -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/zero_rte_data.zip\n", | ||
"!wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_fewrel_unseen_10_seed_0.tar\n", | ||
"!tar -xf model_fewrel_unseen_10_seed_0.tar\n", | ||
"# !wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_wiki_unseen_10_seed_0.tar\n", | ||
"!unzip -nq zero_rte_data.zip\n", | ||
"!pip install -q -r requirements.txt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"#@title Data Parameters\n", | ||
"data_name = \"fewrel\" #@param [\"fewrel\", \"wiki\"]\n", | ||
"num_unseen_labels = 10 #@param [5,10,15]\n", | ||
"random_seed = 0 #@param [0,1,2,3,4]\n", | ||
"data_limit = 5000 #@param {type:\"number\"}\n", | ||
"data_dir = f\"outputs/data/splits/zero_rte/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}\"\n", | ||
"print(dict(data_dir=data_dir))" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "Gj8sD5YNdgLJ", | ||
"outputId": "08ead310-7f39-46d7-d62d-dbe8602c7dc1" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"{'data_dir': 'outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0'}\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Data Setup\n", | ||
"import json\n", | ||
"import random\n", | ||
"from pathlib import Path\n", | ||
"from wrapper import Generator, Extractor, Dataset\n", | ||
"\n", | ||
"def truncate_data(path:str, limit:int, path_out:str):\n", | ||
" # Use a subset of data for quick demo on Colab\n", | ||
" data = Dataset.load(path)\n", | ||
" random.seed(0)\n", | ||
" random.shuffle(data.sents)\n", | ||
" data.sents = data.sents[:limit]\n", | ||
" data.save(path_out)\n", | ||
"\n", | ||
"path_train = \"train.jsonl\"\n", | ||
"path_dev = \"dev.jsonl\"\n", | ||
"path_test = \"test.jsonl\"\n", | ||
"truncate_data(f\"{data_dir}/train.jsonl\", limit=data_limit, path_out=path_train)\n", | ||
"truncate_data(f\"{data_dir}/dev.jsonl\", limit=data_limit // 10, path_out=path_dev)\n", | ||
"truncate_data(f\"{data_dir}/test.jsonl\", limit=data_limit // 10, path_out=path_test)" | ||
], | ||
"metadata": { | ||
"id": "iO--Mb9nHgGG" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Data Exploration\n", | ||
"\n", | ||
"def explore_data(path: str):\n", | ||
" data = Dataset.load(path)\n", | ||
" print(\"labels:\", data.get_labels())\n", | ||
" print()\n", | ||
" for s in random.sample(data.sents, k=3):\n", | ||
" print(\"tokens:\", s.tokens)\n", | ||
" for t in s.triplets:\n", | ||
" print(\"head:\", t.head)\n", | ||
" print(\"tail:\", t.tail)\n", | ||
" print(\"relation:\", t.label)\n", | ||
" print()\n", | ||
"\n", | ||
"explore_data(path_train)" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "vw3NlKDddMIP", | ||
"outputId": "d100f0aa-58fa-454a-bedf-8166ab7150ed" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"labels: ['after a work by', 'applies to jurisdiction', 'architect', 'characters', 'child', 'constellation', 'contains administrative territorial entity', 'country', 'country of citizenship', 'country of origin', 'crosses', 'developer', 'director', 'distributed by', 'father', 'field of work', 'followed by', 'follows', 'genre', 'has part', 'head of government', 'headquarters location', 'heritage designation', 'instance of', 'instrument', 'language of work or name', 'league', 'licensed to broadcast to', 'located in or next to body of water', 'located in the administrative territorial entity', 'located on terrain feature', 'location of formation', 'manufacturer', 'member of', 'military branch', 'military rank', 'mother', 'mountain range', 'mouth of the watercourse', 'movement', 'notable work', 'occupant', 'occupation', 'operator', 'original language of film or TV show', 'part of', 'participant', 'participating team', 'performer', 'place served by transport hub', 'publisher', 'record label', 'residence', 'said to be the same as', 'screenwriter', 'sibling', 'sport', 'sports season of league or competition', 'spouse', 'subsidiary', 'successful candidate', 'tributary', 'voice type', 'winner', 'work location']\n", | ||
"\n", | ||
"tokens: ['In', 'the', 'Ulster', 'Cycle', 'of', 'Irish', 'mythology', ',', 'Lugaid', 'mac', 'Con', 'Roí', 'was', 'the', 'son', 'of', 'Cú', 'Roí', 'mac', 'Dáire', '.']\n", | ||
"head: [2, 3]\n", | ||
"tail: [16, 17]\n", | ||
"relation: characters\n", | ||
"\n", | ||
"tokens: ['Wanandi', 'was', 'a', 'leading', 'student', 'activist', 'during', 'the', '1965', '-', '66', 'in', 'Indonesia', 'when', ',', 'over', 'time', ',', 'president', 'Sukarno', 'was', 'removed', 'from', 'power', 'and', 'Soeharto', 'became', 'the', 'second', 'president', 'of', 'Indonesia', '.']\n", | ||
"head: [25]\n", | ||
"tail: [12]\n", | ||
"relation: country of citizenship\n", | ||
"\n", | ||
"tokens: ['The', 'Temple', 'of', 'Proserpina', 'or', 'Temple', 'of', 'ProserpineSome', 'theories', 'suggest', 'that', 'the', 'temple', 'was', 'a', 'Greek', 'Temple', 'dedicated', 'to', 'Persephone', ',', 'the', 'Greek', 'equivalent', 'to', 'the', 'Roman', 'Goddess', 'Proserpina', '.']\n", | ||
"head: [19]\n", | ||
"tail: [3]\n", | ||
"relation: said to be the same as\n", | ||
"\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Use Pretrained Model for Generation\n", | ||
"model = Generator(load_dir=\"gpt2\", save_dir=\"outputs/wrapper/fewrel/unseen_10_seed_0/generator\")\n", | ||
"model.generate(labels=[\"location\", \"religion\"], path_out=\"synthetic.jsonl\")\n", | ||
"explore_data(path=\"synthetic.jsonl\")" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "tUFis82oGUAS", | ||
"outputId": "23e65d65-fdce-4f38-bc9a-ddf8f9cb6fc6" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"labels: ['location', 'religion']\n", | ||
"\n", | ||
"tokens: ['In', '2007', ',', 'he', 'joined', 'a', 'group', 'of', 'artists', 'known', 'as', 'the', 'Moth', 'Boys', ',', 'an', 'annual', 'neo', '-', 'pop', 'quartet', 'that', 'plays', 'in', 'several', 'venues', 'around', 'the', 'country', 'in', 'Las', 'Vegas', '.']\n", | ||
"head: [12, 13]\n", | ||
"tail: [30, 31]\n", | ||
"relation: location\n", | ||
"\n", | ||
"tokens: ['There', 'is', 'a', 'section', 'of', 'the', 'town', 'under', '\"', 'the', 'Graziano', '\"', 'River', ',', 'a', 'channel', 'flowing', 'the', 'river', 'in', 'southwestern', 'Italy', 'from', 'the', 'island', 'of', 'Sardinia', 'to', 'Italy', '.']\n", | ||
"head: [10]\n", | ||
"tail: [26]\n", | ||
"relation: location\n", | ||
"\n", | ||
"tokens: ['In', 'August', '2012', ',', 'the', 'station', 'opened', 'on', 'its', 'regular', 'schedule', 'between', 'Minto', 'Plaza', 'in', 'Osaka', 'and', 'Keito', 'Station', 'in', 'the', 'city', 'of', 'Nara', 'in', 'Japan', '.']\n", | ||
"head: [17, 18]\n", | ||
"tail: [15]\n", | ||
"relation: location\n", | ||
"\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Use Pretrained Model for Extraction\n", | ||
"model = Extractor(load_dir=\"facebook/bart-base\", save_dir=\"outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final\")\n", | ||
"model.predict(path_in=path_test, path_out=\"pred.jsonl\")\n", | ||
"explore_data(path=\"pred.jsonl\")" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "eGxP3vVmID9W", | ||
"outputId": "c9ba9b59-faf5-4fec-c54e-2d4878176a2c" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=64, grad_accumulation=2, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stderr", | ||
"text": [ | ||
"100%|██████████| 8/8 [00:07<00:00, 1.11it/s]" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"labels: ['', 'competition class', 'location', 'member of political party', 'nominated for', 'operating system', 'original broadcaster', 'owned by', 'position played on team / speciality', 'religion']\n", | ||
"\n", | ||
"tokens: ['The', 'South', 'Bank', 'Show', 'is', 'a', 'television', 'arts', 'magazine', 'show', 'that', 'was', 'produced', 'by', 'ITV', 'between', '1978', 'and', '2010', ',', 'and', 'by', 'Sky', 'Arts', 'from', '2012', '.']\n", | ||
"head: [1, 2, 3]\n", | ||
"tail: [14]\n", | ||
"relation: original broadcaster\n", | ||
"\n", | ||
"tokens: ['Then', 'Senator', 'Neptali', 'Gonzales', ',', 'whom', 'Maceda', 'helped', ',', 'was', 'installed', 'as', 'Senate', 'President', 'from', '1992', '-', '1993', 'and', '1995', '-', '1996', 'succeeded', 'him', '.']\n", | ||
"head: [2, 3]\n", | ||
"tail: [12, 13]\n", | ||
"relation: position played on team / speciality\n", | ||
"\n", | ||
"tokens: ['In', '1908', 'she', 'won', 'the', 'singles', 'title', 'at', 'the', 'Welsh', 'Championships', 'in', 'Newport', 'and', 'successfully', 'defended', 'it', 'in', '1909', '.', 'she', 'also', 'won', 'the', 'Scottish', 'Championships', 'singles', 'title', 'twice', '1908', 'to', '1909', '.']\n", | ||
"head: [9, 10]\n", | ||
"tail: [5]\n", | ||
"relation: competition class\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stderr", | ||
"text": [ | ||
"\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Full Training\n", | ||
"save_dir = f\"outputs/wrapper/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}\"\n", | ||
"print(dict(save_dir=save_dir))\n", | ||
"model_kwargs = dict(batch_size=32, grad_accumulation=4) # For lower memory on Colab\n", | ||
"\n", | ||
"generator = Generator(\n", | ||
" load_dir=\"gpt2\",\n", | ||
" save_dir=str(Path(save_dir) / \"generator\"),\n", | ||
" model_kwargs=model_kwargs,\n", | ||
")\n", | ||
"extractor = Extractor(\n", | ||
" load_dir=\"facebook/bart-base\",\n", | ||
" save_dir=str(Path(save_dir) / \"extractor\"),\n", | ||
" model_kwargs=model_kwargs,\n", | ||
")\n", | ||
"\n", | ||
"generator.fit(path_train, path_dev)\n", | ||
"extractor.fit(path_train, path_dev)\n", | ||
"path_synthetic = str(Path(save_dir) / \"synthetic.jsonl\")\n", | ||
"labels_dev = Dataset.load(path_dev).get_labels()\n", | ||
"labels_test = Dataset.load(path_test).get_labels()\n", | ||
"generator.generate(labels_dev + labels_test, path_out=path_synthetic)\n", | ||
"\n", | ||
"extractor_final = Extractor(\n", | ||
" load_dir=str(Path(save_dir) / \"extractor\" / \"model\"),\n", | ||
" save_dir=str(Path(save_dir) / \"extractor_final\"),\n", | ||
" model_kwargs=model_kwargs,\n", | ||
")\n", | ||
"extractor_final.fit(path_synthetic, path_dev)\n", | ||
"\n", | ||
"path_pred = str(Path(save_dir) / \"pred.jsonl\")\n", | ||
"extractor_final.predict(path_in=path_test, path_out=path_pred)\n", | ||
"results = extractor_final.score(path_pred, path_test)\n", | ||
"print(json.dumps(results, indent=2))" | ||
], | ||
"metadata": { | ||
"id": "qi5PAW5ocjfj", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"outputId": "633b3678-319d-4f1d-9296-b3e47de80c47" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"{'save_dir': 'outputs/wrapper/fewrel/unseen_10_seed_0'}\n", | ||
"{'select_model': RelationGenerator(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/data', model_name='gpt2', do_pretrain=False, encoder_name='generate', pipe_name='text-generation', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, block_size=128)}\n", | ||
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n", | ||
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n", | ||
"{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stderr", | ||
"text": [ | ||
"100%|██████████| 16/16 [00:09<00:00, 1.72it/s]" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"{\n", | ||
" \"path_pred\": \"outputs/wrapper/fewrel/unseen_10_seed_0/pred.jsonl\",\n", | ||
" \"path_gold\": \"test.jsonl\",\n", | ||
" \"precision\": 0.328,\n", | ||
" \"recall\": 0.3215686274509804,\n", | ||
" \"score\": 0.32475247524752476\n", | ||
"}\n" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stderr", | ||
"text": [ | ||
"\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"" | ||
], | ||
"metadata": { | ||
"id": "-63QPp0xuDOA" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"name": "RelationPrompt Demo.ipynb", | ||
"provenance": [], | ||
"collapsed_sections": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |