Skip to content

Commit

Permalink
Added error statistics to residual analysis report
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed May 19, 2019
1 parent e7efd65 commit 11f81e2
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- Fixed MSLE computation in Residual Analysis report
- Renamed RMSError Metric to RMSE
- Embedders no longer implement Estimator interface
- Added error statistics to Residual Analysis report

- 0.0.11-beta
- K Means now uses mini batch GD instead of SGD
Expand Down
53 changes: 29 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5883,28 +5883,29 @@ var_dump($result);
**Output:**

```sh
["label"]=> array(2) {
["wolf"]=> array(19) {
["accuracy"]=> float(0.6)
["precision"]=> float(0.66666666666667)
["recall"]=> float(0.66666666666667)
["specificity"]=> float(0.5)
["negative_predictive_value"]=> float(0.5)
["false_discovery_rate"]=> float(0.33333333333333)
["miss_rate"]=> float(0.33333333333333)
["fall_out"]=> float(0.5)
["false_omission_rate"]=> float(0.5)
["f1_score"]=> float(0.66666666666667)
["mcc"]=> float(0.16666666666667)
["informedness"]=> float(0.16666666666667)
["markedness"]=> float(0.16666666666667)
["true_positives"]=> int(2)
["true_negatives"]=> int(1)
["false_positives"]=> int(1)
["false_negatives"]=> int(1)
["cardinality"]=> int(3)
["density"]=> float(0.6)
['label']=> array(2) {
['wolf']=> array(19) {
['accuracy']=> float(0.6)
['precision']=> float(0.66666666666667)
['recall']=> float(0.66666666666667)
['specificity']=> float(0.5)
['negative_predictive_value']=> float(0.5)
['false_discovery_rate']=> float(0.33333333333333)
['miss_rate']=> float(0.33333333333333)
['fall_out']=> float(0.5)
['false_omission_rate']=> float(0.5)
['f1_score']=> float(0.66666666666667)
['mcc']=> float(0.16666666666667)
['informedness']=> float(0.16666666666667)
['markedness']=> float(0.16666666666667)
['true_positives']=> int(2)
['true_negatives']=> int(1)
['false_positives']=> int(1)
['false_negatives']=> int(1)
['cardinality']=> int(3)
['density']=> float(0.6)
}
...
```
### Residual Analysis
Expand Down Expand Up @@ -5933,16 +5934,20 @@ var_dump($result);
**Output:**
```sh
array(12) {
array(18) {
['mean_absolute_error']=> float(0.18220216502615122)
['mean_absolute_percentage_error']=> float(18.174348688407402)
['median_absolute_error']=> float(0.17700000000000005)
['mean_squared_error']=> float(0.05292430893457563)
['mean_squared_log_error']=> float(51.96853354084834)
['mean_absolute_percentage_error']=> float(18.174348688407402)
['rms_error']=> float(0.23005283944036775)
['mean_squared_log_error']=> float(51.96853354084834)
['r_squared']=> float(0.9999669635675313)
['error_mean']=> float(-0.07112216502615118)
['error_midrange']=> float(-0.12315541256537399)
['error_median']=> float(0.0007000000000000001)
['error_variance']=> float(0.04786594657656853)
['error_mad']=> float(0.17630000000000004)
['error_interquartile_range']=> float(0.455155412565378)
['error_skewness']=> float(-0.49093461098755187)
['error_kurtosis']=> float(-1.216490935575394)
['error_min']=> float(-0.423310825130748)
Expand Down
4 changes: 2 additions & 2 deletions src/CrossValidation/Metrics/SMAPE.php
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public function score(array $predictions, array $labels) : float
foreach ($predictions as $i => $prediction) {
$label = $labels[$i];

$error += 100. * abs(($label - $prediction)
/ ((abs($prediction) + abs($label)) ?: EPSILON));
$error += 100. * abs(($prediction - $label)
/ ((abs($label) + abs($prediction)) ?: EPSILON));
}

return -($error / count($predictions));
Expand Down
24 changes: 16 additions & 8 deletions src/CrossValidation/Reports/ResidualAnalysis.php
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public function generate(array $predictions, array $labels) : array

$muHat = Stats::mean($labels);

$errors = $l1 = $l2 = $ape = $spe = $log = [];
$errors = $l1 = $l2 = $are = $sle = [];

$sse = $sst = 0.;

Expand All @@ -60,31 +60,39 @@ public function generate(array $predictions, array $labels) : array

$l1[] = abs($error);
$l2[] = $se = $error ** 2;
$ape[] = abs($error / ($prediction ?: EPSILON)) * 100.;
$log[] = log((1. + $label) / ((1. + $prediction) ?: EPSILON)) ** 2;
$are[] = abs($error / ($prediction ?: EPSILON));
$sle[] = log((1. + $label) / ((1. + $prediction) ?: EPSILON)) ** 2;

$sse += $se;
$sst += ($label - $muHat) ** 2;
}

$mse = Stats::mean($l2);

[$mean, $variance] = Stats::meanVar($errors);
[$median, $mad] = Stats::medianMad($errors);

$mse = Stats::mean($l2);
$min = min($errors);
$max = max($errors);

return [
'mean_absolute_error' => Stats::mean($l1),
'median_absolute_error' => Stats::median($l1),
'mean_absolute_percentage_error' => Stats::mean($ape),
'mean_squared_error' => $mse,
'mean_absolute_percentage_error' => 100. * Stats::mean($are),
'rms_error' => sqrt($mse),
'mean_squared_log_error' => Stats::mean($log),
'mean_squared_log_error' => Stats::mean($sle),
'r_squared' => 1. - ($sse / ($sst ?: EPSILON)),
'error_mean' => $mean,
'error_midrange' => ($min + $max) / 2.,
'error_median' => $median,
'error_variance' => $variance,
'error_mad' => $mad,
'error_interquartile_range' => Stats::iqr($errors),
'error_skewness' => Stats::skewness($errors, $mean),
'error_kurtosis' => Stats::kurtosis($errors, $mean),
'error_min' => min($errors),
'error_max' => max($errors),
'error_min' => $min,
'error_max' => $max,
'cardinality' => count($predictions),
];
}
Expand Down
20 changes: 14 additions & 6 deletions tests/CrossValidation/Reports/ResidualAnalysisTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@ public function generate_report_provider() : Generator
[11, 12, 14, 40, 55, 12, 16, 10, 2, 7],
[
'mean_absolute_error' => 0.8,
'mean_absolute_percentage_error' => 14.02077497665733,
'median_absolute_error' => 1.,
'mean_squared_error' => 1.,
'mean_squared_log_error' => 0.019107097505647368,
'mean_absolute_percentage_error' => 14.02077497665733,
'rms_error' => 1.,
'mean_squared_log_error' => 0.019107097505647368,
'r_squared' => 0.9958930551562692,
'error_mean' => -0.2,
'error_midrange' => -0.5,
'error_median' => 0.0,
'error_variance' => 0.9599999999999997,
'error_mad' => 1.0,
'error_interquartile_range' => 2.0,
'error_skewness' => -0.22963966338592326,
'error_kurtosis' => -1.0520833333333324,
'error_min' => -2,
'error_max' => 1,
'r_squared' => 0.9958930551562692,
'cardinality' => 10,
],
];
Expand All @@ -60,18 +64,22 @@ public function generate_report_provider() : Generator
[0.0019, -1.822, -0.9, 99.99, M_E],
[
'mean_absolute_error' => 0.18220216502615122,
'mean_absolute_percentage_error' => 18.174348688407402,
'median_absolute_error' => 0.17700000000000005,
'mean_squared_error' => 0.05292430893457563,
'mean_squared_log_error' => 51.96853354084834,
'mean_absolute_percentage_error' => 18.174348688407402,
'rms_error' => 0.23005283944036775,
'mean_squared_log_error' => 51.96853354084834,
'r_squared' => 0.9999669635675313,
'error_mean' => -0.07112216502615118,
'error_midrange' => -0.12315541256537399,
'error_median' => 0.0007000000000000001,
'error_variance' => 0.04786594657656853,
'error_mad' => 0.17630000000000004,
'error_interquartile_range' => 0.455155412565378,
'error_skewness' => -0.49093461098755187,
'error_kurtosis' => -1.216490935575394,
'error_min' => -0.423310825130748,
'error_max' => 0.17700000000000005,
'r_squared' => 0.9999669635675313,
'cardinality' => 5,
],
];
Expand Down

0 comments on commit 11f81e2

Please sign in to comment.