From 23d2abf770a627bdd0e1cbcb324e60a2ad8409b4 Mon Sep 17 00:00:00 2001 From: Andrew DalPino Date: Mon, 5 Jul 2021 01:42:05 -0500 Subject: [PATCH] Optimize Gradient Boost --- src/Regressors/GradientBoost.php | 34 +++++++++++++------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/Regressors/GradientBoost.php b/src/Regressors/GradientBoost.php index e12dc3bd8..bca323a76 100644 --- a/src/Regressors/GradientBoost.php +++ b/src/Regressors/GradientBoost.php @@ -83,7 +83,7 @@ class GradientBoost implements Estimator, Learner, RanksFeatures, Verbose, Persi * * @var int */ - protected const MIN_SUBSAMPLE = 1; + protected const MIN_SUBSAMPLE = 2; /** * The regressor that will fix up the error residuals of the *weak* base learner. @@ -392,12 +392,12 @@ public function train(Dataset $dataset) : void $this->base->train($training); - $out = $prevOut = $this->base->predict($training); + $out = $this->base->predict($training); $targets = $training->labels(); if (!$testing->empty()) { - $prevOutTest = $this->base->predict($testing); + $outTest = $this->base->predict($testing); } $p = max(self::MIN_SUBSAMPLE, (int) round($this->ratio * $m)); @@ -432,16 +432,16 @@ public function train(Dataset $dataset) : void $predictions = $booster->predict($training); - $out = array_map([$this, 'updateOut'], $predictions, $prevOut); + $out = array_map([$this, 'updateOut'], $predictions, $out); $this->losses[$epoch] = $loss; $this->ensemble[] = $booster; - if (isset($prevOutTest)) { + if (isset($outTest)) { $predictions = $booster->predict($testing); - $outTest = array_map([$this, 'updateOut'], $predictions, $prevOutTest); + $outTest = array_map([$this, 'updateOut'], $predictions, $outTest); $score = $this->metric->score($outTest, $testing->labels()); @@ -470,18 +470,13 @@ public function train(Dataset $dataset) : void if ($delta >= $this->window) { break; } - - $prevOutTest = $outTest; } if (abs($prevLoss - $loss) < $this->minChange) { break; } - if ($epoch < $this->estimators) { - $prevOut = $out; - $prevLoss = $loss; - } + $prevLoss = $loss; } if ($this->scores and end($this->scores) <= $bestScore) { @@ -518,10 +513,7 @@ public function predict(Dataset $dataset) : array foreach ($this->ensemble as $estimator) { $predictions = $estimator->predict($dataset); - /** @var int $j */ - foreach ($predictions as $j => $prediction) { - $out[$j] += $this->rate * $prediction; - } + $out = array_map([$this, 'updateOut'], $predictions, $out); } return $out; @@ -542,7 +534,9 @@ public function featureImportances() : array $importances = array_fill(0, $this->featureCount, 0.0); foreach ($this->ensemble as $tree) { - foreach ($tree->featureImportances() as $column => $importance) { + $importances = $tree->featureImportances(); + + foreach ($importances as $column => $importance) { $importances[$column] += $importance; } } @@ -560,12 +554,12 @@ public function featureImportances() : array * Compute the output for an iteration. * * @param float $prediction - * @param float $prevOut + * @param float $out * @return float */ - protected function updateOut(float $prediction, float $prevOut) : float + protected function updateOut(float $prediction, float $out) : float { - return $this->rate * $prediction + $prevOut; + return $this->rate * $prediction + $out; } /**