-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Further edits to temperature scaling methods (#169)
- Loading branch information
1 parent
247439b
commit f2c9b3c
Showing
11 changed files
with
230 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
from fortuna.calibration.binary_classification.temp_scaling.bias_binary_temp_scaling import ( | ||
BiasBinaryClassificationTemperatureScaling, | ||
) | ||
from fortuna.calibration.binary_classification.temp_scaling.brier_binary_temp_scaling import ( | ||
BrierBinaryClassificationTemperatureScaling, | ||
) | ||
from fortuna.calibration.binary_classification.temp_scaling.crossentropy_binary_temp_scaling import ( | ||
CrossEntropyBinaryClassificationTemperatureScaling, | ||
) | ||
from fortuna.calibration.binary_classification.temp_scaling.f1_temp_scaling import ( | ||
F1BinaryClassificationTemperatureScaling, | ||
) | ||
from fortuna.calibration.binary_classification.temp_scaling.mse_binary_temp_scaling import ( | ||
MSEBinaryClassificationTemperatureScaling, | ||
) | ||
from fortuna.calibration.classification.temp_scaling.base import ( | ||
ClassificationTemperatureScaling, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 20 additions & 0 deletions
20
fortuna/calibration/binary_classification/temp_scaling/brier_binary_temp_scaling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import numpy as np | ||
|
||
from fortuna.calibration.binary_classification.temp_scaling.base import ( | ||
BaseBinaryClassificationTemperatureScaling, | ||
) | ||
|
||
|
||
class BrierBinaryClassificationTemperatureScaling( | ||
BaseBinaryClassificationTemperatureScaling | ||
): | ||
""" | ||
A temperature scaling class for binary classification. | ||
It scales the probability that the target variables is positive with a single learnable parameters. | ||
The method attempts to minimize the MSE, or Brier score. | ||
""" | ||
|
||
def fit(self, probs: np.ndarray, targets: np.ndarray): | ||
self._check_probs(probs) | ||
self._check_targets(targets) | ||
self._temperature = np.mean(probs**2) / np.mean(probs * targets) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 0 additions & 12 deletions
12
fortuna/calibration/binary_classification/temp_scaling/mse_binary_temp_scaling.py
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.