-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdet.go
188 lines (163 loc) · 6.58 KB
/
det.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package mlpack
/*
#cgo CFLAGS: -I./capi -Wall
#cgo LDFLAGS: -L. -lmlpack_go_det
#include <capi/det.h>
#include <stdlib.h>
*/
import "C"
import "gonum.org/v1/gonum/mat"
type DetOptionalParam struct {
Folds int
InputModel *dTree
MaxLeafSize int
MinLeafSize int
PathFormat string
SkipPruning bool
Test *mat.Dense
Training *mat.Dense
Verbose bool
}
func DetOptions() *DetOptionalParam {
return &DetOptionalParam{
Folds: 10,
InputModel: nil,
MaxLeafSize: 10,
MinLeafSize: 5,
PathFormat: "lr",
SkipPruning: false,
Test: nil,
Training: nil,
Verbose: false,
}
}
/*
This program performs a number of functions related to Density Estimation
Trees. The optimal Density Estimation Tree (DET) can be trained on a set of
data (specified by "Training") using cross-validation (with number of folds
specified with the "Folds" parameter). This trained density estimation tree
may then be saved with the "OutputModel" output parameter.
The variable importances (that is, the feature importance values for each
dimension) may be saved with the "Vi" output parameter, and the density
estimates for each training point may be saved with the "TrainingSetEstimates"
output parameter.
Enabling path printing for each node outputs the path from the root node to a
leaf for each entry in the test set, or training set (if a test set is not
provided). Strings like 'LRLRLR' (indicating that traversal went to the left
child, then the right child, then the left child, and so forth) will be
output. If 'lr-id' or 'id-lr' are given as the "PathFormat" parameter, then
the ID (tag) of every node along the path will be printed after or before the
L or R character indicating the direction of traversal, respectively.
This program also can provide density estimates for a set of test points,
specified in the "Test" parameter. The density estimation tree used for this
task will be the tree that was trained on the given training points, or a tree
given as the parameter "InputModel". The density estimates for the test
points may be saved using the "TestSetEstimates" output parameter.
Input parameters:
- Folds (int): The number of folds of cross-validation to perform for
the estimation (0 is LOOCV) Default value 10.
- InputModel (dTree): Trained density estimation tree to load.
- MaxLeafSize (int): The maximum size of a leaf in the unpruned, fully
grown DET. Default value 10.
- MinLeafSize (int): The minimum size of a leaf in the unpruned, fully
grown DET. Default value 5.
- PathFormat (string): The format of path printing: 'lr', 'id-lr', or
'lr-id'. Default value 'lr'.
- SkipPruning (bool): Whether to bypass the pruning process and output
the unpruned tree only.
- Test (mat.Dense): A set of test points to estimate the density of.
- Training (mat.Dense): The data set on which to build a density
estimation tree.
- Verbose (bool): Display informational messages and the full list of
parameters and timers at the end of execution.
Output parameters:
- outputModel (dTree): Output to save trained density estimation tree
to.
- tagCountersFile (string): The file to output the number of points
that went to each leaf. Default value ''.
- tagFile (string): The file to output the tags (and possibly paths)
for each sample in the test set. Default value ''.
- testSetEstimates (mat.Dense): The output estimates on the test set
from the final optimally pruned tree.
- trainingSetEstimates (mat.Dense): The output density estimates on the
training set from the final optimally pruned tree.
- vi (mat.Dense): The output variable importance values for each
feature.
*/
func Det(param *DetOptionalParam) (dTree, string, string, *mat.Dense, *mat.Dense, *mat.Dense) {
params := getParams("det")
timers := getTimers()
disableBacktrace()
disableVerbose()
// Detect if the parameter was passed; set if so.
if param.Folds != 10 {
setParamInt(params, "folds", param.Folds)
setPassed(params, "folds")
}
// Detect if the parameter was passed; set if so.
if param.InputModel != nil {
setDTree(params, "input_model", param.InputModel)
setPassed(params, "input_model")
}
// Detect if the parameter was passed; set if so.
if param.MaxLeafSize != 10 {
setParamInt(params, "max_leaf_size", param.MaxLeafSize)
setPassed(params, "max_leaf_size")
}
// Detect if the parameter was passed; set if so.
if param.MinLeafSize != 5 {
setParamInt(params, "min_leaf_size", param.MinLeafSize)
setPassed(params, "min_leaf_size")
}
// Detect if the parameter was passed; set if so.
if param.PathFormat != "lr" {
setParamString(params, "path_format", param.PathFormat)
setPassed(params, "path_format")
}
// Detect if the parameter was passed; set if so.
if param.SkipPruning != false {
setParamBool(params, "skip_pruning", param.SkipPruning)
setPassed(params, "skip_pruning")
}
// Detect if the parameter was passed; set if so.
if param.Test != nil {
gonumToArmaMat(params, "test", param.Test, false)
setPassed(params, "test")
}
// Detect if the parameter was passed; set if so.
if param.Training != nil {
gonumToArmaMat(params, "training", param.Training, false)
setPassed(params, "training")
}
// Detect if the parameter was passed; set if so.
if param.Verbose != false {
setParamBool(params, "verbose", param.Verbose)
setPassed(params, "verbose")
enableVerbose()
}
// Mark all output options as passed.
setPassed(params, "output_model")
setPassed(params, "tag_counters_file")
setPassed(params, "tag_file")
setPassed(params, "test_set_estimates")
setPassed(params, "training_set_estimates")
setPassed(params, "vi")
// Call the mlpack program.
C.mlpackDet(params.mem, timers.mem)
// Initialize result variable and get output.
var outputModel dTree
outputModel.getDTree(params, "output_model")
tagCountersFile := getParamString(params, "tag_counters_file")
tagFile := getParamString(params, "tag_file")
var testSetEstimatesPtr mlpackArma
testSetEstimates := testSetEstimatesPtr.armaToGonumMat(params, "test_set_estimates")
var trainingSetEstimatesPtr mlpackArma
trainingSetEstimates := trainingSetEstimatesPtr.armaToGonumMat(params, "training_set_estimates")
var viPtr mlpackArma
vi := viPtr.armaToGonumMat(params, "vi")
// Clean memory.
cleanParams(params)
cleanTimers(timers)
// Return output(s).
return outputModel, tagCountersFile, tagFile, testSetEstimates, trainingSetEstimates, vi
}