Skip to content

Commit

Permalink
Initial code release
Browse files Browse the repository at this point in the history
  • Loading branch information
chiayewken committed Apr 14, 2022
1 parent 252bf89 commit ff68003
Show file tree
Hide file tree
Showing 9 changed files with 3,025 additions and 0 deletions.
203 changes: 203 additions & 0 deletions encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from pathlib import Path
from typing import Dict, List, Tuple

from fire import Fire
from pydantic import BaseModel
from tqdm import tqdm
from transformers import AutoTokenizer

from transformer_base import run_summarization
from utils import RelationData, RelationSentence


class Encoder(BaseModel):
def encode_x(self, x: str) -> str:
raise NotImplementedError

def encode(self, sent: RelationSentence) -> Tuple[str, str]:
raise NotImplementedError

def decode(self, x: str, y: str) -> RelationSentence:
raise NotImplementedError

def decode_x(self, x: str) -> str:
raise NotImplementedError

def safe_decode(self, x: str, y: str) -> RelationSentence:
text = self.decode_x(x)
try:
s = self.decode(x=x, y=y)
except Exception as e:
s = RelationSentence(
tokens=text.split(), head=[], tail=[], label="", error=str(e), raw=y
)
return s

def encode_to_line(self, sent: RelationSentence) -> str:
raise NotImplementedError

def decode_from_line(self, line: str) -> RelationSentence:
raise NotImplementedError

def parse_line(self, line: str) -> Tuple[str, str]:
raise NotImplementedError


class GenerateEncoder(Encoder):
def encode_x(self, r: str) -> str:
return f"Relation : {r} ."

def decode_x(self, text: str) -> str:
return text.split("Relation : ")[-1][:-2]

def encode_triplet(self, sent: RelationSentence) -> str:
s, r, o = sent.as_tuple()
return f"Context : {sent.text} Head Entity : {s} , Tail Entity : {o} ."

def decode_triplet(self, text: str, label: str) -> RelationSentence:
front, back = text.split(" Head Entity : ")
_, context = front.split("Context : ")
head, back = back.split(" , Tail Entity : ")
tail = back[:-2]
return RelationSentence.from_spans(context, head, tail, label)

def encode_y(self, sent: RelationSentence) -> str:
return self.encode_x(sent.label) + " " + self.encode_triplet(sent)

def decode_y(self, text: str, label: str) -> RelationSentence:
del label
front, back = text.split(" . Context : ")
label = self.decode_x(front + " .")
return self.decode_triplet("Context : " + back, label)

def decode(self, x: str, y: str) -> RelationSentence:
r = self.decode_x(x)
sent = self.decode_y(y, r)
return sent

def encode(self, sent: RelationSentence) -> Tuple[str, str]:
x = self.encode_x(sent.label)
y = self.encode_y(sent)
return x, y

def decode_from_line(self, line: str) -> RelationSentence:
x, y = self.parse_line(line)
return self.decode(x, y)

def encode_to_line(self, sent: RelationSentence) -> str:
x, y = self.encode(sent)
return y + "\n"

def parse_line(self, line: str) -> Tuple[str, str]:
return "", line.strip()


class ExtractEncoder(Encoder):
def encode_x(self, text: str) -> str:
return f"Context : {text}"

def decode_x(self, x: str) -> str:
return x.split("Context : ")[-1]

def encode_y(self, sent: RelationSentence) -> str:
s, r, o = sent.as_tuple()
return f"Head Entity : {s} , Tail Entity : {o} , Relation : {r} ."

def decode_y(self, x: str, y: str) -> RelationSentence:
context = self.decode_x(x)
front, label = y.split(" , Relation : ")
label = label[:-2]
front, tail = front.split(" , Tail Entity : ")
_, head = front.split("Head Entity : ")
return RelationSentence.from_spans(context, head, tail, label)

def encode_entity_prompt(self, head: str, tail: str) -> str:
return f"Head Entity : {head} , Tail Entity : {tail} , Relation :"

def encode(self, sent: RelationSentence) -> Tuple[str, str]:
x = self.encode_x(sent.text)
y = self.encode_y(sent)
return x, y

def decode(self, x: str, y: str) -> RelationSentence:
return self.decode_y(x, y)

def encode_to_line(self, sent: RelationSentence) -> str:
x, y = self.encode(sent)
return run_summarization.encode_to_line(x, y)

def decode_from_line(self, line: str) -> RelationSentence:
x, y = self.parse_line(line)
return self.decode(x, y)

def parse_line(self, line: str) -> Tuple[str, str]:
return run_summarization.decode_from_line(line)


def test_encoders(
paths: List[str] = [
"outputs/data/zsl/wiki/unseen_5_seed_0/train.jsonl",
"outputs/data/zsl/fewrel/unseen_5_seed_0/train.jsonl",
],
print_limit: int = 4,
encoder_names: List[str] = ["generate", "extract"],
limit: int = 1000,
):
encoders = {k: select_encoder(k) for k in encoder_names}

for p in paths:
data = RelationData.load(Path(p))
_, data = data.train_test_split(min(limit, len(data.sents)), random_seed=0)

for name, e in tqdm(list(encoders.items())):
num_fail = 0
print(dict(name=name, p=p))
for s in data.sents:
encoded = e.encode_to_line(s)
x, y = e.parse_line(encoded)
decoded: RelationSentence = e.safe_decode(x, y)

if decoded.as_tuple() != s.as_tuple():
if num_fail < print_limit:
print(dict(gold=s.as_tuple(), text=s.text))
print(dict(pred=decoded.as_tuple(), text=decoded.text))
print(dict(x=x, y=y, e=decoded.error))
print()
num_fail += 1

print(dict(success_rate=1 - (num_fail / len(data.sents))))
print("#" * 80)


def select_encoder(name: str) -> Encoder:
mapping: Dict[str, Encoder] = dict(
extract=ExtractEncoder(),
generate=GenerateEncoder(),
)
encoder = mapping[name]
return encoder


def test_entity_prompts(
path: str = "outputs/data/zsl/wiki/unseen_10_seed_0/test.jsonl", limit: int = 100
):
def tokenize(text: str, tok) -> List[str]:
return tok.convert_ids_to_tokens(tok(text, add_special_tokens=False).input_ids)

data = RelationData.load(Path(path))
e = ExtractEncoder()
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
print(tokenizer)
for i, s in enumerate(tqdm(data.sents[:limit])):
head, label, tail = s.as_tuple()
x, y = e.encode(s)
prompt = e.encode_entity_prompt(head, tail)
tokens_y = tokenize(y, tokenizer)
tokens_prompt = tokenize(prompt, tokenizer)
assert tokens_y[: len(tokens_prompt)] == tokens_prompt
if i < 3:
print(tokens_y)


if __name__ == "__main__":
Fire()
193 changes: 193 additions & 0 deletions generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import Dict, List, Optional, Tuple

import torch
from fire import Fire
from torch import Tensor
from transformers import PreTrainedModel, PreTrainedTokenizerFast

from encoding import ExtractEncoder
from utils import DynamicModel, RelationSentence, find_sublist_index


class TextGenerator(DynamicModel):
model: PreTrainedModel
tokenizer: PreTrainedTokenizerFast
scores: Optional[List[Tensor]] = None
max_length: int

def tokenize(self, texts: List[str], **kwargs):
return self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
**kwargs,
).to(self.model.device)

def run(
self,
texts: List[str],
do_sample=True,
top_k=50,
temperature=1.0,
num_return: int = 4,
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
multi_prompt_ids: Optional[List[List[int]]] = None,
decoder_input_ids: Optional[Tensor] = None,
save_scores: bool = False,
**kwargs,
) -> List[str]:
# https://huggingface.co/transformers/v4.7.0/main_classes/model.html#generation
tok = self.tokenizer
eos, bos = tok.eos_token_id, tok.bos_token_id

if prompt is not None:
prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids
if prompt_ids is not None:
prompt_ids = [eos, bos] + prompt_ids
decoder_input_ids = torch.tensor([prompt_ids])
if multi_prompt_ids is not None:
assert len(texts) == len(multi_prompt_ids)
multi_prompt_ids = [[eos, bos] + lst for lst in multi_prompt_ids]
decoder_input_ids = torch.tensor(multi_prompt_ids)
if decoder_input_ids is not None:
kwargs.update(decoder_input_ids=decoder_input_ids.to(self.model.device))

outputs = self.model.generate(
**self.tokenize(texts),
do_sample=do_sample,
top_k=top_k,
temperature=temperature,
num_return_sequences=num_return,
return_dict_in_generate=True,
output_scores=save_scores,
max_length=self.max_length,
**kwargs,
)

self.scores = None
if save_scores:
self.scores = [_ for _ in torch.stack(outputs.scores, 1).cpu()]
return self.decode(outputs.sequences)

def decode(self, outputs) -> List[str]:
tok = self.tokenizer
texts = tok.batch_decode(
outputs, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

# Manually remove <bos><eos><pad> in case we have custom special tokens
special_tokens = [tok.eos_token, tok.bos_token, tok.pad_token]
for i, t in enumerate(texts):
for token in special_tokens:
t = t.replace(token, "")
texts[i] = t
return texts


class LabelConstraint:
def __init__(
self,
labels: List[str],
tokenizer: PreTrainedTokenizerFast,
prefix: str = " Relation :",
):
self.prefix: List[int] = tokenizer(prefix, add_special_tokens=False).input_ids
self.label_map: Dict[int, str] = {
tokenizer(" " + x, add_special_tokens=False).input_ids[0]: x for x in labels
}
self.tokenizer = tokenizer

def run(self, triplet: RelationSentence, scores: Tensor) -> RelationSentence:
triplet = triplet.copy(deep=True)
assert scores.ndim == 2
token_ids = scores.argmax(dim=-1).int().tolist()
i = find_sublist_index(token_ids, self.prefix)
if i == -1:
return triplet

position = i + len(self.prefix)
best = ""
best_score = -1e9
for j, label in self.label_map.items():
score = scores[position, j].item()
if score > best_score:
best = label
best_score = score

if triplet.label in self.label_map.values():
assert best == triplet.label

assert len(best) > 0
triplet.label = best
triplet.score = best_score
return triplet


class TripletSearchDecoder(DynamicModel):
gen: TextGenerator
constraint: LabelConstraint
encoder: ExtractEncoder
top_k: int = 4

def generate(self, text: str, **kwargs) -> Tuple[str, Tensor]:
outputs = self.gen.run(
[text],
do_sample=False,
num_return=1,
num_beams=1,
save_scores=True,
**kwargs,
)

assert len(outputs) == 1
assert self.gen.scores is not None
scores = torch.log_softmax(self.gen.scores[0], dim=-1)
assert scores.ndim == 2
return outputs[0], scores

def find_prefix_end(self, token_ids: List[str], prefix: str) -> int:
prefix_ids = self.gen.tokenizer(prefix, add_special_tokens=False).input_ids
i = find_sublist_index(token_ids, prefix_ids)
position = i + len(prefix_ids)
return position

def branch(
self, text: str, prefix: str, prompt: Optional[str] = None, **kwargs
) -> List[Tuple[str, float]]:
_, scores = self.generate(text, prompt=prompt, **kwargs)
token_ids = scores.argmax(dim=-1).int().tolist()
i = self.find_prefix_end(token_ids, prefix)

pairs = []
for j in torch.argsort(scores[i])[-self.top_k :]:
p = (prompt or "") + self.gen.decode([token_ids[:i] + [j]])[0]
pairs.append((p, scores[i, j].item()))

return pairs

def run(self, text: str) -> List[RelationSentence]:
x = self.encoder.encode_x(text)
outputs = []

for prompt_a, score_a in self.branch(x, prefix="Head Entity :"):
for prompt_b, score_b in self.branch(
x, prefix=" Tail Entity :", prompt=prompt_a
):
output, scores = self.generate(x, prompt=prompt_b)
token_ids = token_ids = scores.argmax(dim=-1).int().tolist()
i = self.find_prefix_end(token_ids, prefix=" Relation :")
score_c = max(scores[i].tolist())
s = self.encoder.safe_decode(x=x, y=output)
s = self.constraint.run(s, scores)
# score_c = s.score # From LabelConstraint
s.score = (score_a + score_b + score_c) / 3
outputs.append(s)

return outputs


if __name__ == "__main__":
Fire()
Loading

0 comments on commit ff68003

Please sign in to comment.