-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnet_metrics.m
102 lines (81 loc) · 2.6 KB
/
net_metrics.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
% Code settings
clear all
warning off
% Choose network
netChoice = input('Choose network:\n 0 = SLP (default)\n 1 = BiLSTM\n');
% Check network choice
if netChoice == 0
str = 'slp';
elseif netChoice == 1
str = 'bilstm';
elseif isempty(netChoice)
netChoice = 0;
str = 'slp';
else
error('Invalid choice');
end
% Load network
NET = load(strcat('models\gravity_', str, '.mat'));
% Gather datasets
for datas = 1 : 4
close all force
% Load dataset
load(strcat('dataset/DatasGravityFeatures', int2str(datas)), 'DATASET');
% Store dataset patterns
if datas ~= 1
x_true{datas - 1} = DATASET{1};
end
% Save common dataset info
if datas == 1
% Get true labels
y_true = DATASET{2};
% Get fold
datasetFolder = DATASET{3};
fold = 1;
% Dataset sizes
totalSize = DATASET{5};
trainValidationSize = DATASET{4};
end
% Clear used dataset
clear DATASET
end
% Number of instances per dataset
trainSize = floor(trainValidationSize * 0.9);
valSize = trainValidationSize - trainSize;
testSize = totalSize - trainSize - valSize;
% Get test set indexes and labels
testPatternIndexes = datasetFolder(fold, trainValidationSize + 1 : totalSize);
y_fold_test = y_true(testPatternIndexes);
% Create test set
clear testSequences;
for pattern = trainValidationSize + 1 : totalSize
if netChoice == 0
% Get sequence
sequence = [x_true{1}{datasetFolder(fold, pattern)}';
x_true{2}{datasetFolder(fold, pattern)}';
x_true{3}{datasetFolder(fold, pattern)}';];
% Add sequence to test set
testSequences(pattern - trainValidationSize, :) = sequence;
else
% Get sequence
sequence = [x_true{1}{datasetFolder(fold, pattern)};
x_true{2}{datasetFolder(fold, pattern)};
x_true{3}{datasetFolder(fold, pattern)}];
% Transpose sequence
sequence = sequence';
% Add sequence to test set
testSequences{pattern - trainValidationSize} = sequence;
end
end
% Classifying test patterns
[outclass, score{fold}] = classify(NET.netTransfer, testSequences);
% Get highest confidence and related class for each pattern
[a, b] = max(score{fold}');
% Get accuracy (correctly matched labels in test set divided by size)
acc(fold) = sum(b == y_fold_test) ./ length(y_fold_test);
% Compute confusion matrix
confMat = confusionmat(categorical(y_fold_test), categorical(outclass));
% Display confusion matrix
cm = confusionchart(confMat);
% Print test accuracy
fprintf('Test accuracy: %.4f\n', acc);