Skip to content

Commit

Permalink
Optimized CART and Extra Tree memory and node splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Apr 12, 2021
1 parent 737ec09 commit 9c58eae
Show file tree
Hide file tree
Showing 16 changed files with 106 additions and 111 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
- Optimized CART binary categorical node splitting
- Interval Discretizer outputs numeric string categories
- Renamed Random Hot Deck Imputer
- Changed order of decision tree hyper-parameters

- 0.4.1
- Optimized CART node splitting for low variance continuous features
Expand Down
6 changes: 3 additions & 3 deletions docs/classifiers/classification-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ A binary tree-based learner that greedily constructs a decision map for classifi
|---|---|---|---|---|
| 1 | maxHeight | PHP_INT_MAX | int | The maximum height of the tree. |
| 2 | maxLeafSize | 3 | int | The max number of samples that a leaf node can contain. |
| 3 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |
| 4 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary for a node *not* to be post pruned during tree growth. |
| 3 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary to continue splitting a subtree. |
| 4 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |

## Example
```php
use Rubix\ML\Classifiers\ClassificationTree;

$estimator = new ClassificationTree(10, 7, 4, 0.01);
$estimator = new ClassificationTree(10, 5, 0.001, null);
```

## Additional Methods
Expand Down
6 changes: 3 additions & 3 deletions docs/classifiers/extra-tree-classifier.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ An *Extremely Randomized* Classification Tree that recursively chooses node spli
|---|---|---|---|---|
| 1 | maxHeight | PHP_INT_MAX | int | The maximum height of the tree. |
| 2 | maxLeafSize | 3 | int | The max number of samples that a leaf node can contain. |
| 3 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |
| 4 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary for a node *not* to be post pruned during tree growth. |
| 3 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary to continue splitting a subtree. |
| 4 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |

## Example
```php
use Rubix\ML\Classifiers\ExtraTreeClassifier;

$estimator = new ExtraTreeClassifier(50, 3, 4, 1e-7);
$estimator = new ExtraTreeClassifier(50, 3, 1e-7, 10);
```

## Additional Methods
Expand Down
6 changes: 3 additions & 3 deletions docs/regressors/extra-tree-regressor.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
|---|---|---|---|---|
| 1 | maxHeight | PHP_INT_MAX | int | The maximum height of the tree. |
| 2 | maxLeafSize | 3 | int | The max number of samples that a leaf node can contain. |
| 3 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |
| 4 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary for a node *not* to be post pruned during tree growth. |
| 3 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary to continue splitting a subtree. |
| 4 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |

## Example
```php
use Rubix\ML\Regressors\ExtraTreeRegressor;

$estimator = new ExtraTreeRegressor(30, 3, 20, 0.05);
$estimator = new ExtraTreeRegressor(30, 5, 0.05, null);
```

## Additional Methods
Expand Down
6 changes: 3 additions & 3 deletions docs/regressors/regression-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ A decision tree based on the CART (*Classification and Regression Tree*) learnin
|---|---|---|---|---|
| 1 | maxHeight | PHP_INT_MAX | int | The maximum height of the tree. |
| 2 | maxLeafSize | 3 | int | The max number of samples that a leaf node can contain. |
| 3 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |
| 4 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary for a node *not* to be post pruned during tree growth. |
| 3 | minPurityIncrease | 1e-7 | float | The minimum increase in purity necessary to continue splitting a subtree. |
| 4 | maxFeatures | Auto | int | The max number of feature columns to consider when determining a best split. |

## Example
```php
use Rubix\ML\Regressors\RegressionTree;

$estimator = new RegressionTree(20, 2, null, 1e-3);
$estimator = new RegressionTree(20, 2, 1e-3, 10);
```

## Additional Methods
Expand Down
36 changes: 22 additions & 14 deletions src/Classifiers/ClassificationTree.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
use Rubix\ML\Exceptions\RuntimeException;

use function Rubix\ML\argmax;
use function count;
use function array_fill;
use function array_combine;
use function array_replace;
use function array_count_values;
use function array_map;

/**
* Classification Tree
Expand All @@ -46,9 +52,9 @@ class ClassificationTree extends CART implements Estimator, Learner, Probabilist
use AutotrackRevisions;

/**
* The zero vector for the possible class outcomes.
* The list of possible class outcomes.
*
* @var float[]
* @var list<string>
*/
protected $classes = [
//
Expand All @@ -57,16 +63,16 @@ class ClassificationTree extends CART implements Estimator, Learner, Probabilist
/**
* @param int $maxHeight
* @param int $maxLeafSize
* @param int|null $maxFeatures
* @param float $minPurityIncrease
* @param int|null $maxFeatures
*/
public function __construct(
int $maxHeight = PHP_INT_MAX,
int $maxLeafSize = 3,
?int $maxFeatures = null,
float $minPurityIncrease = 1e-7
float $minPurityIncrease = 1e-7,
?int $maxFeatures = null
) {
parent::__construct($maxHeight, $maxLeafSize, $maxFeatures, $minPurityIncrease);
parent::__construct($maxHeight, $maxLeafSize, $minPurityIncrease, $maxFeatures);
}

/**
Expand Down Expand Up @@ -108,8 +114,8 @@ public function params() : array
return [
'max height' => $this->maxHeight,
'max leaf size' => $this->maxLeafSize,
'max features' => $this->maxFeatures,
'min purity increase' => $this->minPurityIncrease,
'max features' => $this->maxFeatures,
];
}

Expand Down Expand Up @@ -137,7 +143,7 @@ public function train(Dataset $dataset) : void
new LabelsAreCompatibleWithLearner($dataset, $this),
])->check();

$this->classes = array_fill_keys($dataset->possibleOutcomes(), 0.0);
$this->classes = $dataset->possibleOutcomes();

$this->grow($dataset);
}
Expand Down Expand Up @@ -207,7 +213,9 @@ public function probaSample(array $sample) : array
/** @var \Rubix\ML\Graph\Nodes\Best $node */
$node = $this->search($sample);

return array_replace($this->classes, $node->probabilities()) ?? [];
$template = array_combine($this->classes, array_fill(0, count($this->classes), 0.0)) ?: [];

return array_replace($template, $node->probabilities());
}

/**
Expand Down Expand Up @@ -237,20 +245,20 @@ protected function terminate(Labeled $dataset) : Best
}

/**
* Compute the gini impurity of a labeled dataset.
* Calculate the impurity of a set of labels.
*
* @param \Rubix\ML\Datasets\Labeled $dataset
* @param list<string|int> $labels
* @return float
*/
protected function impurity(Labeled $dataset) : float
protected function impurity(array $labels) : float
{
$n = $dataset->numRows();
$n = count($labels);

if ($n <= 1) {
return 0.0;
}

$counts = array_count_values($dataset->labels());
$counts = array_count_values($labels);

$gini = 0.0;

Expand Down
34 changes: 21 additions & 13 deletions src/Classifiers/ExtraTreeClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
use Rubix\ML\Exceptions\RuntimeException;

use function Rubix\ML\argmax;
use function count;
use function array_fill;
use function array_combine;
use function array_replace;
use function array_count_values;
use function array_map;

/**
* Extra Tree Classifier
Expand All @@ -47,9 +53,9 @@ class ExtraTreeClassifier extends ExtraTree implements Estimator, Learner, Proba
use AutotrackRevisions;

/**
* The zero vector for the possible class outcomes.
* The list of possible class outcomes.
*
* @var float[]
* @var string[]
*/
protected $classes = [
//
Expand All @@ -58,16 +64,16 @@ class ExtraTreeClassifier extends ExtraTree implements Estimator, Learner, Proba
/**
* @param int $maxHeight
* @param int $maxLeafSize
* @param int|null $maxFeatures
* @param float $minPurityIncrease
* @param int|null $maxFeatures
*/
public function __construct(
int $maxHeight = PHP_INT_MAX,
int $maxLeafSize = 3,
?int $maxFeatures = null,
float $minPurityIncrease = 1e-7
float $minPurityIncrease = 1e-7,
?int $maxFeatures = null
) {
parent::__construct($maxHeight, $maxLeafSize, $maxFeatures, $minPurityIncrease);
parent::__construct($maxHeight, $maxLeafSize, $minPurityIncrease, $maxFeatures);
}

/**
Expand Down Expand Up @@ -138,7 +144,7 @@ public function train(Dataset $dataset) : void
new LabelsAreCompatibleWithLearner($dataset, $this),
])->check();

$this->classes = array_fill_keys($dataset->possibleOutcomes(), 0.0);
$this->classes = $dataset->possibleOutcomes();

$this->grow($dataset);
}
Expand Down Expand Up @@ -208,7 +214,9 @@ public function probaSample(array $sample) : array
/** @var \Rubix\ML\Graph\Nodes\Best $node */
$node = $this->search($sample);

return array_replace($this->classes, $node->probabilities()) ?? [];
$template = array_combine($this->classes, array_fill(0, count($this->classes), 0.0)) ?: [];

return array_replace($template, $node->probabilities());
}

/**
Expand Down Expand Up @@ -240,20 +248,20 @@ protected function terminate(Labeled $dataset) : Best
}

/**
* Compute the entropy of a labeled dataset.
* Calculate the impurity of a set of labels.
*
* @param \Rubix\ML\Datasets\Labeled $dataset
* @param list<string|int> $labels
* @return float
*/
protected function impurity(Labeled $dataset) : float
protected function impurity(array $labels) : float
{
$n = $dataset->numRows();
$n = count($labels);

if ($n <= 1) {
return 0.0;
}

$counts = array_count_values($dataset->labels());
$counts = array_count_values($labels);

$entropy = 0.0;

Expand Down
Loading

0 comments on commit 9c58eae

Please sign in to comment.