Skip to content

Commit

Permalink
Allow empty dataset objects in stack()
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Oct 30, 2021
1 parent ec2c41c commit af69742
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 39 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
- 1.2.2
- Allow empty dataset objects in `stack()`

- 1.2.1
- Refactor stratified methods on Labeled dataset
- Narrower typehints
Expand Down
4 changes: 2 additions & 2 deletions src/Datasets/Dataset.php
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ abstract public static function fromIterator(iterable $iterator) : self;
/**
* Stack a number of datasets on top of each other to form a single dataset.
*
* @param \Rubix\ML\Datasets\Dataset[] $datasets
* @param iterable<\Rubix\ML\Datasets\Dataset> $datasets
* @return static
*/
abstract public static function stack(array $datasets) : self;
abstract public static function stack(iterable $datasets) : self;

/**
* Return a 2-tuple containing the shape of the sample matrix i.e the number of rows and columns.
Expand Down
51 changes: 28 additions & 23 deletions src/Datasets/Labeled.php
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,35 @@ public static function fromIterator(iterable $iterator) : self
/**
* Stack a number of datasets on top of each other to form a single dataset.
*
* @param \Rubix\ML\Datasets\Labeled[] $datasets
* @param iterable<\Rubix\ML\Datasets\Labeled> $datasets
* @throws \Rubix\ML\Exceptions\InvalidArgumentException
* @return self
*/
public static function stack(array $datasets) : self
public static function stack(iterable $datasets) : self
{
$n = $datasets[array_key_first($datasets)]->numFeatures();

$samples = $labels = [];

foreach ($datasets as $dataset) {
foreach ($datasets as $i => $dataset) {
if (!$dataset instanceof Labeled) {
throw new InvalidArgumentException('Dataset must be'
. ' an instance of Labeled, ' . get_class($dataset)
. ' given.');
}

if ($dataset->numFeatures() !== $n) {
throw new InvalidArgumentException('Dataset must have'
. " the same number of columns, $n expected but "
. $dataset->numFeatures() . ' given.');
if ($dataset->empty()) {
continue;
}

if (isset($lastNumFeatures) and $dataset->numFeatures() !== $lastNumFeatures) {
throw new InvalidArgumentException("Dataset $i must have"
. " the same number of columns, $lastNumFeatures"
. " expected but {$dataset->numFeatures()} given.");
}

$samples[] = $dataset->samples();
$labels[] = $dataset->labels();

$lastNumFeatures = $dataset->numFeatures();
}

return self::quick(
Expand Down Expand Up @@ -463,16 +467,17 @@ public function split(float $ratio = 0.5) : array

$n = (int) floor($ratio * $this->numSamples());

return [
self::quick(
array_slice($this->samples, 0, $n),
array_slice($this->labels, 0, $n)
),
self::quick(
array_slice($this->samples, $n),
array_slice($this->labels, $n)
),
];
$left = self::quick(
array_slice($this->samples, 0, $n),
array_slice($this->labels, 0, $n)
);

$right = self::quick(
array_slice($this->samples, $n),
array_slice($this->labels, $n)
);

return [$left, $right];
}

/**
Expand All @@ -498,10 +503,10 @@ public function stratifiedSplit(float $ratio = 0.5) : array
$rightStrata[] = $right;
}

return [
self::stack($leftStrata),
self::stack($rightStrata),
];
$left = self::stack($leftStrata);
$right = self::stack($rightStrata);

return [$left, $right];
}

/**
Expand Down
30 changes: 17 additions & 13 deletions src/Datasets/Unlabeled.php
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,33 @@ public static function fromIterator(iterable $iterator) : self
/**
* Stack a number of datasets on top of each other to form a single dataset.
*
* @param \Rubix\ML\Datasets\Dataset[] $datasets
* @param iterable<\Rubix\ML\Datasets\Dataset> $datasets
* @throws \Rubix\ML\Exceptions\InvalidArgumentException
* @return self
*/
public static function stack(array $datasets) : self
public static function stack(iterable $datasets) : self
{
$n = $datasets[array_key_first($datasets)]->numFeatures();

$samples = [];

foreach ($datasets as $dataset) {
foreach ($datasets as $i => $dataset) {
if (!$dataset instanceof Dataset) {
throw new InvalidArgumentException('Dataset must implement'
. ' the Dataset interface.');
}

if ($dataset->numFeatures() !== $n) {
throw new InvalidArgumentException('Dataset must have'
. " the same number of columns, $n expected but"
. " {$dataset->numFeatures()} given.");
if ($dataset->empty()) {
continue;
}

if (isset($lastNumFeatures) and $dataset->numFeatures() !== $lastNumFeatures) {
throw new InvalidArgumentException("Dataset $i must have"
. " the same number of features, $lastNumFeatures"
. " expected but {$dataset->numFeatures()} given.");
}

$samples[] = $dataset->samples();

$lastNumFeatures = $dataset->numFeatures();
}

return self::quick(array_merge(...$samples));
Expand Down Expand Up @@ -261,10 +265,10 @@ public function split(float $ratio = 0.5) : array

$n = (int) floor($ratio * $this->numSamples());

return [
self::quick(array_slice($this->samples, 0, $n)),
self::quick(array_slice($this->samples, $n)),
];
$left = self::quick(array_slice($this->samples, 0, $n));
$right = self::quick(array_slice($this->samples, $n));

return [$left, $right];
}

/**
Expand Down
2 changes: 1 addition & 1 deletion src/constants.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
*
* @var string
*/
const VERSION = '1.2.1';
const VERSION = '1.2.2';

/**
* A small number used in substitution of 0.
Expand Down

0 comments on commit af69742

Please sign in to comment.