Skip to content

Commit

Permalink
Revert numerical instability prevention
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Jul 5, 2021
1 parent 63ab043 commit abd5e5c
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 10 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
- 1.1.1
- Fix Gradient Boost subsampling and importance scores
- Prevent Random Forest and AdaBoost sample weight underflow

- 1.1.0
- Update to Scienide Tensor 3.0
Expand Down
7 changes: 2 additions & 5 deletions src/Classifiers/AdaBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
use Rubix\ML\Verbose;
use Rubix\ML\Estimator;
use Rubix\ML\Persistable;
use Rubix\ML\Helpers\CPU;
use Rubix\ML\Probabilistic;
use Rubix\ML\EstimatorType;
use Rubix\ML\Helpers\Params;
Expand Down Expand Up @@ -316,9 +315,7 @@ public function train(Dataset $dataset) : void

$p = max(self::MIN_SUBSAMPLE, (int) round($this->ratio * $m));

$epsilon = 2.0 * CPU::epsilon();

$weights = array_fill(0, $m, max($epsilon, 1.0 / $m));
$weights = array_fill(0, $m, 1.0 / $m);

$this->classes = array_fill_keys($classes, 0.0);
$this->featureCount = $n;
Expand Down Expand Up @@ -407,7 +404,7 @@ public function train(Dataset $dataset) : void
$total = array_sum($weights) ?: EPSILON;

foreach ($weights as &$weight) {
$weight = max($epsilon, $weight / $total);
$weight /= $total;
}
}

Expand Down
5 changes: 1 addition & 4 deletions src/Classifiers/RandomForest.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
use Rubix\ML\Parallel;
use Rubix\ML\Estimator;
use Rubix\ML\Persistable;
use Rubix\ML\Helpers\CPU;
use Rubix\ML\Probabilistic;
use Rubix\ML\RanksFeatures;
use Rubix\ML\EstimatorType;
Expand Down Expand Up @@ -224,12 +223,10 @@ public function train(Dataset $dataset) : void

$min = min($counts);

$epsilon = CPU::epsilon();

$weights = [];

foreach ($dataset->labels() as $label) {
$weights[] = max($epsilon, $min / $counts[$label]);
$weights[] = $min / $counts[$label];
}
}

Expand Down

0 comments on commit abd5e5c

Please sign in to comment.