From ff6800308825d499edafc1baa5e10df69580126a Mon Sep 17 00:00:00 2001 From: chiayewken Date: Fri, 15 Apr 2022 01:16:16 +0800 Subject: [PATCH] Initial code release --- encoding.py | 203 ++++++++ generation.py | 193 ++++++++ modeling.py | 300 ++++++++++++ requirements.txt | 11 + transformer_base/__init__.py | 0 transformer_base/run_clm.py | 591 ++++++++++++++++++++++ transformer_base/run_summarization.py | 679 ++++++++++++++++++++++++++ utils.py | 548 +++++++++++++++++++++ wrapper.py | 500 +++++++++++++++++++ 9 files changed, 3025 insertions(+) create mode 100644 encoding.py create mode 100644 generation.py create mode 100644 modeling.py create mode 100644 requirements.txt create mode 100644 transformer_base/__init__.py create mode 100644 transformer_base/run_clm.py create mode 100644 transformer_base/run_summarization.py create mode 100644 utils.py create mode 100644 wrapper.py diff --git a/encoding.py b/encoding.py new file mode 100644 index 0000000..c389cf8 --- /dev/null +++ b/encoding.py @@ -0,0 +1,203 @@ +from pathlib import Path +from typing import Dict, List, Tuple + +from fire import Fire +from pydantic import BaseModel +from tqdm import tqdm +from transformers import AutoTokenizer + +from transformer_base import run_summarization +from utils import RelationData, RelationSentence + + +class Encoder(BaseModel): + def encode_x(self, x: str) -> str: + raise NotImplementedError + + def encode(self, sent: RelationSentence) -> Tuple[str, str]: + raise NotImplementedError + + def decode(self, x: str, y: str) -> RelationSentence: + raise NotImplementedError + + def decode_x(self, x: str) -> str: + raise NotImplementedError + + def safe_decode(self, x: str, y: str) -> RelationSentence: + text = self.decode_x(x) + try: + s = self.decode(x=x, y=y) + except Exception as e: + s = RelationSentence( + tokens=text.split(), head=[], tail=[], label="", error=str(e), raw=y + ) + return s + + def encode_to_line(self, sent: RelationSentence) -> str: + raise NotImplementedError + + def decode_from_line(self, line: str) -> RelationSentence: + raise NotImplementedError + + def parse_line(self, line: str) -> Tuple[str, str]: + raise NotImplementedError + + +class GenerateEncoder(Encoder): + def encode_x(self, r: str) -> str: + return f"Relation : {r} ." + + def decode_x(self, text: str) -> str: + return text.split("Relation : ")[-1][:-2] + + def encode_triplet(self, sent: RelationSentence) -> str: + s, r, o = sent.as_tuple() + return f"Context : {sent.text} Head Entity : {s} , Tail Entity : {o} ." + + def decode_triplet(self, text: str, label: str) -> RelationSentence: + front, back = text.split(" Head Entity : ") + _, context = front.split("Context : ") + head, back = back.split(" , Tail Entity : ") + tail = back[:-2] + return RelationSentence.from_spans(context, head, tail, label) + + def encode_y(self, sent: RelationSentence) -> str: + return self.encode_x(sent.label) + " " + self.encode_triplet(sent) + + def decode_y(self, text: str, label: str) -> RelationSentence: + del label + front, back = text.split(" . Context : ") + label = self.decode_x(front + " .") + return self.decode_triplet("Context : " + back, label) + + def decode(self, x: str, y: str) -> RelationSentence: + r = self.decode_x(x) + sent = self.decode_y(y, r) + return sent + + def encode(self, sent: RelationSentence) -> Tuple[str, str]: + x = self.encode_x(sent.label) + y = self.encode_y(sent) + return x, y + + def decode_from_line(self, line: str) -> RelationSentence: + x, y = self.parse_line(line) + return self.decode(x, y) + + def encode_to_line(self, sent: RelationSentence) -> str: + x, y = self.encode(sent) + return y + "\n" + + def parse_line(self, line: str) -> Tuple[str, str]: + return "", line.strip() + + +class ExtractEncoder(Encoder): + def encode_x(self, text: str) -> str: + return f"Context : {text}" + + def decode_x(self, x: str) -> str: + return x.split("Context : ")[-1] + + def encode_y(self, sent: RelationSentence) -> str: + s, r, o = sent.as_tuple() + return f"Head Entity : {s} , Tail Entity : {o} , Relation : {r} ." + + def decode_y(self, x: str, y: str) -> RelationSentence: + context = self.decode_x(x) + front, label = y.split(" , Relation : ") + label = label[:-2] + front, tail = front.split(" , Tail Entity : ") + _, head = front.split("Head Entity : ") + return RelationSentence.from_spans(context, head, tail, label) + + def encode_entity_prompt(self, head: str, tail: str) -> str: + return f"Head Entity : {head} , Tail Entity : {tail} , Relation :" + + def encode(self, sent: RelationSentence) -> Tuple[str, str]: + x = self.encode_x(sent.text) + y = self.encode_y(sent) + return x, y + + def decode(self, x: str, y: str) -> RelationSentence: + return self.decode_y(x, y) + + def encode_to_line(self, sent: RelationSentence) -> str: + x, y = self.encode(sent) + return run_summarization.encode_to_line(x, y) + + def decode_from_line(self, line: str) -> RelationSentence: + x, y = self.parse_line(line) + return self.decode(x, y) + + def parse_line(self, line: str) -> Tuple[str, str]: + return run_summarization.decode_from_line(line) + + +def test_encoders( + paths: List[str] = [ + "outputs/data/zsl/wiki/unseen_5_seed_0/train.jsonl", + "outputs/data/zsl/fewrel/unseen_5_seed_0/train.jsonl", + ], + print_limit: int = 4, + encoder_names: List[str] = ["generate", "extract"], + limit: int = 1000, +): + encoders = {k: select_encoder(k) for k in encoder_names} + + for p in paths: + data = RelationData.load(Path(p)) + _, data = data.train_test_split(min(limit, len(data.sents)), random_seed=0) + + for name, e in tqdm(list(encoders.items())): + num_fail = 0 + print(dict(name=name, p=p)) + for s in data.sents: + encoded = e.encode_to_line(s) + x, y = e.parse_line(encoded) + decoded: RelationSentence = e.safe_decode(x, y) + + if decoded.as_tuple() != s.as_tuple(): + if num_fail < print_limit: + print(dict(gold=s.as_tuple(), text=s.text)) + print(dict(pred=decoded.as_tuple(), text=decoded.text)) + print(dict(x=x, y=y, e=decoded.error)) + print() + num_fail += 1 + + print(dict(success_rate=1 - (num_fail / len(data.sents)))) + print("#" * 80) + + +def select_encoder(name: str) -> Encoder: + mapping: Dict[str, Encoder] = dict( + extract=ExtractEncoder(), + generate=GenerateEncoder(), + ) + encoder = mapping[name] + return encoder + + +def test_entity_prompts( + path: str = "outputs/data/zsl/wiki/unseen_10_seed_0/test.jsonl", limit: int = 100 +): + def tokenize(text: str, tok) -> List[str]: + return tok.convert_ids_to_tokens(tok(text, add_special_tokens=False).input_ids) + + data = RelationData.load(Path(path)) + e = ExtractEncoder() + tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") + print(tokenizer) + for i, s in enumerate(tqdm(data.sents[:limit])): + head, label, tail = s.as_tuple() + x, y = e.encode(s) + prompt = e.encode_entity_prompt(head, tail) + tokens_y = tokenize(y, tokenizer) + tokens_prompt = tokenize(prompt, tokenizer) + assert tokens_y[: len(tokens_prompt)] == tokens_prompt + if i < 3: + print(tokens_y) + + +if __name__ == "__main__": + Fire() diff --git a/generation.py b/generation.py new file mode 100644 index 0000000..888bdf2 --- /dev/null +++ b/generation.py @@ -0,0 +1,193 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from fire import Fire +from torch import Tensor +from transformers import PreTrainedModel, PreTrainedTokenizerFast + +from encoding import ExtractEncoder +from utils import DynamicModel, RelationSentence, find_sublist_index + + +class TextGenerator(DynamicModel): + model: PreTrainedModel + tokenizer: PreTrainedTokenizerFast + scores: Optional[List[Tensor]] = None + max_length: int + + def tokenize(self, texts: List[str], **kwargs): + return self.tokenizer( + texts, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + **kwargs, + ).to(self.model.device) + + def run( + self, + texts: List[str], + do_sample=True, + top_k=50, + temperature=1.0, + num_return: int = 4, + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, + multi_prompt_ids: Optional[List[List[int]]] = None, + decoder_input_ids: Optional[Tensor] = None, + save_scores: bool = False, + **kwargs, + ) -> List[str]: + # https://huggingface.co/transformers/v4.7.0/main_classes/model.html#generation + tok = self.tokenizer + eos, bos = tok.eos_token_id, tok.bos_token_id + + if prompt is not None: + prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids + if prompt_ids is not None: + prompt_ids = [eos, bos] + prompt_ids + decoder_input_ids = torch.tensor([prompt_ids]) + if multi_prompt_ids is not None: + assert len(texts) == len(multi_prompt_ids) + multi_prompt_ids = [[eos, bos] + lst for lst in multi_prompt_ids] + decoder_input_ids = torch.tensor(multi_prompt_ids) + if decoder_input_ids is not None: + kwargs.update(decoder_input_ids=decoder_input_ids.to(self.model.device)) + + outputs = self.model.generate( + **self.tokenize(texts), + do_sample=do_sample, + top_k=top_k, + temperature=temperature, + num_return_sequences=num_return, + return_dict_in_generate=True, + output_scores=save_scores, + max_length=self.max_length, + **kwargs, + ) + + self.scores = None + if save_scores: + self.scores = [_ for _ in torch.stack(outputs.scores, 1).cpu()] + return self.decode(outputs.sequences) + + def decode(self, outputs) -> List[str]: + tok = self.tokenizer + texts = tok.batch_decode( + outputs, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + + # Manually remove in case we have custom special tokens + special_tokens = [tok.eos_token, tok.bos_token, tok.pad_token] + for i, t in enumerate(texts): + for token in special_tokens: + t = t.replace(token, "") + texts[i] = t + return texts + + +class LabelConstraint: + def __init__( + self, + labels: List[str], + tokenizer: PreTrainedTokenizerFast, + prefix: str = " Relation :", + ): + self.prefix: List[int] = tokenizer(prefix, add_special_tokens=False).input_ids + self.label_map: Dict[int, str] = { + tokenizer(" " + x, add_special_tokens=False).input_ids[0]: x for x in labels + } + self.tokenizer = tokenizer + + def run(self, triplet: RelationSentence, scores: Tensor) -> RelationSentence: + triplet = triplet.copy(deep=True) + assert scores.ndim == 2 + token_ids = scores.argmax(dim=-1).int().tolist() + i = find_sublist_index(token_ids, self.prefix) + if i == -1: + return triplet + + position = i + len(self.prefix) + best = "" + best_score = -1e9 + for j, label in self.label_map.items(): + score = scores[position, j].item() + if score > best_score: + best = label + best_score = score + + if triplet.label in self.label_map.values(): + assert best == triplet.label + + assert len(best) > 0 + triplet.label = best + triplet.score = best_score + return triplet + + +class TripletSearchDecoder(DynamicModel): + gen: TextGenerator + constraint: LabelConstraint + encoder: ExtractEncoder + top_k: int = 4 + + def generate(self, text: str, **kwargs) -> Tuple[str, Tensor]: + outputs = self.gen.run( + [text], + do_sample=False, + num_return=1, + num_beams=1, + save_scores=True, + **kwargs, + ) + + assert len(outputs) == 1 + assert self.gen.scores is not None + scores = torch.log_softmax(self.gen.scores[0], dim=-1) + assert scores.ndim == 2 + return outputs[0], scores + + def find_prefix_end(self, token_ids: List[str], prefix: str) -> int: + prefix_ids = self.gen.tokenizer(prefix, add_special_tokens=False).input_ids + i = find_sublist_index(token_ids, prefix_ids) + position = i + len(prefix_ids) + return position + + def branch( + self, text: str, prefix: str, prompt: Optional[str] = None, **kwargs + ) -> List[Tuple[str, float]]: + _, scores = self.generate(text, prompt=prompt, **kwargs) + token_ids = scores.argmax(dim=-1).int().tolist() + i = self.find_prefix_end(token_ids, prefix) + + pairs = [] + for j in torch.argsort(scores[i])[-self.top_k :]: + p = (prompt or "") + self.gen.decode([token_ids[:i] + [j]])[0] + pairs.append((p, scores[i, j].item())) + + return pairs + + def run(self, text: str) -> List[RelationSentence]: + x = self.encoder.encode_x(text) + outputs = [] + + for prompt_a, score_a in self.branch(x, prefix="Head Entity :"): + for prompt_b, score_b in self.branch( + x, prefix=" Tail Entity :", prompt=prompt_a + ): + output, scores = self.generate(x, prompt=prompt_b) + token_ids = token_ids = scores.argmax(dim=-1).int().tolist() + i = self.find_prefix_end(token_ids, prefix=" Relation :") + score_c = max(scores[i].tolist()) + s = self.encoder.safe_decode(x=x, y=output) + s = self.constraint.run(s, scores) + # score_c = s.score # From LabelConstraint + s.score = (score_a + score_b + score_c) / 3 + outputs.append(s) + + return outputs + + +if __name__ == "__main__": + Fire() diff --git a/modeling.py b/modeling.py new file mode 100644 index 0000000..78ca9cd --- /dev/null +++ b/modeling.py @@ -0,0 +1,300 @@ +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +from fire import Fire +from tqdm import tqdm +from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, + IntervalStrategy, Pipeline, TrainingArguments, + pipeline, set_seed) + +from encoding import select_encoder +from generation import TextGenerator +from transformer_base import run_clm, run_summarization +from utils import DynamicModel, RelationData, RelationSentence + + +class RelationModel(DynamicModel): + model_dir: str + data_dir: str + model_name: str + do_pretrain: bool + encoder_name: str + pipe_name: str + batch_size: int = 64 + grad_accumulation: int = 2 + random_seed: int = 42 + warmup_ratio: float = 0.2 + lr_pretrain: float = 3e-4 + lr_finetune: float = 3e-5 + epochs_pretrain: int = 3 + epochs_finetune: int = 5 + train_fp16: bool = True + + def fit(self, path_train: str, path_dev: Optional[str] = None): + raise NotImplementedError + + def run(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def get_lr(self) -> float: + return self.lr_pretrain if self.do_pretrain else self.lr_finetune + + def get_epochs(self) -> int: + return self.epochs_pretrain if self.do_pretrain else self.epochs_finetune + + def make_pipe(self, **kwargs) -> Pipeline: + pipe = pipeline( + self.pipe_name, + model=self.model_dir, + tokenizer=self.model_name, + device=0 if torch.cuda.is_available() else -1, + **kwargs, + ) + return pipe + + def get_encoder(self): + return select_encoder(self.encoder_name) + + def get_train_args(self, do_eval: bool) -> TrainingArguments: + return TrainingArguments( + seed=self.random_seed, + do_train=True, + do_eval=do_eval or None, # False still becomes True after parsing + overwrite_output_dir=True, + per_device_train_batch_size=self.batch_size, + gradient_accumulation_steps=self.grad_accumulation, + warmup_ratio=self.warmup_ratio, + output_dir=self.model_dir, + save_strategy=IntervalStrategy.EPOCH, + evaluation_strategy=IntervalStrategy.EPOCH + if do_eval + else IntervalStrategy.NO, + learning_rate=self.get_lr(), + num_train_epochs=self.get_epochs(), + load_best_model_at_end=True, + fp16=self.train_fp16, + ) + + +class RelationGenerator(RelationModel): + model_name: str = "gpt2" + block_size: int = 128 + encoder_name: str = "gpt_new_generate" + pipe_name: str = "text-generation" + + def fit(self, path_train: str, path_dev: Optional[str] = None): + data_args = run_clm.DataTrainingArguments( + concat_texts=False, + train_file=path_train, + validation_file=path_dev, + overwrite_cache=True, + block_size=self.block_size, + ) + train_args = self.get_train_args(do_eval=path_dev is not None) + model_args = run_clm.ModelArguments(model_name_or_path=self.model_name) + run_clm.main( + model_args=model_args, training_args=train_args, data_args=data_args + ) + + def generate( + self, relation: str, num: int, pipe: Pipeline + ) -> Tuple[List[RelationSentence], List[str]]: + set_seed(self.random_seed) + encoder = self.get_encoder() + prompt = encoder.encode_x(relation) + sents, raw = [], [] + errors = set() + + while len(sents) < num: + outputs = pipe( + [prompt], + num_return_sequences=self.batch_size, + max_length=self.block_size, + ) + for o in outputs: + raw.append(o["generated_text"] + "\n") + x, y = encoder.parse_line(raw[-1]) + try: + s = encoder.decode(x=prompt, y=y) + if s.is_valid(): + sents.append(s) + except Exception as e: + errors.add(str(e)) + + print(dict(target=num, success=len(sents), raw=len(raw))) + + assert len(sents) >= num + print(dict(prompt=prompt, success_rate=len(sents) / len(raw), errors=errors)) + return sents[:num], raw + + def run( + self, + labels: List[str], + path_out: Path, + num_samples_per_relation: int, + device: torch.device = torch.device("cuda"), + ) -> RelationData: + pipe = self.make_pipe() + sents_all, raw_all = [], [] + for relation in tqdm(labels): + sents, raw = self.generate(relation, num_samples_per_relation, pipe=pipe) + sents_all.extend(sents) + raw_all.extend(raw) + + with open(path_out, "w") as f: + f.write("".join(raw_all)) + + data = RelationData(sents=sents_all) + return data + + def decode(self, *args, **kwargs): + pass + + +class NewRelationGenerator(RelationModel): + model_name: str = "facebook/bart-base" + max_source_length: int = 128 + max_target_length: int = 128 + encoder_name: str = "new_generate" + pipe_name: str = "summarization" + + def fit(self, path_train: str, path_dev: Optional[str] = None): + kwargs = {} + + data_args = run_summarization.DataTrainingArguments( + train_file=path_train, + validation_file=path_dev, + overwrite_cache=True, + max_target_length=self.max_target_length, + max_source_length=self.max_source_length, + **kwargs, + ) + train_args = self.get_train_args(do_eval=path_dev is not None) + kwargs = { + k: v for k, v in train_args.to_dict().items() if not k.startswith("_") + } + train_args = run_summarization.Seq2SeqTrainingArguments(**kwargs) + model_args = run_summarization.ModelArguments( + model_name_or_path=self.model_name + ) + run_summarization.main( + model_args=model_args, training_args=train_args, data_args=data_args + ) + + def load_generator(self, device: torch.device) -> TextGenerator: + gen = TextGenerator( + model=AutoModelForSeq2SeqLM.from_pretrained(self.model_dir), + tokenizer=AutoTokenizer.from_pretrained(self.model_dir), + max_length=self.max_target_length, + ) + gen.model = gen.model.to(device) + return gen + + def generate( + self, relation: str, num: int, gen: TextGenerator + ) -> Tuple[List[RelationSentence], List[str]]: + set_seed(self.random_seed) + encoder = self.get_encoder() + prompt = encoder.encode_x(relation) + sents, raw = [], [] + errors = set() + + while len(sents) < num: + outputs = gen.run([prompt], num_return=self.batch_size) + for o in outputs: + raw.append(run_summarization.encode_to_line(x=prompt, y=o)) + try: + s = encoder.decode(x=prompt, y=o) + if s.is_valid(): + sents.append(s) + except Exception as e: + errors.add(str(e)) + + print(dict(target=num, success=len(sents), raw=len(raw))) + + assert len(sents) >= num + print(dict(prompt=prompt, success_rate=len(sents) / len(raw), errors=errors)) + return sents[:num], raw + + def run( + self, + labels: List[str], + path_out: Path, + num_samples_per_relation: int, + device: torch.device = torch.device("cuda"), + ) -> RelationData: + gen = self.load_generator(device=device) + sents_all, raw_all = [], [] + for relation in tqdm(labels): + sents, raw = self.generate(relation, num_samples_per_relation, gen=gen) + sents_all.extend(sents) + raw_all.extend(raw) + + with open(path_out, "w") as f: + f.write("".join(raw_all)) + + data = RelationData(sents=sents_all) + return data + + def decode(self, *args, **kwargs): + pass + + +class NewRelationExtractor(NewRelationGenerator): + encoder_name: str = "new_extract" + + @staticmethod + def gen_texts(texts: List[str], gen: TextGenerator, **kwargs): + return gen.run(texts, do_sample=False, num_return=1, **kwargs) + + def run( + self, + texts: List[str], + path_out: Path, + batch_size: int = 512, + device: torch.device = torch.device("cuda"), + ): + set_seed(self.random_seed) + encoder = self.get_encoder() + prompts = [encoder.encode_x(t) for t in texts] + gen = self.load_generator(device=device) + preds = [] + + for i in tqdm(range(0, len(texts), batch_size), desc="RelationExtractor.run"): + batch = prompts[i : i + batch_size] + outputs = self.gen_texts(batch, gen) + preds.extend(outputs) + + path_out.parent.mkdir(exist_ok=True, parents=True) + with open(path_out, "w") as f: + for x, y in zip(prompts, preds): + f.write(run_summarization.encode_to_line(x=x, y=y)) + + def decode(self, path: Path) -> RelationData: + encoder = self.get_encoder() + with open(path) as f: + sents = [encoder.safe_decode(*encoder.parse_line(line)) for line in f] + + success_rate = len([s for s in sents if s.is_valid()]) / len(sents) + print(dict(success_rate=success_rate)) + data = RelationData(sents=sents) + return data + + +def select_model(name: str, **kwargs) -> RelationModel: + mapping = dict( + generate=RelationGenerator(**kwargs), + new_generate=NewRelationGenerator(**kwargs), + new_extract=NewRelationExtractor(**kwargs), + ) + model = mapping[name] + print(dict(select_model=model)) + return model + + +if __name__ == "__main__": + Fire() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c28ce8e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch==1.9.0 +transformers==4.7.0 +datasets==1.11.0 +pandas==1.2.4 +pydantic==1.8.1 +fastavro==1.4.0 +fire==0.4.0 +nltk==3.6.2 +lxml==4.6.3 +editdistance==0.5.3 +seqeval==1.2.2 diff --git a/transformer_base/__init__.py b/transformer_base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/transformer_base/run_clm.py b/transformer_base/run_clm.py new file mode 100644 index 0000000..bbd0227 --- /dev/null +++ b/transformer_base/run_clm.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=causal-lm +Adapted from: https://github.com/huggingface/transformers/blob/v4.7.0/examples/pytorch/language-modeling/run_clm.py +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import logging +import math +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import transformers +from datasets import load_dataset +from transformers import (CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, AutoModelForCausalLM, AutoTokenizer, + HfArgumentParser, Trainer, TrainingArguments, + default_data_collator, set_seed) +from transformers.testing_utils import CaptureLogger +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.7.0") + +logger = logging.getLogger(__name__) + + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={ + "help": "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES) + }, + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + def __post_init__(self): + if self.config_overrides is not None and ( + self.config_name is not None or self.model_name_or_path is not None + ): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a text file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + + block_size: Optional[int] = field( + default=None, + metadata={ + "help": "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}, + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + concat_texts: bool = field( + default=True, + metadata={"help": "Concatenate all lines from dataset and draw chunks"}, + ) + + tokenizer_kwargs: Optional[dict] = field( + default=None, + metadata={"help": "Extra keyword arguments to initialize tokenizer"}, + ) + + def __post_init__(self): + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + ): + raise ValueError( + "Need either a dataset name or a training/validation file." + ) + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in [ + "csv", + "json", + "txt", + ], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in [ + "csv", + "json", + "txt", + ], "`validation_file` should be a csv, a json or a txt file." + + +def main( + model_args: ModelArguments = None, + data_args: DataTrainingArguments = None, + training_args: TrainingArguments = None, +): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + if model_args is None or data_args is None or training_args is None: + print("Using HfArgumentParser") + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, TrainingArguments) + ) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if training_args.should_log: + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif ( + last_checkpoint is not None and training_args.resume_from_checkpoint is None + ): + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + ) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = ( + data_args.train_file.split(".")[-1] + if data_args.train_file is not None + else data_args.validation_file.split(".")[-1] + ) + if extension == "txt": + extension = "text" + datasets = load_dataset( + extension, data_files=data_files, cache_dir=model_args.cache_dir + ) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, **config_kwargs + ) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + + tokenizer_kwargs = { + "cache_dir": model_args.cache_dir, + "use_fast": model_args.use_fast_tokenizer, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if data_args.tokenizer_kwargs: + tokenizer_kwargs.update(**data_args.tokenizer_kwargs) + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, **tokenizer_kwargs + ) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, **tokenizer_kwargs + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + print(tokenizer) + + if model_args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + else: + model = AutoModelForCausalLM.from_config(config) + n_params = sum( + dict((p.data_ptr(), p.numel()) for p in model.parameters()).values() + ) + logger.info( + f"Training new model from scratch - Total size={n_params/2**20:.2f}M params" + ) + + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = datasets["train"].column_names + else: + column_names = datasets["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function + tok_logger = transformers.utils.logging.get_logger( + "transformers.tokenization_utils_base" + ) + if not data_args.concat_texts: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize_function(examples): + with CaptureLogger(tok_logger) as cl: + output = tokenizer(examples[text_column_name]) + + if not data_args.concat_texts: + output = tokenizer( + examples[text_column_name], + truncation=True, + padding="max_length", + max_length=data_args.block_size, + ) + + # clm input could be much much longer than block_size + if "Token indices sequence length is longer than the" in cl.out: + tok_logger.warning( + "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." + ) + return output + + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if data_args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) + block_size = 1024 + else: + if data_args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(data_args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + if not data_args.concat_texts: + result = examples + result["labels"] = result["input_ids"].copy() + return result + + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = lm_datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = lm_datasets["validation"] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + # Data collator will default to DataCollatorWithPadding, so we change it. + data_collator=default_data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + metrics = trainer.evaluate() + + max_eval_samples = ( + data_args.max_eval_samples + if data_args.max_eval_samples is not None + else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.push_to_hub: + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "tasks": "text-generation", + } + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs[ + "dataset" + ] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + trainer.push_to_hub(**kwargs) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/transformer_base/run_summarization.py b/transformer_base/run_summarization.py new file mode 100644 index 0000000..7bd5e09 --- /dev/null +++ b/transformer_base/run_summarization.py @@ -0,0 +1,679 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence. +Adapted from: https://github.com/huggingface/transformers/blob/v4.7.0/examples/pytorch/summarization/run_summarization.py +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import json +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional, Tuple + +import transformers +from datasets import load_dataset +from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, + DataCollatorForSeq2Seq, HfArgumentParser, + Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.7.0") + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models" + } + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + text_column: Optional[str] = field( + default=None, + metadata={ + "help": "The name of the column in the datasets containing the full texts (for summarization)." + }, + ) + summary_column: Optional[str] = field( + default=None, + metadata={ + "help": "The name of the column in the datasets containing the summaries (for summarization)." + }, + ) + train_file: Optional[str] = field( + default=None, + metadata={"help": "The input training data file (a jsonlines or csv file)."}, + ) + validation_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " + "(a jsonlines or csv file)." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (rouge) on " + "(a jsonlines or csv file)." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_source_length: Optional[int] = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + max_target_length: Optional[int] = field( + default=128, + metadata={ + "help": "The maximum total sequence length for target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + val_max_target_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." + "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " + "during ``evaluate`` and ``predict``." + }, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to model maximum sentence length. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " + "efficient on GPU but very bad for TPU." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " + "which is used during ``evaluate`` and ``predict``." + }, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." + }, + ) + source_prefix: Optional[str] = field( + default=None, + metadata={ + "help": "A prefix to add before every source text (useful for T5 models)." + }, + ) + + tokenizer_kwargs: Optional[dict] = field( + default=None, + metadata={"help": "Extra keyword arguments to initialize tokenizer"}, + ) + + def __post_init__(self): + if ( + self.dataset_name is None + and self.train_file is None + and self.validation_file is None + ): + raise ValueError( + "Need either a dataset name or a training/validation file." + ) + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in [ + "csv", + "json", + ], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in [ + "csv", + "json", + ], "`validation_file` should be a csv or a json file." + if self.val_max_target_length is None: + self.val_max_target_length = self.max_target_length + + +summarization_name_mapping = { + "amazon_reviews_multi": ("review_body", "review_title"), + "big_patent": ("description", "abstract"), + "cnn_dailymail": ("article", "highlights"), + "orange_sum": ("text", "summary"), + "pn_summary": ("article", "summary"), + "psc": ("extract_text", "summary_text"), + "samsum": ("dialogue", "summary"), + "thaisum": ("body", "summary"), + "xglue": ("news_body", "news_title"), + "xsum": ("document", "summary"), + "wiki_summary": ("article", "highlights"), +} + + +def main( + model_args: ModelArguments = None, + data_args: DataTrainingArguments = None, + training_args: Seq2SeqTrainingArguments = None, +): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + if model_args is None or data_args is None or training_args is None: + print("Using HfArgumentParser") + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments) + ) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if training_args.should_log: + transformers.utils.logging.set_verbosity_info() + logger.info(f"Training/evaluation parameters {training_args}") + + if data_args.source_prefix is None and model_args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " + "`--source_prefix 'summarize: ' `" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif ( + last_checkpoint is not None and training_args.resume_from_checkpoint is None + ): + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files this script will use the first column for the full texts and the second column for the + # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + datasets = load_dataset( + extension, data_files=data_files, cache_dir=model_args.cache_dir + ) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name + if model_args.config_name + else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + tokenizer_kwargs = data_args.tokenizer_kwargs or {} + print(dict(tokenizer_kwargs=tokenizer_kwargs)) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name + if model_args.tokenizer_name + else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + **tokenizer_kwargs, + ) + print(dict(tokenizer=tokenizer)) + + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + model.resize_token_embeddings(len(tokenizer)) + + if model.config.decoder_start_token_id is None: + raise ValueError( + "Make sure that `config.decoder_start_token_id` is correctly defined" + ) + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + column_names = datasets["train"].column_names + elif training_args.do_eval: + column_names = datasets["validation"].column_names + elif training_args.do_predict: + column_names = datasets["test"].column_names + else: + logger.info( + "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`." + ) + return + + # Get the column names for input/target. + dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) + if data_args.text_column is None: + text_column = ( + dataset_columns[0] if dataset_columns is not None else column_names[0] + ) + else: + text_column = data_args.text_column + if text_column not in column_names: + raise ValueError( + f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" + ) + if data_args.summary_column is None: + summary_column = ( + dataset_columns[1] if dataset_columns is not None else column_names[1] + ) + else: + summary_column = data_args.summary_column + if summary_column not in column_names: + raise ValueError( + f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + padding = "max_length" if data_args.pad_to_max_length else False + + if training_args.label_smoothing_factor > 0 and not hasattr( + model, "prepare_decoder_input_ids_from_labels" + ): + logger.warning( + "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" + f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" + ) + + def preprocess_function(examples): + inputs = examples[text_column] + targets = examples[summary_column] + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer( + inputs, + max_length=data_args.max_source_length, + padding=padding, + truncation=True, + ) + + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer( + targets, max_length=max_target_length, padding=padding, truncation=True + ) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] + for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_eval: + max_target_length = data_args.val_max_target_length + if "validation" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets["validation"] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + if "test" not in datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = datasets["test"] + if data_args.max_predict_samples is not None: + predict_dataset = predict_dataset.select( + range(data_args.max_predict_samples) + ) + predict_dataset = predict_dataset.map( + preprocess_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + # Data collator + label_pad_token_id = ( + -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if training_args.fp16 else None, + ) + + # Initialize our Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + + metrics = trainer.evaluate( + max_length=data_args.val_max_target_length, + num_beams=data_args.num_beams, + metric_key_prefix="eval", + ) + max_eval_samples = ( + data_args.max_eval_samples + if data_args.max_eval_samples is not None + else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Predict ***") + + predict_results = trainer.predict( + predict_dataset, + metric_key_prefix="predict", + max_length=data_args.val_max_target_length, + num_beams=data_args.num_beams, + ) + metrics = predict_results.metrics + max_predict_samples = ( + data_args.max_predict_samples + if data_args.max_predict_samples is not None + else len(predict_dataset) + ) + metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + if trainer.is_world_process_zero(): + if training_args.predict_with_generate: + predictions = tokenizer.batch_decode( + predict_results.predictions, + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + ) + predictions = [pred.strip() for pred in predictions] + output_prediction_file = os.path.join( + training_args.output_dir, "generated_predictions.txt" + ) + with open(output_prediction_file, "w") as writer: + writer.write("\n".join(predictions)) + + if training_args.push_to_hub: + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "tasks": "summarization", + } + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs[ + "dataset" + ] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + trainer.push_to_hub(**kwargs) + + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +def encode_to_line(x: str, y: str) -> str: + # Refer to original transformers readme + text = json.dumps(dict(text=x, summary=y)) + "\n" + assert decode_from_line(text) == (x, y) + return text + + +def decode_from_line(text: str) -> Tuple[str, str]: + d = json.loads(text) + return d["text"], d["summary"] + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..5bb6098 --- /dev/null +++ b/utils.py @@ -0,0 +1,548 @@ +import hashlib +import json +import os +import shutil +import time +from collections import Counter +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple, Union + +import numpy as np +import pandas as pd +from fire import Fire +from pydantic import BaseModel +from pydantic.main import Extra +from tqdm import tqdm + +Span = Tuple[int, int] +BasicValue = Union[str, int, bool, float] + + +def train_test_split(*args, **kwargs) -> list: + raise NotImplementedError + + +def find_sublist_index(items: list, query: list): + length = len(query) + for i in range(len(items) - length + 1): + if items[i : i + length] == query: + return i + return -1 + + +def test_find_sublist_query(): + items = [1, 6, 3, 5, 7] + print(dict(items=items)) + for query in [[6], [7], [6, 3], [3, 5, 7], [7, 5]]: + print(dict(query=query, i=find_sublist_index(items, query))) + + +def find_sublist_indices(items: list, query: list) -> List[int]: + i = find_sublist_index(items, query) + if i == -1: + return [] + return list(range(i, i + len(query))) + + +def test_find_sublist_indices(): + items = [1, 6, 3, 5, 7] + assert find_sublist_indices(items, [6, 3, 5]) == [1, 2, 3] + print(dict(test_find_sublist_indices=True)) + + +class WikiProperty(BaseModel): + """ + # https://query.wikidata.org + # All properties with descriptions and aliases and types + + SELECT ?p ?pType ?pLabel ?pDescription ?pAltLabel WHERE { + ?p wikibase:propertyType ?pType . + SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". } + } + ORDER BY ASC(xsd:integer(STRAFTER(STR(?p), 'P'))) + """ + + p: str + pType: str + pLabel: str + pDescription: str + pAltLabel: str + + @property + def id(self) -> str: + return self.p.split("/")[-1] + + @property + def aliases(self) -> List[str]: + names = [n.strip() for n in self.pAltLabel.split(",")] + return sorted(set(names)) + + +def load_wiki_relation_map(path: str) -> Dict[str, WikiProperty]: + df = pd.read_csv(path) + props = [WikiProperty(**r) for r in df.to_dict(orient="records")] + return {p.id: p for p in props} + + +def load_label_to_properties( + path: str, use_alias: bool = True +) -> Dict[str, WikiProperty]: + relation_map = load_wiki_relation_map(path) + mapping = {} + for p in relation_map.values(): + if not p.pLabel in mapping.keys(): + mapping[p.pLabel] = p + if use_alias: + for p in relation_map.values(): + for a in p.aliases: + if a not in mapping.keys(): + mapping[a] = p + return mapping + + +def test_load_wiki(): + relation_map = load_wiki_relation_map("data/wiki_properties.csv") + for k, v in list(relation_map.items())[:3]: + print(dict(k=k, v=v, aliases=v.aliases)) + + +class DynamicModel(BaseModel): + class Config: + arbitrary_types_allowed = True + validate_assignment = True + + +class StrictModel(BaseModel): + class Config: + extra = Extra.forbid + frozen = True + validate_assignment = True + + +def compute_macro_PRF( + predicted_idx: np.ndarray, gold_idx: np.ndarray, i=-1, empty_label=None +) -> Tuple[float, float, float]: + # https://github.com/dinobby/ZS-BERT/blob/master/model/evaluation.py + """ + This evaluation function follows work from Sorokin and Gurevych(https://www.aclweb.org/anthology/D17-1188.pdf) + code borrowed from the following link: + https://github.com/UKPLab/emnlp2017-relation-extraction/blob/master/relation_extraction/evaluation/metrics.py + """ + if i == -1: + i = len(predicted_idx) + + complete_rel_set = set(gold_idx) - {empty_label} + avg_prec = 0.0 + avg_rec = 0.0 + + for r in complete_rel_set: + r_indices = predicted_idx[:i] == r + tp = len((predicted_idx[:i][r_indices] == gold_idx[:i][r_indices]).nonzero()[0]) + tp_fp = len(r_indices.nonzero()[0]) + tp_fn = len((gold_idx == r).nonzero()[0]) + prec = (tp / tp_fp) if tp_fp > 0 else 0 + rec = tp / tp_fn + avg_prec += prec + avg_rec += rec + f1 = 0.0 + avg_prec = avg_prec / len(set(predicted_idx[:i])) + avg_rec = avg_rec / len(complete_rel_set) + if (avg_rec + avg_prec) > 0: + f1 = 2.0 * avg_prec * avg_rec / (avg_prec + avg_rec) + + return avg_prec, avg_rec, f1 + + +def test_compute_prf(): + a = np.array([0, 0, 0, 0, 0]) + b = np.array([0, 0, 1, 1, 0]) + print(compute_macro_PRF(a, b)) + + +def glob_rmtree(folder: str, pattern: str, verbose=True): + for path in Path(folder).glob(pattern): + shutil.rmtree(path) + if verbose: + print(dict(rmtree=path)) + + +def test_glob_rmtree(): + folder = "tmp/test_glob_rmtree" + Path(folder).mkdir(exist_ok=False, parents=True) + glob_rmtree("tmp", "test_glob*") + + +def hash_text(x: str) -> str: + return hashlib.md5(x.encode()).hexdigest() + + +def check_overlap(a: Span, b: Span) -> bool: + # Assumes end in (start, end) is exclusive like python slicing + return ( + a[0] <= b[0] < a[1] + or a[0] <= b[1] - 1 < a[1] + or b[0] <= a[0] < b[1] + or b[0] <= a[1] - 1 < b[1] + ) + + +class RelationSentence(BaseModel): + tokens: List[str] + head: List[int] + tail: List[int] + label: str + head_id: str = "" + tail_id: str = "" + label_id: str = "" + error: str = "" + raw: str = "" + score: float = 0.0 + zerorc_included: bool = True + + def as_tuple(self) -> Tuple[str, str, str]: + head = " ".join([self.tokens[i] for i in self.head]) + tail = " ".join([self.tokens[i] for i in self.tail]) + return head, self.label, tail + + def as_line(self) -> str: + return self.json() + "\n" + + def is_valid(self) -> bool: + for x in [self.tokens, self.head, self.tail, self.label]: + if len(x) == 0: + return False + for x in [self.head, self.tail]: + if -1 in x: + return False + return True + + @property + def text(self) -> str: + return " ".join(self.tokens) + + @classmethod + def from_spans(cls, text: str, head: str, tail: str, label: str, strict=True): + tokens = text.split() + sent = cls( + tokens=tokens, + head=find_span(head, tokens), + tail=find_span(tail, tokens), + label=label, + ) + if strict: + assert sent.is_valid(), (head, label, tail, text) + return sent + + def as_marked_text(self) -> str: + tokens = list(self.tokens) + for i, template in [ + (self.head[0], "[H {}"), + (self.head[-1], "{} ]"), + (self.tail[0], "[T {}"), + (self.tail[-1], "{} ]"), + ]: + tokens[i] = template.format(tokens[i]) + return " ".join(tokens) + + +def align_span_to_tokens(span: str, tokens: List[str]) -> Tuple[int, int]: + # Eg align("John R. Allen, Jr.", ['John', 'R.', 'Allen', ',', 'Jr.']) + char_word_map = {} + num_chars = 0 + for i, w in enumerate(tokens): + for _ in w: + char_word_map[num_chars] = i + num_chars += 1 + char_word_map[num_chars] = len(tokens) + + query = span.replace(" ", "") + text = "".join(tokens) + assert query in text + i = text.find(query) + start = char_word_map[i] + end = char_word_map[i + len(query) - 1] + assert 0 <= start <= end + return start, end + 1 + + +def test_align_span( + span: str = "John R. Allen, Jr.", + tokens=("The", "John", "R.", "Allen", ",", "Jr.", "is", "here"), +): + start, end = align_span_to_tokens(span, tokens) + print(dict(start=start, end=end, span=tokens[start:end])) + + +def find_span(span: str, tokens: List[str]) -> List[int]: + if span == "": + return [] + start = find_sublist_index(tokens, span.split()) + if start >= 0: + return [start + i for i in range(len(span.split()))] + else: + start, end = align_span_to_tokens(span, tokens) + return list(range(start, end)) + + +def test_find_span( + span: str = "Hohenzollern", + text: str = "Princess of Hohenzollern-Sigmaringen ( born 26 March 1949", +): + tokens = text.split() + indices = find_span(span, tokens) + print(dict(test_find_span=[tokens[i] for i in indices])) + + +class QualifierSentence(RelationSentence): + qualifier: str = "" + qualifier_id: str + value: List[int] + value_type: str + + def as_tuple(self) -> Tuple[str, str, str, str, str]: + head = " ".join([self.tokens[i] for i in self.head]) + tail = " ".join([self.tokens[i] for i in self.tail]) + value = " ".join([self.tokens[i] for i in self.value]) + return head, self.label, tail, self.qualifier, value + + +class RelationData(BaseModel): + sents: List[RelationSentence] + + @classmethod + def load(cls, path: Path): + with open(path) as f: + lines = f.readlines() + sents = [ + RelationSentence(**json.loads(x)) + for x in tqdm(lines, desc="RelationData.load") + ] + return cls(sents=sents) + + def save(self, path: Path): + path.parent.mkdir(exist_ok=True, parents=True) + with open(path, "w") as f: + f.write("".join([s.as_line() for s in self.sents])) + + @property + def unique_labels(self) -> List[str]: + return sorted(set([s.label for s in self.sents])) + + def train_test_split( + self, test_size: Union[int, float], random_seed: int, by_label: bool = False + ): + if by_label: + labels_train, labels_test = train_test_split( + self.unique_labels, test_size=test_size, random_state=random_seed + ) + train = [s for s in self.sents if s.label in labels_train] + test = [s for s in self.sents if s.label in labels_test] + else: + groups = self.to_sentence_groups() + keys_train, keys_test = train_test_split( + sorted(groups.keys()), test_size=test_size, random_state=random_seed + ) + train = [s for k in keys_train for s in groups[k]] + test = [s for k in keys_test for s in groups[k]] + + # Enforce no sentence overlap + texts_test = set([s.text for s in test]) + train = [s for s in train if s.text not in texts_test] + + data_train = RelationData(sents=train) + data_test = RelationData(sents=test) + if by_label: + assert len(data_test.unique_labels) == test_size + assert not set(data_train.unique_labels).intersection( + data_test.unique_labels + ) + + info = dict( + sents_train=len(data_train.sents), + sents_test=len(data_test.sents), + labels_train=len(data_train.unique_labels), + labels_test=len(data_test.unique_labels), + ) + print(json.dumps(info, indent=2)) + return data_train, data_test + + def to_sentence_groups(self) -> Dict[str, List[RelationSentence]]: + groups = {} + for s in self.sents: + groups.setdefault(s.text, []).append(s) + return groups + + def to_label_groups(self) -> Dict[str, List[RelationSentence]]: + groups = {} + for s in self.sents: + groups.setdefault(s.label, []).append(s) + return groups + + def filter_group_sizes(self, min_size: int = 0, max_size: int = 999): + groups = self.to_sentence_groups() + sents = [ + s + for k, lst in groups.items() + for s in lst + if min_size <= len(lst) <= max_size + ] + return RelationData(sents=sents) + + def filter_errors(self): + def check_valid_span(span: List[int]) -> bool: + start = sorted(span)[0] + end = sorted(span)[-1] + 1 + return span == list(range(start, end)) + + sents = [] + for s in self.sents: + if s.is_valid(): + if check_valid_span(s.head) and check_valid_span(s.tail): + sents.append(s) + + print(dict(filter_errors_success=len(sents) / len(self.sents))) + return RelationData(sents=sents) + + def analyze(self, header: Optional[str] = None): + labels = self.unique_labels + groups = self.to_sentence_groups() + spans = [] + words = [] + for s in self.sents: + head, label, tail = s.as_tuple() + spans.append(head) + spans.append(tail) + words.extend(s.tokens) + info = dict( + header=header, + sents=len(self.sents), + labels=str([len(labels), labels]), + unique_texts=len(groups.keys()), + unique_spans=len(set(spans)), + unique_words=len(set(words)), + group_sizes=str(Counter([len(lst) for lst in groups.values()])), + ) + print(json.dumps(info, indent=2)) + return info + + +def wiki_uri_to_id(uri: str) -> str: + i = uri.split("/")[-1] + if i[0] in "QP" and i[1:].isdigit(): + return i + else: + return "" + + +def split_common_prefix(texts: List[str]) -> Tuple[str, List[str]]: + end = 0 + i_max = min(map(len, texts)) + for i in range(i_max): + if len(set([t[i] for t in texts])) > 1: + break + end += 1 + + prefix = texts[0][:end] + texts = [t[end:] for t in texts] + return prefix, texts + + +def delete_checkpoints( + folder: str = ".", pattern="**/checkpoint*", delete: bool = True +): + for p in Path(folder).glob(pattern): + if (p.parent / "config.json").exists(): + print(p) + if delete: + if p.is_dir(): + shutil.rmtree(p) + elif p.is_file(): + os.remove(p) + else: + raise ValueError("Unknown Type") + + +class Timer(BaseModel): + name: str + start: float = 0 + + def __enter__(self): + self.start = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + duration = round(time.time() - self.start, 3) + print(dict(name=self.name, duration=duration)) + + +def test_timer(interval: int = 2): + with Timer(name="test_timer"): + time.sleep(interval) + + +def sorted_glob(folder: str, pattern: str) -> List[Path]: + # Best practice to be deterministic and avoid weird behavior + return sorted(Path(folder).glob(pattern)) + + +def test_sorted_glob(): + for path in sorted_glob("outputs/data/zsl/wiki", "*/test.jsonl"): + print(path) + + +def mark_wiki_entity(edge): + e1 = edge["left"] + e2 = edge["right"] + return e1, e2 + + +def mark_fewrel_entity(edge): + e1 = edge["h"][2][0] + e2 = edge["t"][2][0] + return e1, e2 + + +class WikiDataset: + def __init__(self, mode, data, pid2vec, property2idx): + assert mode in ["train", "dev", "test"] + self.mode = mode + self.data = data + self.pid2vec = pid2vec + self.property2idx = property2idx + self.len = len(self.data) + + def load_edges( + self, i: int, label_ids: Optional[Set[str]] = None + ) -> List[RelationSentence]: + g = self.data[i] + tokens = g["tokens"] + sents = [] + for j in range(len(g["edgeSet"])): + property_id = g["edgeSet"][j]["kbID"] + edge = g["edgeSet"][j] + head, tail = mark_wiki_entity(edge) + if label_ids and property_id not in label_ids: + continue + s = RelationSentence( + tokens=tokens, head=head, tail=tail, label="", label_id=property_id + ) + sents.append(s) + return sents + + def __getitem__(self, item: int) -> RelationSentence: + # The ZS-BERT setting is throw away all except first edge + return self.load_edges(item)[0] + + def __len__(self): + return self.len + + +if __name__ == "__main__": + """ + python new_utils.py analyze_relation_data --path data/relations/trex/100000.jsonl + """ + test_find_sublist_query() + test_load_wiki() + test_compute_prf() + test_glob_rmtree() + test_find_sublist_indices() + Fire() diff --git a/wrapper.py b/wrapper.py new file mode 100644 index 0000000..5eecdc5 --- /dev/null +++ b/wrapper.py @@ -0,0 +1,500 @@ +import json +import random +from collections import Counter +from pathlib import Path +from typing import List + +import torch +from fire import Fire +from pydantic.main import BaseModel +from tqdm import tqdm + +from generation import LabelConstraint, TripletSearchDecoder +from modeling import (NewRelationExtractor, RelationGenerator, RelationModel, + select_model) +from utils import (RelationSentence, WikiDataset, delete_checkpoints, + load_wiki_relation_map, mark_fewrel_entity) + + +def safe_divide(a: float, b: float) -> float: + if a == 0 or b == 0: + return 0 + return a / b + + +class Sentence(BaseModel): + triplets: List[RelationSentence] + + @property + def tokens(self) -> List[str]: + return self.triplets[0].tokens + + @property + def text(self) -> str: + return " ".join(self.tokens) + + def assert_valid(self): + assert len(self.tokens) > 0 + for t in self.triplets: + assert t.text == self.text + assert len(t.head) > 0 + assert len(t.tail) > 0 + assert len(t.label) > 0 + + +class Dataset(BaseModel): + sents: List[Sentence] + + def get_labels(self) -> List[str]: + return sorted(set(t.label for s in self.sents for t in s.triplets)) + + @classmethod + def load(cls, path: str): + with open(path) as f: + sents = [Sentence(**json.loads(line)) for line in f] + return cls(sents=sents) + + def save(self, path: str): + Path(path).parent.mkdir(exist_ok=True, parents=True) + with open(path, "w") as f: + for s in self.sents: + f.write(s.json() + "\n") + + @classmethod + def load_fewrel(cls, path: str, path_properties: str = "data/wiki_properties.csv"): + relation_map = load_wiki_relation_map(path_properties) + groups = {} + + with open(path) as f: + for i, lst in tqdm(json.load(f).items()): + for raw in lst: + head, tail = mark_fewrel_entity(raw) + t = RelationSentence( + tokens=raw["tokens"], + head=head, + tail=tail, + label=relation_map[i].pLabel, + label_id=i, + ) + groups.setdefault(t.text, []).append(t) + + sents = [Sentence(triplets=lst) for lst in groups.values()] + return cls(sents=sents) + + @classmethod + def load_wiki(cls, path: str, path_properties: str = "data/wiki_properties.csv"): + relation_map = load_wiki_relation_map(path_properties) + sents = [] + with open(path) as f: + ds = WikiDataset( + mode="train", data=json.load(f), pid2vec=None, property2idx=None + ) + for i in tqdm(range(len(ds))): + triplets = ds.load_edges(i) + triplets = [t for t in triplets if t.label_id in relation_map.keys()] + for t in triplets: + t.label = relation_map[t.label_id].pLabel + if triplets: + # ZSBERT only includes first triplet in each sentence + for t in triplets: + t.zerorc_included = False + triplets[0].zerorc_included = True + + s = Sentence(triplets=triplets) + sents.append(s) + + data = cls(sents=sents) + counter = Counter(t.label for s in data.sents for t in s.triplets) + threshold = sorted(counter.values())[-113] # Based on ZSBERT data stats + labels = [k for k, v in counter.items() if v >= threshold] + data = data.filter_labels(labels) + return data + + def filter_labels(self, labels: List[str]): + label_set = set(labels) + sents = [] + for s in self.sents: + triplets = [t for t in s.triplets if t.label in label_set] + if triplets: + s = s.copy(deep=True) + s.triplets = triplets + sents.append(s) + return Dataset(sents=sents) + + def train_test_split(self, test_size: int, random_seed: int, by_label: bool): + random.seed(random_seed) + + if by_label: + labels = self.get_labels() + labels_test = random.sample(labels, k=test_size) + labels_train = sorted(set(labels) - set(labels_test)) + sents_train = self.filter_labels(labels_train).sents + sents_test = self.filter_labels(labels_test).sents + else: + sents_train = [s for s in self.sents] + sents_test = random.sample(self.sents, k=test_size) + + banned = set(s.text for s in sents_test) # Prevent sentence overlap + sents_train = [s for s in sents_train if s.text not in banned] + assert len(self.sents) == len(sents_train) + len(sents_test) + return Dataset(sents=sents_train), Dataset(sents=sents_test) + + def analyze(self): + info = dict( + sents=len(self.sents), + unique_texts=len(set(s.triplets[0].text for s in self.sents)), + lengths=str(Counter(len(s.triplets) for s in self.sents)), + labels=len(self.get_labels()), + ) + print(json.dumps(info, indent=2)) + + +def write_data_splits( + path_in: str, + mode: str, + folder_out: str = "outputs/data/splits/zero_rte", + num_dev_labels: int = 5, + num_test_labels: List[int] = [5, 10, 15], + seeds: List[int] = [0, 1, 2, 3, 4], +): + for n in num_test_labels: + for s in seeds: + if mode == "fewrel": + data = Dataset.load_fewrel(path_in) + elif mode == "wiki": + data = Dataset.load_wiki(path_in) + else: + raise ValueError() + + train, test = data.train_test_split( + test_size=n, random_seed=s, by_label=True + ) + train, dev = train.train_test_split( + test_size=num_dev_labels, random_seed=s, by_label=True + ) + del data + + for key, data in dict(train=train, dev=dev, test=test).items(): + name = f"unseen_{n}_seed_{s}" + path = Path(folder_out) / Path(path_in).stem / name / f"{key}.jsonl" + data.save(str(path)) + print(dict(key=key, labels=len(data.get_labels()), path=path)) + + +class Generator(BaseModel): + load_dir: str + save_dir: str + num_gen_per_label: int = 250 + model_name: str = "generate" + encoder_name: str = "generate" + model_kwargs: dict = {} + + def get_model(self) -> RelationModel: + model = select_model( + name=self.model_name, + encoder_name=self.encoder_name, + model_dir=str(Path(self.save_dir) / "model"), + model_name=self.load_dir, + data_dir=str(Path(self.save_dir) / "data"), + do_pretrain=False, + **self.model_kwargs, + ) + return model + + def write_data(self, data: Dataset, name: str) -> str: + model = self.get_model() + path_out = Path(model.data_dir) / f"{name}.txt" + path_out.parent.mkdir(exist_ok=True, parents=True) + encoder = model.get_encoder() + lines = [encoder.encode_to_line(t) for s in data.sents for t in s.triplets] + random.seed(model.random_seed) + random.shuffle(lines) + with open(path_out, "w") as f: + f.write("".join(lines)) + return str(path_out) + + def fit(self, path_train: str, path_dev: str): + model = self.get_model() + if Path(model.model_dir).exists(): + return + + data_train = Dataset.load(path_train) + data_dev = Dataset.load(path_dev) + path_train = self.write_data(data_train, "train") + path_dev = self.write_data(data_dev, "dev") + model.fit(path_train=path_train, path_dev=path_dev) + delete_checkpoints(model.model_dir) + + def generate(self, labels: List[str], path_out: str): + if Path(path_out).exists(): + return + + model = self.get_model() + pipe = model.make_pipe() + groups = {} + assert isinstance(model, RelationGenerator) + for relation in tqdm(labels): + triplets, raw = model.generate(relation, self.num_gen_per_label, pipe=pipe) + for t in triplets: + groups.setdefault(t.text, []).append(t) + + sents = [Sentence(triplets=lst) for lst in groups.values()] + data = Dataset(sents=sents) + data.save(path_out) + + +class Extractor(BaseModel): + load_dir: str + save_dir: str + model_name: str = "new_extract" + encoder_name: str = "extract" + search_threshold: float = -0.9906 + model_kwargs: dict = {} + + def get_model(self) -> RelationModel: + model = select_model( + name=self.model_name, + encoder_name=self.encoder_name, + model_dir=str(Path(self.save_dir) / "model"), + model_name=self.load_dir, + data_dir=str(Path(self.save_dir) / "data"), + do_pretrain=False, + **self.model_kwargs, + ) + return model + + def write_data(self, data: Dataset, name: str) -> str: + model = self.get_model() + path_out = Path(model.data_dir) / f"{name}.json" + path_out.parent.mkdir(exist_ok=True, parents=True) + encoder = model.get_encoder() + lines = [encoder.encode_to_line(t) for s in data.sents for t in s.triplets] + random.seed(model.random_seed) + random.shuffle(lines) + with open(path_out, "w") as f: + f.write("".join(lines)) + return str(path_out) + + def fit(self, path_train: str, path_dev: str): + model = self.get_model() + if Path(model.model_dir).exists(): + return + + data_train = Dataset.load(path_train) + data_train = Dataset.load(path_train) + data_dev = Dataset.load(path_dev) + path_train = self.write_data(data_train, "train") + path_dev = self.write_data(data_dev, "dev") + model.fit(path_train=path_train, path_dev=path_dev) + delete_checkpoints(model.model_dir) + + def predict(self, path_in: str, path_out: str, use_label_constraint: bool = True): + data = Dataset.load(path_in) + texts = [s.text for s in data.sents] + model = self.get_model() + assert isinstance(model, NewRelationExtractor) + gen = model.load_generator(torch.device("cuda")) + encoder = model.get_encoder() + constraint = LabelConstraint(labels=data.get_labels(), tokenizer=gen.tokenizer) + sents = [] + + for i in tqdm(range(0, len(texts), model.batch_size)): + batch = texts[i : i + model.batch_size] + x = [encoder.encode_x(t) for t in batch] + outputs = model.gen_texts( + x, gen, num_beams=1, save_scores=use_label_constraint + ) + assert len(outputs) == len(x) + + for i, raw in enumerate(outputs): + triplet = encoder.safe_decode(x[i], y=raw) + if use_label_constraint: + assert gen.scores is not None + triplet = constraint.run(triplet, gen.scores[i]) + sents.append(Sentence(triplets=[triplet])) + + Dataset(sents=sents).save(path_out) + + def predict_multi(self, path_in: str, path_out: str): + stem = Path(path_out).stem + path_raw = path_out.replace(stem, f"{stem}_raw") + print(dict(predict_multi=locals())) + data = Dataset.load(path_in) + model = self.get_model() + assert isinstance(model, NewRelationExtractor) + gen = model.load_generator(torch.device("cuda")) + constraint = LabelConstraint(labels=data.get_labels(), tokenizer=gen.tokenizer) + searcher = TripletSearchDecoder( + gen=gen, encoder=model.get_encoder(), constraint=constraint + ) + + sents = [ + Sentence(tokens=s.tokens, triplets=searcher.run(s.text)) + for s in tqdm(data.sents) + ] + Dataset(sents=sents).save(path_raw) + for s in sents: + s.triplets = [t for t in s.triplets if t.score > self.search_threshold] + Dataset(sents=sents).save(path_out) + + @staticmethod + def score(path_pred: str, path_gold: str) -> dict: + pred = Dataset.load(path_pred) + gold = Dataset.load(path_gold) + assert len(pred.sents) == len(gold.sents) + num_pred = 0 + num_gold = 0 + num_correct = 0 + + for i in range(len(gold.sents)): + num_pred += len(pred.sents[i].triplets) + num_gold += len(gold.sents[i].triplets) + for p in pred.sents[i].triplets: + for g in gold.sents[i].triplets: + if (p.head, p.tail, p.label) == (g.head, g.tail, g.label): + num_correct += 1 + + precision = safe_divide(num_correct, num_pred) + recall = safe_divide(num_correct, num_gold) + + info = dict( + path_pred=path_pred, + path_gold=path_gold, + precision=precision, + recall=recall, + score=safe_divide(2 * precision * recall, precision + recall), + ) + return info + + +def main( + path_train: str, + path_dev: str, + path_test: str, + save_dir: str, +): + print(dict(main=locals())) + generator = Generator( + load_dir="gpt2", + save_dir=str(Path(save_dir) / "generator"), + ) + extractor = Extractor( + load_dir="facebook/bart-base", + save_dir=str(Path(save_dir) / "extractor"), + ) + + generator.fit(path_train, path_dev) + extractor.fit(path_train, path_dev) + path_synthetic = str(Path(save_dir) / "synthetic.jsonl") + labels_dev = Dataset.load(path_dev).get_labels() + labels_test = Dataset.load(path_test).get_labels() + generator.generate(labels_dev + labels_test, path_out=path_synthetic) + + extractor_final = Extractor( + load_dir=str(Path(save_dir) / "extractor" / "model"), + save_dir=str(Path(save_dir) / "extractor_final"), + ) + extractor_final.fit(path_synthetic, path_dev) + + path_pred = str(Path(save_dir) / "pred.jsonl") + extractor_final.predict(path_in=path_test, path_out=path_pred) + results = extractor_final.score(path_pred, path_test) + print(json.dumps(results, indent=2)) + with open(Path(save_dir) / "results.json", "w") as f: + json.dump(results, f, indent=2) + return results + + +def main_many(data_dir_pattern: str, save_dir: str, **kwargs): + mode = Path(save_dir).name + assert mode in ["fewrel", "wiki"] + records = [] + + for path in tqdm(sorted(Path().glob(data_dir_pattern))): + path_train = path / "train.jsonl" + path_dev = path / "dev.jsonl" + path_test = path / "test.jsonl" + results = main( + path_train=str(path_train), + path_dev=str(path_dev), + path_test=str(path_test), + save_dir=str(Path(save_dir) / path.name), + **kwargs, + ) + records.append(results) + + avg_p = sum([r["precision"] for r in records]) / len(records) + avg_r = sum([r["recall"] for r in records]) / len(records) + avg_f = safe_divide(2 * avg_p * avg_r, avg_p + avg_r) + info = dict(avg_p=avg_p, avg_r=avg_r, avg_f=avg_f) + print(json.dumps(info, indent=2)) + + +def run_eval(path_model: str, path_test: str, mode: str, limit: int = 0): + print(dict(run_eval=locals())) + data = Dataset.load(path_test) + model = Extractor(load_dir=str(Path(path_model) / "model"), save_dir=path_model) + + if mode == "single": + data.sents = [s for s in data.sents if len(s.triplets) == 1] + elif mode == "multi": + data.sents = [s for s in data.sents if len(s.triplets) > 1] + else: + raise ValueError(f"mode must be single or multi") + + if limit > 0: + random.seed(0) + random.shuffle(data.sents) + data.sents = data.sents[:limit] + + path_in = str(Path(path_model) / f"pred_in_{mode}.jsonl") + path_out = str(Path(path_model) / f"pred_out_{mode}.jsonl") + data.save(path_in) + + if mode == "single": + model.predict(path_in, path_out) + else: + model.predict_multi(path_in, path_out) + + results = model.score(path_pred=path_out, path_gold=path_in) + path_results = str(Path(path_model) / f"results_{mode}.json") + results.update(mode=mode, limit=limit, path_results=path_results) + print(json.dumps(results, indent=2)) + with open(path_results, "w") as f: + json.dump(results, f, indent=2) + + +def run_eval_many(path_model_pattern: str, data_dir: str, **kwargs): + for path in tqdm(sorted(Path().glob(path_model_pattern))): + name = path.parts[-2] + path_test = Path(data_dir) / name / "test.jsonl" + assert path_test.exists() + run_eval(path_model=str(path), path_test=str(path_test), **kwargs) + + +""" +p wrapper.py write_data_splits data/data_zsbert/wiki.json --mode wiki --seeds [0,1,11,14,16] --folder_out temp/outputs/data/splits/zero_rte +p wrapper.py write_data_splits data/data_zsbert/fewrel.json --mode fewrel --seeds [9,1,0,8,4] --folder_out temp/outputs/data/splits/zero_rte + +p wrapper.py main \ +--path_train outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/train.jsonl \ +--path_dev outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/dev.jsonl \ +--path_test outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/test.jsonl \ +--save_dir outputs/wrapper/fewrel/unseen_10_seed_0 + +p wrapper.py main \ +--path_train outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/train.jsonl \ +--path_dev outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/dev.jsonl \ +--path_test outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/test.jsonl \ +--save_dir outputs/wrapper/wiki/unseen_10_seed_0 + +p wrapper.py run_eval \ +--path_model outputs/wrapper/wiki/unseen_10_seed_0/extractor_final \ +--path_test outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/test.jsonl \ +--mode single + +""" + + +if __name__ == "__main__": + Fire()