Skip to content

Commit

Permalink
Forego unnecessary logistic computation in Logit Boost
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Feb 20, 2022
1 parent 359cc24 commit 14eed45
Show file tree
Hide file tree
Showing 12 changed files with 21 additions and 20 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
- 1.3.3
- Forego unnecessary logistic computation in Logit Boost

- 1.3.2
- Optimize Binary output layer

Expand Down
7 changes: 3 additions & 4 deletions src/Classifiers/ExtraTreeClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ public function proba(Dataset $dataset) : array
}

/**
* Terminate the branch by selecting the class outcome with the highest
* probability.
* Terminate the branch by selecting the class outcome with the highest probability.
*
* @param \Rubix\ML\Datasets\Labeled $dataset
* @return \Rubix\ML\Graph\Nodes\Best
Expand All @@ -235,9 +234,9 @@ protected function terminate(Labeled $dataset) : Best

$p = $counts[$outcome] / $n;

$impurity = -($p * log($p));
$entropy = -($p * log($p));

return new Best($outcome, $probabilities, $impurity, $n);
return new Best($outcome, $probabilities, $entropy, $n);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class LogisticRegression implements Estimator, Learner, Online, Probabilistic, R
/**
* The number of training samples to process at a time.
*
* @var int
* @var positive-int
*/
protected int $batchSize;

Expand Down
7 changes: 2 additions & 5 deletions src/Classifiers/LogitBoost.php
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ public function train(Dataset $dataset) : void
$zHat = $booster->predict($training);

$z = array_map([$this, 'updateZ'], $zHat, $z);

$out = array_map('Rubix\ML\sigmoid', $z);

$this->losses[$epoch] = $loss;
Expand All @@ -439,12 +438,10 @@ public function train(Dataset $dataset) : void

$zTest = array_map([$this, 'updateZ'], $zHat, $zTest);

$outTest = array_map('Rubix\ML\sigmoid', $zTest);

$predictions = [];

foreach ($outTest as $probability) {
$predictions[] = $probability < 0.5 ? $classes[0] : $classes[1];
foreach ($zTest as $value) {
$predictions[] = $value < 0.0 ? $classes[0] : $classes[1];
}

$score = $this->metric->score($predictions, $testing->labels());
Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/MultilayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class MultilayerPerceptron implements Estimator, Learner, Online, Probabilistic,
/**
* The number of training samples to process at a time.
*
* @var int
* @var positive-int
*/
protected int $batchSize;

Expand Down
2 changes: 1 addition & 1 deletion src/Classifiers/SoftmaxClassifier.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class SoftmaxClassifier implements Estimator, Learner, Online, Probabilistic, Ve
/**
* The number of training samples to process at a time.
*
* @var int
* @var positive-int
*/
protected int $batchSize;

Expand Down
2 changes: 1 addition & 1 deletion src/Clusterers/KMeans.php
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class KMeans implements Estimator, Learner, Online, Probabilistic, Verbose, Pers
/**
* The size of each mini batch in samples.
*
* @var int
* @var positive-int
*/
protected int $batchSize;

Expand Down
5 changes: 3 additions & 2 deletions src/Datasets/Labeled.php
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ public function stratifiedFold(int $k = 10) : array
* not enough samples to fill an entire batch, then the dataset will contain
* as many samples and labels as possible.
*
* @param int $n
* @param positive-int $n
* @return list<self>
*/
public function batch(int $n = 50) : array
Expand Down Expand Up @@ -742,7 +742,8 @@ public function randomWeightedSubsetWithReplacement(int $n, array $weights) : se
. ' but ' . count($weights) . ' given.');
}

$numLevels = (int) round(sqrt(count($weights)));
/** @var positive-int $numLevels */
$numLevels = (int) round(sqrt(count($weights))) ?: 1;

$levels = array_chunk($weights, $numLevels, true);

Expand Down
3 changes: 2 additions & 1 deletion src/Datasets/Unlabeled.php
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ public function fold(int $k = 3) : array
* not enough samples to fill an entire batch, then the dataset will contain
* as many samples as possible.
*
* @param int $n
* @param positive-int $n
* @return list<self>
*/
public function batch(int $n = 50) : array
Expand Down Expand Up @@ -453,6 +453,7 @@ public function randomWeightedSubsetWithReplacement(int $n, array $weights) : se
. ' but ' . count($weights) . ' given.');
}

/** @var positive-int $numLevels */
$numLevels = (int) round(sqrt(count($weights)));

$levels = array_chunk($weights, $numLevels, true);
Expand Down
2 changes: 1 addition & 1 deletion src/Regressors/Adaline.php
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Adaline implements Estimator, Learner, Online, RanksFeatures, Verbose, Per
/**
* The number of training samples to process at a time.
*
* @var int
* @var positive-int
*/
protected int $batchSize;

Expand Down
2 changes: 1 addition & 1 deletion src/Regressors/MLPRegressor.php
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MLPRegressor implements Estimator, Learner, Online, Verbose, Persistable
/**
* The number of training samples to process at a time.
*
* @var int
* @var positive-int
*/
protected int $batchSize;

Expand Down
4 changes: 2 additions & 2 deletions tests/Classifiers/LogitBoostTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class LogitBoostTest extends TestCase
protected function setUp() : void
{
$this->generator = new Agglomerate([
'male' => new Blob([69.2, 195.7, 40.0], [2.0, 6.0, 0.6]),
'female' => new Blob([63.7, 168.5, 38.1], [1.6, 5.0, 0.8]),
'male' => new Blob([69.2, 195.7, 40.0], [2.0, 6.4, 0.6]),
'female' => new Blob([63.7, 168.5, 38.1], [1.8, 5.0, 0.8]),
]);

$this->estimator = new LogitBoost(new RegressionTree(3), 0.1, 0.5, 1000, 1e-4, 5, 0.1, new FBeta());
Expand Down

0 comments on commit 14eed45

Please sign in to comment.