Skip to content

Commit

Permalink
Add transformer assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Jul 8, 2020
1 parent 9d4f370 commit 52de887
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 17 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
- 0.1.0-rc5
- Improved logging for Verbose Learners
- Added max document frequency to Word Count Vectorizer
- Whitespace Trimmer is now a separate transformer
- Whitespace Trimmer is now a separate Transformer
- Text Normalizers no longer remove extra whitespace
- Added extra characters pattern to Regex Filter class constants
- Moved Lambda Function transformer to Extras package
Expand Down
22 changes: 11 additions & 11 deletions src/Datasets/Generators/Blob.php
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,43 @@ class Blob implements Generator
*
* @var \Tensor\Vector|int|float
*/
protected $stddev;
protected $stdDev;

/**
* @param (int|float)[] $center
* @param int|float|(int|float)[] $stddev
* @param int|float|(int|float)[] $stdDev
* @throws \InvalidArgumentException
*/
public function __construct(array $center = [0, 0], $stddev = 1.0)
public function __construct(array $center = [0, 0], $stdDev = 1.0)
{
if (empty($center)) {
throw new InvalidArgumentException('Cannot generate samples'
. ' with dimensionality less than 1.');
}

if (is_array($stddev)) {
if (count($center) !== count($stddev)) {
if (is_array($stdDev)) {
if (count($center) !== count($stdDev)) {
throw new InvalidArgumentException('Number of center'
. ' coordinates and standard deviations must be equal.');
}

foreach ($stddev as $value) {
foreach ($stdDev as $value) {
if ($value < 0) {
throw new InvalidArgumentException('Standard deviation'
. " must be greater than 0, $value given.");
}
}

$stddev = Vector::quick($stddev);
$stdDev = Vector::quick($stdDev);
} else {
if ($stddev <= 0) {
if ($stdDev <= 0) {
throw new InvalidArgumentException('Standard deviation'
. " must be greater than 0, $stddev given.");
. " must be greater than 0, $stdDev given.");
}
}

$this->center = Vector::quick($center);
$this->stddev = $stddev;
$this->stdDev = $stdDev;
}

/**
Expand All @@ -95,7 +95,7 @@ public function generate(int $n) : Unlabeled
$d = $this->dimensions();

$samples = Matrix::gaussian($n, $d)
->multiply($this->stddev)
->multiply($this->stdDev)
->add($this->center)
->asArray();

Expand Down
6 changes: 6 additions & 0 deletions src/Transformers/GaussianRandomProjector.php
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,16 @@ class GaussianRandomProjector implements Transformer, Stateful, Stringable
*
* @param int $n
* @param float $maxDistortion
* @throws \InvalidArgumentException
* @return int
*/
public static function minDimensions(int $n, float $maxDistortion = 0.1) : int
{
if ($n < 0) {
throw new InvalidArgumentException('Number of samples'
. " must be be greater than 0, $n given.");
}

$denominator = $maxDistortion ** 2 / 2.0 - $maxDistortion ** 3 / 3.0;

return (int) round(4.0 * log($n) / $denominator);
Expand Down
36 changes: 34 additions & 2 deletions tests/Transformers/GaussianRandomProjectorTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use Rubix\ML\Transformers\GaussianRandomProjector;
use PHPUnit\Framework\TestCase;
use RuntimeException;
use Generator;

/**
* @group Transformers
Expand Down Expand Up @@ -47,10 +48,41 @@ public function build() : void

/**
* @test
* @dataProvider minDimensionsProvider
*
* @param int $n
* @param float $maxDistortion
* @param int $expected
*/
public function minDimensions() : void
public function minDimensions(int $n, float $maxDistortion, int $expected) : void
{
$this->assertEquals(663, GaussianRandomProjector::minDimensions(1000000, 0.5));
$this->assertEquals($expected, GaussianRandomProjector::minDimensions($n, $maxDistortion));
}

/**
* @return \Generator<array>
*/
public function minDimensionsProvider() : Generator
{
yield [10, 0.1, 1974];

yield [100, 0.1, 3947];

yield [1000, 0.1, 5921];

yield [10000, 0.1, 7895];

yield [100000, 0.1, 9868];

yield [1000000, 0.1, 11842];

yield [10000, 0.01, 741772];

yield [10000, 0.3, 1023];

yield [10000, 0.5, 442];

yield [10000, 0.99, 221];
}

/**
Expand Down
12 changes: 12 additions & 0 deletions tests/Transformers/RobustStandardizerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ public function fitUpdateTransform() : void

$this->assertTrue($this->transformer->fitted());

$medians = $this->transformer->medians();

$this->assertIsArray($medians);
$this->assertCount(3, $medians);
$this->assertContainsOnly('float', $medians);

$mads = $this->transformer->mads();

$this->assertIsArray($mads);
$this->assertCount(3, $mads);
$this->assertContainsOnly('float', $mads);

$sample = $this->generator->generate(1)
->apply($this->transformer)
->sample(0);
Expand Down
12 changes: 9 additions & 3 deletions tests/Transformers/TfIdfTransformerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ public function fitTransform() : void

$this->assertTrue($this->transformer->fitted());

$dfs = $this->transformer->dfs();

$this->assertIsArray($dfs);
$this->assertCount(19, $dfs);
$this->assertContainsOnly('int', $dfs);

$this->dataset->apply($this->transformer);

$outcome = [
[1.6931471805599454, 3.8630462173553424, 0., 0., 1.2876820724517808, 0., 0., 0., 1.2876820724517808, 2.5753641449035616, 0., 2.5753641449035616, 0., 0., 0., 6.772588722239782, 1.2876820724517808, 0., 1.6931471805599454],
[0., 1.2876820724517808, 1.6931471805599454, 0., 0., 2.5753641449035616, 1.6931471805599454, 0., 0., 0., 0., 3.8630462173553424, 0., 1.6931471805599454, 0., 0., 0., 0., 0.],
[0., 0., 0., 1.6931471805599454, 2.5753641449035616, 3.8630462173553424, 0., 0., 5.150728289807123, 2.5753641449035616, 0., 0., 1.6931471805599454, 0., 3.386294361119891, 0., 1.2876820724517808, 0., 0.],
[1.6931471805599454, 3.8630462173553424, 0.0, 0.0, 1.2876820724517808, 0.0, 0.0, 0.0, 1.2876820724517808, 2.5753641449035616, 0.0, 2.5753641449035616, 0.0, 0.0, 0.0, 6.772588722239782, 1.2876820724517808, 0.0, 1.6931471805599454],
[0.0, 1.2876820724517808, 1.6931471805599454, 0.0, 0.0, 2.5753641449035616, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 3.8630462173553424, 0.0, 1.6931471805599454, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.6931471805599454, 2.5753641449035616, 3.8630462173553424, 0.0, 0.0, 5.150728289807123, 2.5753641449035616, 0.0, 0.0, 1.6931471805599454, 0.0, 3.386294361119891, 0.0, 1.2876820724517808, 0.0, 0.0],
];

$this->assertEquals($outcome, $this->dataset->samples());
Expand Down
6 changes: 6 additions & 0 deletions tests/Transformers/WordCountVectorizerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ public function fitTransform() : void

$this->assertTrue($this->transformer->fitted());

$vocabulary = current($this->transformer->vocabularies());

$this->assertIsArray($vocabulary);
$this->assertCount(20, $vocabulary);
$this->assertContainsOnly('string', $vocabulary);

$this->dataset->apply($this->transformer);

$outcome = [
Expand Down
12 changes: 12 additions & 0 deletions tests/Transformers/ZScaleStandardizerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ public function fitUpdateTransform() : void

$this->assertTrue($this->transformer->fitted());

$means = $this->transformer->means();

$this->assertIsArray($means);
$this->assertCount(3, $means);
$this->assertContainsOnly('float', $means);

$variances = $this->transformer->variances();

$this->assertIsArray($variances);
$this->assertCount(3, $variances);
$this->assertContainsOnly('float', $variances);

$sample = $this->generator->generate(1)
->apply($this->transformer)
->sample(0);
Expand Down

0 comments on commit 52de887

Please sign in to comment.