-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresults_plots.py
55 lines (48 loc) · 2.16 KB
/
results_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from results_plotter import plot_results
import matplotlib as plt
def dict_based_name(d, name):
for k,v in d.items():
if k in name:
return v
def model_name(name):
d = {'distilbert-base-uncased': 'DistilBERT',
'bert-base-uncased': 'BERT',
'albert-xxlarge-v2': 'ALBERT',
'ensemble-distil-1-albert-1': 'ensemble-last-1',
'ensemble-distil-2-albert-2': 'ensemble-last-2',
'ensemble-distil-3-albert-3': 'ensemble-last-3',
'ensemble-distil-4-albert-4': 'ensemble-last-4'}
n = dict_based_name(d, name)
if not name.startswith('ensemble'):
d = {'token-embedding-last-2':'-last-2',
'token-embedding-last-3':'-last-3',
'token-embedding-last-4':'-last-4',
'token-embedding-last-1':'-last-1',
'token-embedding-last-layer':'-last-1'}
return n + dict_based_name(d,name)
return n
def linear(name):
if '[768]' in name:
return 'non-linear classifier'
else:
return 'linear classifier'
def cls_token(name):
if 'cls_token' in name:
return "cls token"
else:
return "relevant token"
def plot(fig_index):
if fig_index==0:
#plt.title("Linear vs two-layer classifier")
#plot_results("results_lin", None, lambda s: model_name(s) + " " + linear(s), ['r-', 'g-', 'r--', 'g--','b-','b--'])#,'c-','c--'])
plot_results("results_lin", None, lambda s: model_name(s) + " " + linear(s), ['r-', 'r--','b-','b--'])#,'c-','c--'])
elif fig_index==1:
plot_results("results_positional", None, lambda s: model_name(s) + " " + cls_token(s), ['r-', 'r--'])#,'c-','c--'])
elif fig_index==2:
plot_results("results_gru", None, lambda s: s, ['r-', 'r--'])#,'c-','c--'])
elif fig_index==3:
plot_results("results_finetune", None, ['train. loss: DistilBERT base model frozen', 'train. loss: DistilBERT base model fine-tuned', 'val. loss: DistilBERT base model frozen', 'val. loss: DistilBERT base model fine-tuned'], ['k-', 'k--','r-','r--'])#,'c-','c--'])
elif fig_index==4:
plot_results("results_overall", None, lambda s: s, None)
if __name__ == "__main__":
plot(3)