Skip to content

Commit

Permalink
Access to network instance for neural net estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Jun 28, 2018
1 parent 63314c4 commit 7319f67
Show file tree
Hide file tree
Showing 50 changed files with 693 additions and 27 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ A type of classifier that uses the logistic (sigmoid) function to distinguish be
| Method | Description |
|--|--|
| `progress() : array` | Returns an array with the training progress at each epoch of training. |
| `network() : Network` | Returns the underlying neural network instance or *null* if not trained. |

##### Example:
```php
Expand Down Expand Up @@ -849,6 +850,7 @@ Multiclass [Neural Network](#neural-network) model that uses a series of user-de
| Method | Description |
|--|--|
| `progress() : array` | Returns an array with the training progress at each epoch of training. |
| `network() : Network` | Returns the underlying neural network instance or *null* if not trained. |

##### Example:
```php
Expand Down Expand Up @@ -930,6 +932,7 @@ A generalization of logistic regression for multiple class outcomes using a sing
| Method | Description |
|--|--|
| `progress() : array` | Returns an array with the training progress at each epoch of training. |
| `network() : Network` | Returns the underlying neural network instance or *null* if not trained. |

##### Example:
```php
Expand Down Expand Up @@ -1118,6 +1121,7 @@ A [Neural Network](#neural-network) with a continuous output layer suitable for
| Method | Description |
|--|--|
| `progress() : array` | Returns an array with the validation score at each epoch of training. |
| `network() : Network` | Returns the underlying neural network instance or *null* if not trained. |

##### Example:
```php
Expand Down
14 changes: 13 additions & 1 deletion src/Classifiers/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class LogisticRegression implements Binary, Online, Probabilistic, Persistable
/**
* The underlying computational graph.
*
* @var \Rubix\ML\NeuralNet\Network
* @var \Rubix\ML\NeuralNet\Network|null
*/
protected $network;

Expand Down Expand Up @@ -122,13 +122,25 @@ public function __construct(int $batchSize = 10, Optimizer $optimizer = null,
}

/**
* Return the training progress of the estimator.
*
* @return array
*/
public function progress() : array
{
return $this->progress;
}

/**
* Return the underlying neural network instance or null if not trained.
*
* @return \Rubix\ML\NeuralNet\Network|null
*/
public function network() : ?Network
{
return $this->network;
}

/**
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \InvalidArgumentException
Expand Down
14 changes: 13 additions & 1 deletion src/Classifiers/MultiLayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class MultiLayerPerceptron implements Multiclass, Online, Probabilistic, Persist
/**
* The underlying computational graph.
*
* @var \Rubix\ML\NeuralNet\Network
* @var \Rubix\ML\NeuralNet\Network|null
*/
protected $network;

Expand Down Expand Up @@ -180,13 +180,25 @@ public function __construct(array $hidden = [], int $batchSize = 50, Optimizer $
}

/**
* Return the training progress of the estimator.
*
* @return array
*/
public function progress() : array
{
return $this->progress;
}

/**
* Return the underlying neural network instance or null if not trained.
*
* @return \Rubix\ML\NeuralNet\Network|null
*/
public function network() : ?Network
{
return $this->network;
}

/**
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \InvalidArgumentException
Expand Down
14 changes: 13 additions & 1 deletion src/Classifiers/SoftmaxClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SoftmaxClassifier implements Multiclass, Online, Probabilistic, Persistabl
/**
* The underlying computational graph.
*
* @var \Rubix\ML\NeuralNet\Network
* @var \Rubix\ML\NeuralNet\Network|null
*/
protected $network;

Expand Down Expand Up @@ -122,13 +122,25 @@ public function __construct(int $batchSize = 10, Optimizer $optimizer = null,
}

/**
* Return the training progress of the estimator.
*
* @return array
*/
public function progress() : array
{
return $this->progress;
}

/**
* Return the underlying neural network instance or null if not trained.
*
* @return \Rubix\ML\NeuralNet\Network|null
*/
public function network() : ?Network
{
return $this->network;
}

/**
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \InvalidArgumentException
Expand Down
4 changes: 2 additions & 2 deletions src/NeuralNet/Network.php
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ public function parametric() : array
}

/**
* The depth of the network. i.e. the number of hidden layers.
* The depth of the network. i.e. the number of parametric layers.
*
* @return int
*/
public function depth() : int
{
return count($this->hidden());
return count($this->parametric());
}

/**
Expand Down
7 changes: 7 additions & 0 deletions src/Regressors/DummyRegressor.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Transformers\Strategies\Continuous;
use Rubix\ML\Transformers\Strategies\BlurryMean;
use InvalidArgumentException;

class DummyRegressor implements Regressor, Persistable
{
Expand Down Expand Up @@ -34,10 +35,16 @@ public function __construct(Continuous $strategy = null)
* Fit the training set to the given guessing strategy.
*
* @param \Rubix\ML\Datasets\Labeled $dataset
* @throws \InvalidArgumentException
* @return void
*/
public function train(Dataset $dataset) : void
{
if (!$dataset instanceof Labeled) {
throw new InvalidArgumentException('This Estimator requires a'
. ' Labeled training set.');
}

$this->strategy->fit($dataset->labels());
}

Expand Down
14 changes: 13 additions & 1 deletion src/Regressors/MLPRegressor.php
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class MLPRegressor implements Regressor, Online, Persistable
/**
* The underlying computational graph.
*
* @var \Rubix\ML\NeuralNet\Network
* @var \Rubix\ML\NeuralNet\Network|null
*/
protected $network;

Expand Down Expand Up @@ -168,13 +168,25 @@ public function __construct(array $hidden, int $batchSize = 50, Optimizer $optim
}

/**
* Return the training progress of the estimator.
*
* @return array
*/
public function progress() : array
{
return $this->progress;
}

/**
* Return the underlying neural network instance or null if not trained.
*
* @return \Rubix\ML\NeuralNet\Network|null
*/
public function network() : ?Network
{
return $this->network;
}

/**
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \InvalidArgumentException
Expand Down
11 changes: 11 additions & 0 deletions tests/Classifiers/AdaBoostTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
use Rubix\ML\Estimator;
use Rubix\ML\Persistable;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Classifiers\Binary;
use Rubix\ML\Classifiers\AdaBoost;
use Rubix\ML\Classifiers\Classifier;
use Rubix\ML\Classifiers\DecisionTree;
use PHPUnit\Framework\TestCase;
use InvalidArgumentException;

class AdaBoostTest extends TestCase
{
Expand Down Expand Up @@ -81,4 +83,13 @@ public function test_make_prediction()
$this->assertEquals('male', $predictions[0]);
$this->assertEquals('female', $predictions[1]);
}

public function test_train_with_unlabeled()
{
$dataset = new Unlabeled([['bad']]);

$this->expectException(InvalidArgumentException::class);

$this->estimator->train($dataset);
}
}
11 changes: 11 additions & 0 deletions tests/Classifiers/CommitteeMachineTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
use Rubix\ML\Estimator;
use Rubix\ML\Probabilistic;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Classifiers\Multiclass;
use Rubix\ML\Classifiers\Classifier;
use Rubix\ML\Kernels\Distance\Euclidean;
use Rubix\ML\Classifiers\CommitteeMachine;
use Rubix\ML\Classifiers\KNearestNeighbors;
use PHPUnit\Framework\TestCase;
use InvalidArgumentException;

class CommitteeMachineTest extends TestCase
{
Expand Down Expand Up @@ -100,4 +102,13 @@ public function test_predict_proba()
$this->assertLessThan(0.5, $probabilities[1]['male']);
$this->assertGreaterThan(0.5, $probabilities[1]['female']);
}

public function test_train_with_unlabeled()
{
$dataset = new Unlabeled([['bad']]);

$this->expectException(InvalidArgumentException::class);

$this->estimator->train($dataset);
}
}
14 changes: 12 additions & 2 deletions tests/Classifiers/DecisionTreeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

use Rubix\ML\Estimator;
use Rubix\ML\Persistable;
use Rubix\ML\Probabilistic;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Classifiers\Multiclass;
use Rubix\ML\Classifiers\Classifier;
use Rubix\ML\Classifiers\DecisionTree;
use Rubix\ML\Probabilistic;

use PHPUnit\Framework\TestCase;
use InvalidArgumentException;

class DecisionTreeTest extends TestCase
{
Expand Down Expand Up @@ -97,4 +98,13 @@ public function test_predict_proba()
$this->assertLessThan(0.5, $probabilities[1]['male']);
$this->assertGreaterThan(0.5, $probabilities[1]['female']);
}

public function test_train_with_unlabeled()
{
$dataset = new Unlabeled([['bad']]);

$this->expectException(InvalidArgumentException::class);

$this->estimator->train($dataset);
}
}
11 changes: 11 additions & 0 deletions tests/Classifiers/DummyClassifierTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
use Rubix\ML\Estimator;
use Rubix\ML\Persistable;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Classifiers\Classifier;
use Rubix\ML\Classifiers\DummyClassifier;
use Rubix\ML\Transformers\Strategies\PopularityContest;
use PHPUnit\Framework\TestCase;
use InvalidArgumentException;

class DummyClassifierTest extends TestCase
{
Expand Down Expand Up @@ -77,4 +79,13 @@ public function test_make_prediction()
$this->assertContains($predictions[0], ['male', 'female']);
$this->assertContains($predictions[1], ['male', 'female']);
}

public function test_train_with_unlabeled()
{
$dataset = new Unlabeled([['bad']]);

$this->expectException(InvalidArgumentException::class);

$this->estimator->train($dataset);
}
}
11 changes: 11 additions & 0 deletions tests/Classifiers/KNearestNeighborsTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
use Rubix\ML\Persistable;
use Rubix\ML\Probabilistic;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Classifiers\Multiclass;
use Rubix\ML\Classifiers\Classifier;
use Rubix\ML\Kernels\Distance\Euclidean;
use Rubix\ML\Classifiers\KNearestNeighbors;
use PHPUnit\Framework\TestCase;
use InvalidArgumentException;

class KNearestNeighborsTest extends TestCase
{
Expand Down Expand Up @@ -99,4 +101,13 @@ public function test_predict_proba()
$this->assertLessThan(0.5, $probabilities[1]['male']);
$this->assertGreaterThan(0.5, $probabilities[1]['female']);
}

public function test_train_with_unlabeled()
{
$dataset = new Unlabeled([['bad']]);

$this->expectException(InvalidArgumentException::class);

$this->estimator->train($dataset);
}
}
11 changes: 11 additions & 0 deletions tests/Classifiers/LogisticRegressionTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
use Rubix\ML\Persistable;
use Rubix\ML\Probabilistic;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Classifiers\Binary;
use Rubix\ML\Classifiers\Classifier;
use Rubix\ML\NeuralNet\Optimizers\Adam;
use Rubix\ML\Classifiers\LogisticRegression;
use PHPUnit\Framework\TestCase;
use InvalidArgumentException;

class LogisticRegressionTest extends TestCase
{
Expand Down Expand Up @@ -99,4 +101,13 @@ public function test_predict_proba()
$this->assertLessThan(0.5, $probabilities[1]['male']);
$this->assertGreaterThan(0.5, $probabilities[1]['female']);
}

public function test_train_with_unlabeled()
{
$dataset = new Unlabeled([['bad']]);

$this->expectException(InvalidArgumentException::class);

$this->estimator->train($dataset);
}
}
Loading

0 comments on commit 7319f67

Please sign in to comment.