forked from declare-lab/RelationPrompt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
252bf89
commit ff68003
Showing
9 changed files
with
3,025 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from pathlib import Path | ||
from typing import Dict, List, Tuple | ||
|
||
from fire import Fire | ||
from pydantic import BaseModel | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer | ||
|
||
from transformer_base import run_summarization | ||
from utils import RelationData, RelationSentence | ||
|
||
|
||
class Encoder(BaseModel): | ||
def encode_x(self, x: str) -> str: | ||
raise NotImplementedError | ||
|
||
def encode(self, sent: RelationSentence) -> Tuple[str, str]: | ||
raise NotImplementedError | ||
|
||
def decode(self, x: str, y: str) -> RelationSentence: | ||
raise NotImplementedError | ||
|
||
def decode_x(self, x: str) -> str: | ||
raise NotImplementedError | ||
|
||
def safe_decode(self, x: str, y: str) -> RelationSentence: | ||
text = self.decode_x(x) | ||
try: | ||
s = self.decode(x=x, y=y) | ||
except Exception as e: | ||
s = RelationSentence( | ||
tokens=text.split(), head=[], tail=[], label="", error=str(e), raw=y | ||
) | ||
return s | ||
|
||
def encode_to_line(self, sent: RelationSentence) -> str: | ||
raise NotImplementedError | ||
|
||
def decode_from_line(self, line: str) -> RelationSentence: | ||
raise NotImplementedError | ||
|
||
def parse_line(self, line: str) -> Tuple[str, str]: | ||
raise NotImplementedError | ||
|
||
|
||
class GenerateEncoder(Encoder): | ||
def encode_x(self, r: str) -> str: | ||
return f"Relation : {r} ." | ||
|
||
def decode_x(self, text: str) -> str: | ||
return text.split("Relation : ")[-1][:-2] | ||
|
||
def encode_triplet(self, sent: RelationSentence) -> str: | ||
s, r, o = sent.as_tuple() | ||
return f"Context : {sent.text} Head Entity : {s} , Tail Entity : {o} ." | ||
|
||
def decode_triplet(self, text: str, label: str) -> RelationSentence: | ||
front, back = text.split(" Head Entity : ") | ||
_, context = front.split("Context : ") | ||
head, back = back.split(" , Tail Entity : ") | ||
tail = back[:-2] | ||
return RelationSentence.from_spans(context, head, tail, label) | ||
|
||
def encode_y(self, sent: RelationSentence) -> str: | ||
return self.encode_x(sent.label) + " " + self.encode_triplet(sent) | ||
|
||
def decode_y(self, text: str, label: str) -> RelationSentence: | ||
del label | ||
front, back = text.split(" . Context : ") | ||
label = self.decode_x(front + " .") | ||
return self.decode_triplet("Context : " + back, label) | ||
|
||
def decode(self, x: str, y: str) -> RelationSentence: | ||
r = self.decode_x(x) | ||
sent = self.decode_y(y, r) | ||
return sent | ||
|
||
def encode(self, sent: RelationSentence) -> Tuple[str, str]: | ||
x = self.encode_x(sent.label) | ||
y = self.encode_y(sent) | ||
return x, y | ||
|
||
def decode_from_line(self, line: str) -> RelationSentence: | ||
x, y = self.parse_line(line) | ||
return self.decode(x, y) | ||
|
||
def encode_to_line(self, sent: RelationSentence) -> str: | ||
x, y = self.encode(sent) | ||
return y + "\n" | ||
|
||
def parse_line(self, line: str) -> Tuple[str, str]: | ||
return "", line.strip() | ||
|
||
|
||
class ExtractEncoder(Encoder): | ||
def encode_x(self, text: str) -> str: | ||
return f"Context : {text}" | ||
|
||
def decode_x(self, x: str) -> str: | ||
return x.split("Context : ")[-1] | ||
|
||
def encode_y(self, sent: RelationSentence) -> str: | ||
s, r, o = sent.as_tuple() | ||
return f"Head Entity : {s} , Tail Entity : {o} , Relation : {r} ." | ||
|
||
def decode_y(self, x: str, y: str) -> RelationSentence: | ||
context = self.decode_x(x) | ||
front, label = y.split(" , Relation : ") | ||
label = label[:-2] | ||
front, tail = front.split(" , Tail Entity : ") | ||
_, head = front.split("Head Entity : ") | ||
return RelationSentence.from_spans(context, head, tail, label) | ||
|
||
def encode_entity_prompt(self, head: str, tail: str) -> str: | ||
return f"Head Entity : {head} , Tail Entity : {tail} , Relation :" | ||
|
||
def encode(self, sent: RelationSentence) -> Tuple[str, str]: | ||
x = self.encode_x(sent.text) | ||
y = self.encode_y(sent) | ||
return x, y | ||
|
||
def decode(self, x: str, y: str) -> RelationSentence: | ||
return self.decode_y(x, y) | ||
|
||
def encode_to_line(self, sent: RelationSentence) -> str: | ||
x, y = self.encode(sent) | ||
return run_summarization.encode_to_line(x, y) | ||
|
||
def decode_from_line(self, line: str) -> RelationSentence: | ||
x, y = self.parse_line(line) | ||
return self.decode(x, y) | ||
|
||
def parse_line(self, line: str) -> Tuple[str, str]: | ||
return run_summarization.decode_from_line(line) | ||
|
||
|
||
def test_encoders( | ||
paths: List[str] = [ | ||
"outputs/data/zsl/wiki/unseen_5_seed_0/train.jsonl", | ||
"outputs/data/zsl/fewrel/unseen_5_seed_0/train.jsonl", | ||
], | ||
print_limit: int = 4, | ||
encoder_names: List[str] = ["generate", "extract"], | ||
limit: int = 1000, | ||
): | ||
encoders = {k: select_encoder(k) for k in encoder_names} | ||
|
||
for p in paths: | ||
data = RelationData.load(Path(p)) | ||
_, data = data.train_test_split(min(limit, len(data.sents)), random_seed=0) | ||
|
||
for name, e in tqdm(list(encoders.items())): | ||
num_fail = 0 | ||
print(dict(name=name, p=p)) | ||
for s in data.sents: | ||
encoded = e.encode_to_line(s) | ||
x, y = e.parse_line(encoded) | ||
decoded: RelationSentence = e.safe_decode(x, y) | ||
|
||
if decoded.as_tuple() != s.as_tuple(): | ||
if num_fail < print_limit: | ||
print(dict(gold=s.as_tuple(), text=s.text)) | ||
print(dict(pred=decoded.as_tuple(), text=decoded.text)) | ||
print(dict(x=x, y=y, e=decoded.error)) | ||
print() | ||
num_fail += 1 | ||
|
||
print(dict(success_rate=1 - (num_fail / len(data.sents)))) | ||
print("#" * 80) | ||
|
||
|
||
def select_encoder(name: str) -> Encoder: | ||
mapping: Dict[str, Encoder] = dict( | ||
extract=ExtractEncoder(), | ||
generate=GenerateEncoder(), | ||
) | ||
encoder = mapping[name] | ||
return encoder | ||
|
||
|
||
def test_entity_prompts( | ||
path: str = "outputs/data/zsl/wiki/unseen_10_seed_0/test.jsonl", limit: int = 100 | ||
): | ||
def tokenize(text: str, tok) -> List[str]: | ||
return tok.convert_ids_to_tokens(tok(text, add_special_tokens=False).input_ids) | ||
|
||
data = RelationData.load(Path(path)) | ||
e = ExtractEncoder() | ||
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") | ||
print(tokenizer) | ||
for i, s in enumerate(tqdm(data.sents[:limit])): | ||
head, label, tail = s.as_tuple() | ||
x, y = e.encode(s) | ||
prompt = e.encode_entity_prompt(head, tail) | ||
tokens_y = tokenize(y, tokenizer) | ||
tokens_prompt = tokenize(prompt, tokenizer) | ||
assert tokens_y[: len(tokens_prompt)] == tokens_prompt | ||
if i < 3: | ||
print(tokens_y) | ||
|
||
|
||
if __name__ == "__main__": | ||
Fire() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from fire import Fire | ||
from torch import Tensor | ||
from transformers import PreTrainedModel, PreTrainedTokenizerFast | ||
|
||
from encoding import ExtractEncoder | ||
from utils import DynamicModel, RelationSentence, find_sublist_index | ||
|
||
|
||
class TextGenerator(DynamicModel): | ||
model: PreTrainedModel | ||
tokenizer: PreTrainedTokenizerFast | ||
scores: Optional[List[Tensor]] = None | ||
max_length: int | ||
|
||
def tokenize(self, texts: List[str], **kwargs): | ||
return self.tokenizer( | ||
texts, | ||
padding=True, | ||
truncation=True, | ||
max_length=self.max_length, | ||
return_tensors="pt", | ||
**kwargs, | ||
).to(self.model.device) | ||
|
||
def run( | ||
self, | ||
texts: List[str], | ||
do_sample=True, | ||
top_k=50, | ||
temperature=1.0, | ||
num_return: int = 4, | ||
prompt: Optional[str] = None, | ||
prompt_ids: Optional[List[int]] = None, | ||
multi_prompt_ids: Optional[List[List[int]]] = None, | ||
decoder_input_ids: Optional[Tensor] = None, | ||
save_scores: bool = False, | ||
**kwargs, | ||
) -> List[str]: | ||
# https://huggingface.co/transformers/v4.7.0/main_classes/model.html#generation | ||
tok = self.tokenizer | ||
eos, bos = tok.eos_token_id, tok.bos_token_id | ||
|
||
if prompt is not None: | ||
prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids | ||
if prompt_ids is not None: | ||
prompt_ids = [eos, bos] + prompt_ids | ||
decoder_input_ids = torch.tensor([prompt_ids]) | ||
if multi_prompt_ids is not None: | ||
assert len(texts) == len(multi_prompt_ids) | ||
multi_prompt_ids = [[eos, bos] + lst for lst in multi_prompt_ids] | ||
decoder_input_ids = torch.tensor(multi_prompt_ids) | ||
if decoder_input_ids is not None: | ||
kwargs.update(decoder_input_ids=decoder_input_ids.to(self.model.device)) | ||
|
||
outputs = self.model.generate( | ||
**self.tokenize(texts), | ||
do_sample=do_sample, | ||
top_k=top_k, | ||
temperature=temperature, | ||
num_return_sequences=num_return, | ||
return_dict_in_generate=True, | ||
output_scores=save_scores, | ||
max_length=self.max_length, | ||
**kwargs, | ||
) | ||
|
||
self.scores = None | ||
if save_scores: | ||
self.scores = [_ for _ in torch.stack(outputs.scores, 1).cpu()] | ||
return self.decode(outputs.sequences) | ||
|
||
def decode(self, outputs) -> List[str]: | ||
tok = self.tokenizer | ||
texts = tok.batch_decode( | ||
outputs, skip_special_tokens=False, clean_up_tokenization_spaces=False | ||
) | ||
|
||
# Manually remove <bos><eos><pad> in case we have custom special tokens | ||
special_tokens = [tok.eos_token, tok.bos_token, tok.pad_token] | ||
for i, t in enumerate(texts): | ||
for token in special_tokens: | ||
t = t.replace(token, "") | ||
texts[i] = t | ||
return texts | ||
|
||
|
||
class LabelConstraint: | ||
def __init__( | ||
self, | ||
labels: List[str], | ||
tokenizer: PreTrainedTokenizerFast, | ||
prefix: str = " Relation :", | ||
): | ||
self.prefix: List[int] = tokenizer(prefix, add_special_tokens=False).input_ids | ||
self.label_map: Dict[int, str] = { | ||
tokenizer(" " + x, add_special_tokens=False).input_ids[0]: x for x in labels | ||
} | ||
self.tokenizer = tokenizer | ||
|
||
def run(self, triplet: RelationSentence, scores: Tensor) -> RelationSentence: | ||
triplet = triplet.copy(deep=True) | ||
assert scores.ndim == 2 | ||
token_ids = scores.argmax(dim=-1).int().tolist() | ||
i = find_sublist_index(token_ids, self.prefix) | ||
if i == -1: | ||
return triplet | ||
|
||
position = i + len(self.prefix) | ||
best = "" | ||
best_score = -1e9 | ||
for j, label in self.label_map.items(): | ||
score = scores[position, j].item() | ||
if score > best_score: | ||
best = label | ||
best_score = score | ||
|
||
if triplet.label in self.label_map.values(): | ||
assert best == triplet.label | ||
|
||
assert len(best) > 0 | ||
triplet.label = best | ||
triplet.score = best_score | ||
return triplet | ||
|
||
|
||
class TripletSearchDecoder(DynamicModel): | ||
gen: TextGenerator | ||
constraint: LabelConstraint | ||
encoder: ExtractEncoder | ||
top_k: int = 4 | ||
|
||
def generate(self, text: str, **kwargs) -> Tuple[str, Tensor]: | ||
outputs = self.gen.run( | ||
[text], | ||
do_sample=False, | ||
num_return=1, | ||
num_beams=1, | ||
save_scores=True, | ||
**kwargs, | ||
) | ||
|
||
assert len(outputs) == 1 | ||
assert self.gen.scores is not None | ||
scores = torch.log_softmax(self.gen.scores[0], dim=-1) | ||
assert scores.ndim == 2 | ||
return outputs[0], scores | ||
|
||
def find_prefix_end(self, token_ids: List[str], prefix: str) -> int: | ||
prefix_ids = self.gen.tokenizer(prefix, add_special_tokens=False).input_ids | ||
i = find_sublist_index(token_ids, prefix_ids) | ||
position = i + len(prefix_ids) | ||
return position | ||
|
||
def branch( | ||
self, text: str, prefix: str, prompt: Optional[str] = None, **kwargs | ||
) -> List[Tuple[str, float]]: | ||
_, scores = self.generate(text, prompt=prompt, **kwargs) | ||
token_ids = scores.argmax(dim=-1).int().tolist() | ||
i = self.find_prefix_end(token_ids, prefix) | ||
|
||
pairs = [] | ||
for j in torch.argsort(scores[i])[-self.top_k :]: | ||
p = (prompt or "") + self.gen.decode([token_ids[:i] + [j]])[0] | ||
pairs.append((p, scores[i, j].item())) | ||
|
||
return pairs | ||
|
||
def run(self, text: str) -> List[RelationSentence]: | ||
x = self.encoder.encode_x(text) | ||
outputs = [] | ||
|
||
for prompt_a, score_a in self.branch(x, prefix="Head Entity :"): | ||
for prompt_b, score_b in self.branch( | ||
x, prefix=" Tail Entity :", prompt=prompt_a | ||
): | ||
output, scores = self.generate(x, prompt=prompt_b) | ||
token_ids = token_ids = scores.argmax(dim=-1).int().tolist() | ||
i = self.find_prefix_end(token_ids, prefix=" Relation :") | ||
score_c = max(scores[i].tolist()) | ||
s = self.encoder.safe_decode(x=x, y=output) | ||
s = self.constraint.run(s, scores) | ||
# score_c = s.score # From LabelConstraint | ||
s.score = (score_a + score_b + score_c) / 3 | ||
outputs.append(s) | ||
|
||
return outputs | ||
|
||
|
||
if __name__ == "__main__": | ||
Fire() |
Oops, something went wrong.