-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathEval-ImageRecg.py
192 lines (140 loc) · 6.07 KB
/
Eval-ImageRecg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
'''
Evaluation Script, report will be generated automatically.
** 'pred' define by the model dict
** the number of 'labels' >= the number of 'pred'
$$ precision = #(label=pred) / #pred
$$ recall = #(label=pred) / #label
input: txtfile format 'imgname'\t'label'\t'pred'\t'confidence'\n
output: [report-name].txt general report with P/R in each label
[PRline-name].line line file for P/R line drawings, confidence thresh default sets [0.0, 0.2, 0.4, 0.6, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99]
'''
import cv2, os, sys
import argparse
import numpy as np
def PRstatistics(txtfile, conf_thresh=0.0, BIG_CLASS=['pl', 'pr', 'pa', 'pm', 'pw', 'ph'], FILTER_CLASS=['noprohibit', 'po', 'background'], \
LABEL_FILTER=True, BIG_CLASS_STAT=True):
'''
PRstatistics() to output result dict with label, predict-class, and hit number in a specific conf_thresh
func can be amelirated. will do later
input : txtfile object
output: result_dict
'''
result_dict = {}
## summary of big class
if BIG_CLASS_STAT:
for bclass in BIG_CLASS:
result_dict[bclass] = [0,0,0]
## statistics
for line in txtfile:
imgname, label, pred, confidence = line.strip().split('\t')
confidence = float(confidence)
## label filter
if LABEL_FILTER == True:
if label in FILTER_CLASS:
continue
if label not in result_dict.keys():
result_dict[label] = [1,0,0]
else:
result_dict[label][0] += 1
## only over conf_thresh will be counted
if confidence > conf_thresh:
if pred not in result_dict.keys():
result_dict[pred] = [0,1,0]
else:
result_dict[pred][1] += 1
if pred == label:
result_dict[label][2] += 1
## big class
if BIG_CLASS_STAT == True:
if label[:2] in BIG_CLASS:
result_dict[label[:2]][0] += 1
## only over conf_thresh will be counted
if confidence > conf_thresh:
if pred[:2] in BIG_CLASS:
result_dict[pred[:2]][1] += 1
if label[:2] == pred[:2]:
result_dict[label[:2]][2] += 1
return result_dict
argparser = argparse.ArgumentParser(description='Experiment Platform for Image Recognition')
argparser.add_argument('--input', help='input txtfile generated by `predict.py`')
argparser.add_argument('--outdir', help='outfile directory')
argparser.add_argument('--name', help='report name')
argparser.add_argument('--dict', help='dictionary of the model')
args = argparser.parse_args()
## report generator ##
txtfile = open(args.input).readlines()
## the first conf_thresh must be 0.0 to generate general_report
# conf_thresh_sets = [0.0, 0.2, 0.4, 0.6, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99, 0.995]
conf_thresh_sets = [0.0, 0.2, 0.4, 0.6, 0.8, 0.85, 0.9, 0.91, 0.92, 0.93, 0.94]
BIG_CLASS=['pl', 'pr', 'pa', 'pm', 'pw', 'ph']
FILTER_CLASS=['noprohibit', 'po', 'background']
## model class ##
_dict = open(args.dict).readlines()
model_class = []
for _cate in _dict:
_cate = _cate.strip()
if _cate not in model_class:
model_class.append(_cate)
## line file
linefile_name = os.path.join(args.outdir, 'LINEAR_'+args.name+'.line')
linefile = open(linefile_name, 'w')
print " >>>>> Line File : ", linefile_name
for conf_thresh in conf_thresh_sets:
result_dict = PRstatistics(txtfile, conf_thresh=conf_thresh, BIG_CLASS=BIG_CLASS, FILTER_CLASS=FILTER_CLASS, LABEL_FILTER=True, BIG_CLASS_STAT=True)
## Precision / Recall
sum_label_num = 0.
sum_pred_num = 0.
sum_hit_num = 0.
for cate in sorted(result_dict.keys()):
### how to cal the big class is a big duel
# ## rule-1: Sum-Class other than 'pl'
# if cate[:2] != 'pl':
# if len(cate) > 2 and cate[:2] in BIG_CLASS:
# continue
# if cate == 'pl':
# continue
# ## rule-2: Fine-Grain Categories
# if cate in BIG_CLASS:
# continue
## rule-3: model dict only
if cate not in model_class:
continue
label_num, pred_num, hit_num = result_dict[cate]
sum_label_num += label_num
sum_pred_num += pred_num
sum_hit_num += hit_num
if sum_pred_num != 0:
total_precision = round(sum_hit_num/sum_pred_num, 3)
else:
total_precision = 0
if sum_label_num != 0:
total_recall = round(sum_hit_num/sum_label_num, 3)
else:
total_recall = 0
## General Report [conf=0.0]
if conf_thresh == 0.0:
greport_name = os.path.join(args.outdir, 'General-Report_'+args.name+'.txt')
greport = open(greport_name, 'w')
print " >>>>> General Report Generating: ", greport_name
for cate in sorted(result_dict.keys()):
label_num, pred_num, hit_num = result_dict[cate]
if pred_num!=0:
precision = round(float(hit_num)/pred_num, 3)
else:
precision = 0
if label_num!=0:
recall = round(float(hit_num)/label_num, 3)
else:
recall = 0
greport.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'. \
format('cate:', cate, '#label:', int(label_num), '#pred:', int(pred_num), '#hit:', int(hit_num), 'precision:', precision, 'recall:', recall))
greport.write('\n\n')
greport.write('General Statistics:\n')
greport.write('Sum #label:{}\tSum #pred:{}\tSum #hit:{}\tTotal precision:{}\tTotal recall:{}\n\n'.\
format(int(sum_label_num), int(sum_pred_num), int(sum_hit_num), total_precision, total_recall))
greport.close()
print " >>>>> General Report Done"
## linefile recorder
linefile.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format('Confidence:', conf_thresh, 'Precision:', total_precision, 'Recall:', total_recall))
linefile.close()
print " >>>>> Line File Done"