Skip to content

Commit

Permalink
Fix Naive Bayes divide by zero when smoothing is 0
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Sep 7, 2020
1 parent c90fdec commit a6956c8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
- 0.1.5
- Compensate for zero vectors in Cosine kernel
- Fixed KMC2 random threshold calculation
- Fix Naive Bayes divide by zero when smoothing is 0

- 0.1.4
- Optimized Cosine distance for sparse vectors
Expand Down
2 changes: 1 addition & 1 deletion docs/classifiers/naive-bayes.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Categorical Naive Bayes is a probability-based classifier that uses counting and
## Parameters
| # | Param | Default | Type | Description |
|---|---|---|---|---|
| 1 | alpha | 1.0 | float | The amount of (Laplace) smoothing added to the probabilities. |
| 1 | smoothing | 1.0 | float | The amount of (Laplace) smoothing added to the probabilities. |
| 2 | priors | null | array | The class prior probabilities as an associative array with class labels as keys and the prior probabilities as values. If null, then the learner will compute these values from the training data. |

## Example
Expand Down
24 changes: 12 additions & 12 deletions src/Classifiers/NaiveBayes.php
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class NaiveBayes implements Estimator, Learner, Online, Probabilistic, Persistab
*
* @var float
*/
protected $alpha;
protected $smoothing;

/**
* The class prior log probabilities.
Expand Down Expand Up @@ -97,15 +97,15 @@ class NaiveBayes implements Estimator, Learner, Online, Probabilistic, Persistab
];

/**
* @param float $alpha
* @param float $smoothing
* @param (int|float)[]|null $priors
* @throws \InvalidArgumentException
*/
public function __construct(float $alpha = 1.0, ?array $priors = null)
public function __construct(float $smoothing = 1.0, ?array $priors = null)
{
if ($alpha < 0.0) {
throw new InvalidArgumentException('Alpha must be'
. " greater than 0, $alpha given.");
if ($smoothing <= 0.0) {
throw new InvalidArgumentException('Smoothing must be'
. " greater than 0, $smoothing given.");
}

$logPriors = [];
Expand All @@ -128,7 +128,7 @@ public function __construct(float $alpha = 1.0, ?array $priors = null)
}
}

$this->alpha = $alpha;
$this->smoothing = $smoothing;
$this->logPriors = $logPriors;
$this->fitPriors = is_null($priors);
}
Expand Down Expand Up @@ -163,7 +163,7 @@ public function compatibility() : array
public function params() : array
{
return [
'alpha' => $this->alpha,
'smoothing' => $this->smoothing,
'priors' => $this->fitPriors ? null : $this->priors(),
];
}
Expand Down Expand Up @@ -250,12 +250,12 @@ public function partial(Dataset $dataset) : void
}
}

$total = array_sum($columnCounts) + (count($columnCounts) * $this->alpha);
$total = array_sum($columnCounts) + (count($columnCounts) * $this->smoothing);

$probs = [];

foreach ($columnCounts as $category => $count) {
$probs[$category] = log(($count + $this->alpha) / $total);
$probs[$category] = log(($count + $this->smoothing) / $total);
}

$classCounts[$column] = $columnCounts;
Expand All @@ -269,12 +269,12 @@ public function partial(Dataset $dataset) : void
}

if ($this->fitPriors) {
$total = array_sum($this->weights) + (count($this->weights) * $this->alpha);
$total = array_sum($this->weights) + (count($this->weights) * $this->smoothing);

$this->logPriors = [];

foreach ($this->weights as $class => $weight) {
$this->logPriors[$class] = log(($weight + $this->alpha) / $total);
$this->logPriors[$class] = log(($weight + $this->smoothing) / $total);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/Classifiers/NaiveBayesTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public function compatibility() : void
public function params() : void
{
$expected = [
'alpha' => 1.0,
'smoothing' => 1.0,
'priors' => null,
];

Expand Down

0 comments on commit a6956c8

Please sign in to comment.