diff --git a/src/Transformers/MaxAbsoluteScaler.php b/src/Transformers/MaxAbsoluteScaler.php index dfb3fab05..6adc6746a 100644 --- a/src/Transformers/MaxAbsoluteScaler.php +++ b/src/Transformers/MaxAbsoluteScaler.php @@ -100,7 +100,7 @@ public function update(Dataset $dataset) : void foreach ($this->maxabs as $column => $oldMax) { $values = $dataset->feature($column); - $max = max(array_map('abs', $values)); + $max = max(array_map('abs', array_filter($values, 'is_finite') ?: [0])); $max = max($oldMax, $max); diff --git a/src/Transformers/MinMaxNormalizer.php b/src/Transformers/MinMaxNormalizer.php index 3bf741ec4..38c17c3f6 100644 --- a/src/Transformers/MinMaxNormalizer.php +++ b/src/Transformers/MinMaxNormalizer.php @@ -138,10 +138,10 @@ public function fit(Dataset $dataset) : void $values = $dataset->feature($column); /** @var int|float $min */ - $min = min($values); + $min = min(array_filter($values, 'is_finite') ?: [0]); /** @var int|float $max */ - $max = max($values); + $max = max(array_filter($values, 'is_finite') ?: [0]); $scale = ($this->max - $this->min) / (($max - $min) ?: EPSILON); @@ -199,6 +199,10 @@ public function transform(array &$samples) : void foreach ($this->scales as $column => $scale) { $value = &$sample[$column]; + if (!is_finite($value)) { + continue; + } + $min = $this->minimums[$column]; $value *= $scale; @@ -224,6 +228,10 @@ public function reverseTransform(array &$samples) : void foreach ($this->scales as $column => $scale) { $value = &$sample[$column]; + if (!is_finite($value)) { + continue; + } + $min = $this->minimums[$column]; $value -= $this->min - $min * $scale; diff --git a/tests/Transformers/MaxAbsoluteScalerTest.php b/tests/Transformers/MaxAbsoluteScalerTest.php index a9923ad53..094230c2c 100644 --- a/tests/Transformers/MaxAbsoluteScalerTest.php +++ b/tests/Transformers/MaxAbsoluteScalerTest.php @@ -109,4 +109,15 @@ public function reverseTransformUnfitted() : void $this->transformer->reverseTransform($samples); } + + /** + * @test + */ + public function skipsNonFinite(): void + { + $samples = Unlabeled::build([[0.0, 3000.0, NAN, -6.0], [1.0, 30.0, NAN, 0.001]]); + $this->transformer->fit($samples); + $this->assertNan($samples[0][2]); + $this->assertNan($samples[1][2]); + } } diff --git a/tests/Transformers/MinMaxNormalizerTest.php b/tests/Transformers/MinMaxNormalizerTest.php index ee7624baf..8b8848e45 100644 --- a/tests/Transformers/MinMaxNormalizerTest.php +++ b/tests/Transformers/MinMaxNormalizerTest.php @@ -102,4 +102,15 @@ public function transformUnfitted() : void $this->transformer->transform($samples); } + + /** + * @test + */ + public function skipsNonFinite(): void + { + $samples = Unlabeled::build([[0.0, 3000.0, NAN, -6.0], [1.0, 30.0, NAN, 0.001]]); + $this->transformer->fit($samples); + $this->assertNan($samples[0][2]); + $this->assertNan($samples[1][2]); + } }