Skip to content

Commit

Permalink
K means and means shift now implement probabilistc
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Apr 18, 2019
1 parent dbca7e5 commit 509024a
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Added Random, K-MC2, and Plus Plus seeders
- Accelerated Mean Shift with Ball Tree
- Added radius estimation to Mean Shift
- K Means and Mean Shift now implement Probabilistic
- Gaussian Mixture now supports seeders
- Changed order of K Means hyperparameters
- Moved Ranking interface to anomaly detector namespace
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2159,7 +2159,7 @@ $estimator = new GaussianMixture(5, 1e-4, 100, new KMC2(50));
### K Means
A fast online centroid-based hard clustering algorithm capable of clustering linearly separable data points given some prior knowledge of the target number of clusters (defined by *k*). K Means with inertia is trained using adaptive mini batch gradient descent and minimizes the inertial cost function. Inertia is defined as the sum of the distances between each sample and its nearest cluster centroid.

**Interfaces:** [Estimator](#estimators), [Learner](#learner), [Online](#online), [Persistable](#persistable), [Verbose](#verbose)
**Interfaces:** [Estimator](#estimators), [Learner](#learner), [Online](#online), [Probabilistic](#probabilistic), [Persistable](#persistable), [Verbose](#verbose)

**Compatibility:** Continuous

Expand Down Expand Up @@ -2210,7 +2210,7 @@ A hierarchical clustering algorithm that uses peak finding to locate the local m

> **Note**: Seeding Mean Shift using a [Seeder](#seeders) can often speed up convergence of large datasets. The default is to initialize all training samples as seeds.
**Interfaces:** [Estimator](#estimators), [Learner](#learner), [Verbose](#verbose), [Persistable](#persistable)
**Interfaces:** [Estimator](#estimators), [Learner](#learner), [Probabilistic](#probabilistic), [Verbose](#verbose), [Persistable](#persistable)

**Compatibility:** Continuous

Expand Down
48 changes: 46 additions & 2 deletions src/Clusterers/KMeans.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

namespace Rubix\ML\Clusterers;

use Rubix\ML\Online;
use Rubix\ML\Learner;
use Rubix\ML\Verbose;
use Rubix\ML\Estimator;
use Rubix\Tensor\Matrix;
use Rubix\ML\Persistable;
use Rubix\ML\Probabilistic;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Other\Helpers\Params;
Expand Down Expand Up @@ -37,7 +39,7 @@
* @package Rubix/ML
* @author Andrew DalPino
*/
class KMeans implements Estimator, Learner, Persistable, Verbose
class KMeans implements Estimator, Learner, Online, Probabilistic, Persistable, Verbose
{
use LoggerAware;

Expand Down Expand Up @@ -328,7 +330,7 @@ public function partial(Dataset $dataset) : void
}
}

$inertia = $this->inertia($samples, $labels);
$inertia = $this->inertia($samples);

$this->steps[] = $inertia;

Expand Down Expand Up @@ -371,6 +373,25 @@ public function predict(Dataset $dataset) : array
return array_map([self::class, 'assign'], $dataset->samples());
}

/**
* Estimate probabilities for each possible outcome.
*
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \InvalidArgumentException
* @throws \RuntimeException
* @return array
*/
public function proba(Dataset $dataset) : array
{
if (empty($this->centroids)) {
throw new RuntimeException('Estimator has not been trained.');
}

DatasetIsCompatibleWithEstimator::check($dataset, $this);

return array_map([self::class, 'membership'], $dataset->samples());
}

/**
* Label a given sample based on its distance from a particular centroid.
*
Expand All @@ -394,6 +415,29 @@ protected function assign(array $sample) : int
return (int) $bestCluster;
}

/**
* Return the membership of a sample to each of the k centroids.
*
* @param array $sample
* @return array
*/
protected function membership(array $sample) : array
{
$membership = $distances = [];

foreach ($this->centroids as $centroid) {
$distances[] = $this->kernel->compute($sample, $centroid);
}

$total = array_sum($distances) ?: self::EPSILON;

foreach ($distances as $distance) {
$membership[] = $distance / $total;
}

return $membership;
}

/**
* Calculate the sum of distances between all samples and their closest
* centroid.
Expand Down
49 changes: 46 additions & 3 deletions src/Clusterers/MeanShift.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use Rubix\ML\Estimator;
use Rubix\Tensor\Matrix;
use Rubix\ML\Persistable;
use Rubix\ML\Probabilistic;
use Rubix\ML\Graph\BallTree;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Other\Helpers\Stats;
Expand Down Expand Up @@ -36,7 +37,7 @@
* @package Rubix/ML
* @author Andrew DalPino
*/
class MeanShift implements Estimator, Learner, Verbose, Persistable
class MeanShift implements Estimator, Learner, Probabilistic, Verbose, Persistable
{
use LoggerAware;

Expand Down Expand Up @@ -328,7 +329,7 @@ public function train(Dataset $dataset) : void
}
}

$shift = $this->centroidShift($centroids, $previous);
$shift = $this->shift($centroids, $previous);

$this->steps[] = $shift;

Expand Down Expand Up @@ -373,6 +374,25 @@ public function predict(Dataset $dataset) : array
return array_map([self::class, 'assign'], $dataset->samples());
}

/**
* Estimate probabilities for each possible outcome.
*
* @param \Rubix\ML\Datasets\Dataset $dataset
* @throws \InvalidArgumentException
* @throws \RuntimeException
* @return array
*/
public function proba(Dataset $dataset) : array
{
if (empty($this->centroids)) {
throw new RuntimeException('Estimator has not been trained.');
}

DatasetIsCompatibleWithEstimator::check($dataset, $this);

return array_map([self::class, 'membership'], $dataset->samples());
}

/**
* Label a given sample based on its distance from a particular centroid.
*
Expand All @@ -396,14 +416,37 @@ protected function assign(array $sample) : int
return (int) $bestCluster;
}

/**
* Return the membership of a sample to each of the centroids.
*
* @param array $sample
* @return array
*/
protected function membership(array $sample) : array
{
$membership = $distances = [];

foreach ($this->centroids as $centroid) {
$distances[] = $this->kernel->compute($sample, $centroid);
}

$total = array_sum($distances) ?: self::EPSILON;

foreach ($distances as $distance) {
$membership[] = $distance / $total;
}

return $membership;
}

/**
* Calculate the magnitude (l1) of centroid shift from the previous epoch.
*
* @param array $current
* @param array $previous
* @return float
*/
protected function centroidShift(array $current, array $previous) : float
protected function shift(array $current, array $previous) : float
{
$shift = 0.;

Expand Down
18 changes: 9 additions & 9 deletions src/Regressors/GradientBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,10 @@ public function train(Dataset $dataset) : void
$yHat[] = $dataset->label($i) - $prediction;
}

$residuals = Labeled::quick($dataset->samples(), $yHat);
$residual = Labeled::quick($dataset->samples(), $yHat);

if ($this->logger) {
$this->logger->info('Attempting to correct'
. " error residuals with $this->estimators "
$this->logger->info("Boosting with $this->estimators "
. Params::shortName($this->booster)
. ($this->estimators > 1 ? 's' : ''));
}
Expand All @@ -282,20 +281,21 @@ public function train(Dataset $dataset) : void
for ($epoch = 1; $epoch <= $this->estimators; $epoch++) {
$booster = clone $this->booster;

$subset = $residuals->randomize()->head($p);
$subset = $residual->randomize()->head($p);

$booster->train($subset);

$predictions = $booster->predict($residuals);
$predictions = $booster->predict($residual);

$labels = $residual->labels();

$loss = 0.;
$yHat = [];

foreach ($predictions as $i => $prediction) {
$label = $residuals->label($i);
$label = $labels[$i];

$loss += ($label - $prediction) ** 2;
$yHat[] = $label - ($this->rate * $prediction);
$yHat[$i] = $label - ($this->rate * $prediction);
}

$loss /= $n;
Expand All @@ -319,7 +319,7 @@ public function train(Dataset $dataset) : void
break 1;
}

$residuals = Labeled::quick($residuals->samples(), $yHat);
$residual = Labeled::quick($residual->samples(), $yHat);

$previous = $loss;
}
Expand Down
4 changes: 4 additions & 0 deletions tests/Clusterers/KMeansTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

namespace Rubix\ML\Tests\Clusterers;

use Rubix\ML\Online;
use Rubix\ML\Learner;
use Rubix\ML\Verbose;
use Rubix\ML\Estimator;
use Rubix\ML\Persistable;
use Rubix\ML\Probabilistic;
use Rubix\ML\Clusterers\KMeans;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Other\Helpers\DataType;
Expand Down Expand Up @@ -54,6 +56,8 @@ public function test_build_clusterer()
{
$this->assertInstanceOf(KMeans::class, $this->estimator);
$this->assertInstanceOf(Learner::class, $this->estimator);
$this->assertInstanceOf(Online::class, $this->estimator);
$this->assertInstanceOf(Probabilistic::class, $this->estimator);
$this->assertInstanceOf(Persistable::class, $this->estimator);
$this->assertInstanceOf(Verbose::class, $this->estimator);
$this->assertInstanceOf(Estimator::class, $this->estimator);
Expand Down
2 changes: 2 additions & 0 deletions tests/Clusterers/MeanShiftTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use Rubix\ML\Verbose;
use Rubix\ML\Estimator;
use Rubix\ML\Persistable;
use Rubix\ML\Probabilistic;
use Rubix\ML\Datasets\Unlabeled;
use Rubix\ML\Clusterers\MeanShift;
use Rubix\ML\Other\Helpers\DataType;
Expand Down Expand Up @@ -54,6 +55,7 @@ public function test_build_clusterer()
{
$this->assertInstanceOf(MeanShift::class, $this->estimator);
$this->assertInstanceOf(Learner::class, $this->estimator);
$this->assertInstanceOf(Probabilistic::class, $this->estimator);
$this->assertInstanceOf(Verbose::class, $this->estimator);
$this->assertInstanceOf(Persistable::class, $this->estimator);
$this->assertInstanceOf(Estimator::class, $this->estimator);
Expand Down
2 changes: 1 addition & 1 deletion tests/Regressors/GradientBoostTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GradientBoostTest extends TestCase
{
protected const TRAIN_SIZE = 400;
protected const TEST_SIZE = 10;
protected const MIN_SCORE = 0.7;
protected const MIN_SCORE = 0.9;

protected const RANDOM_SEED = 0;

Expand Down

0 comments on commit 509024a

Please sign in to comment.