Skip to content

Commit

Permalink
Further edits to temperature scaling methods (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso authored Dec 14, 2023
1 parent 247439b commit f2c9b3c
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 76 deletions.
46 changes: 25 additions & 21 deletions benchmarks/calibration/temp_scaling/breast_cancer_temp_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from fortuna.calibration import (
BiasBinaryClassificationTemperatureScaling,
BrierBinaryClassificationTemperatureScaling,
ClassificationTemperatureScaling,
CrossEntropyBinaryClassificationTemperatureScaling,
F1BinaryClassificationTemperatureScaling,
MSEBinaryClassificationTemperatureScaling,
)
from fortuna.metric.classification import (
brier_score,
Expand Down Expand Up @@ -73,22 +73,26 @@ def binary_cross_entropy(probs: np.array, targets: np.ndarray) -> float:
before_f1 = f1(test_probs, test_targets)
before_ce = binary_cross_entropy(test_probs, test_targets)

mse_temp_scaler = MSEBinaryClassificationTemperatureScaling()
mse_temp_scaler.fit(probs=calib_probs, targets=calib_targets)
mse_temp_scaled_test_probs = mse_temp_scaler.predict_proba(probs=test_probs)
mse_temp_scaled_test_preds = mse_temp_scaler.predict(probs=test_probs)
mse_temp_scaled_brier_score = brier_score(mse_temp_scaled_test_probs, test_targets)
mse_temp_scaled_ece = expected_calibration_error(
brier_temp_scaler = BrierBinaryClassificationTemperatureScaling()
brier_temp_scaler.fit(probs=calib_probs, targets=calib_targets)
brier_temp_scaled_test_probs = brier_temp_scaler.predict_proba(probs=test_probs)
brier_temp_scaled_test_preds = brier_temp_scaler.predict(probs=test_probs)
brier_temp_scaled_brier_score = brier_score(
brier_temp_scaled_test_probs, test_targets
)
brier_temp_scaled_ece = expected_calibration_error(
probs=np.stack(
(1 - mse_temp_scaled_test_probs, mse_temp_scaled_test_probs), axis=1
(1 - brier_temp_scaled_test_probs, brier_temp_scaled_test_probs), axis=1
),
preds=mse_temp_scaled_test_preds,
preds=brier_temp_scaled_test_preds,
targets=test_targets,
)
mse_temp_scaled_prec = precision(mse_temp_scaled_test_preds, test_targets)
mse_temp_scaled_rec = recall(mse_temp_scaled_test_preds, test_targets)
mse_temp_scaled_f1 = f1(mse_temp_scaled_test_preds, test_targets)
mse_temp_scaled_ce = binary_cross_entropy(mse_temp_scaled_test_probs, test_targets)
brier_temp_scaled_prec = precision(brier_temp_scaled_test_preds, test_targets)
brier_temp_scaled_rec = recall(brier_temp_scaled_test_preds, test_targets)
brier_temp_scaled_f1 = f1(brier_temp_scaled_test_preds, test_targets)
brier_temp_scaled_ce = binary_cross_entropy(
brier_temp_scaled_test_probs, test_targets
)

ce_temp_scaler = CrossEntropyBinaryClassificationTemperatureScaling()
ce_temp_scaler.fit(probs=calib_probs, targets=calib_targets)
Expand Down Expand Up @@ -185,13 +189,13 @@ def binary_cross_entropy(probs: np.array, targets: np.ndarray) -> float:
before_f1,
],
[
"MSE binary temperature scaling",
mse_temp_scaled_brier_score,
mse_temp_scaled_ce,
mse_temp_scaled_ece,
mse_temp_scaled_prec,
mse_temp_scaled_rec,
mse_temp_scaled_f1,
"Brier binary temperature scaling",
brier_temp_scaled_brier_score,
brier_temp_scaled_ce,
brier_temp_scaled_ece,
brier_temp_scaled_prec,
brier_temp_scaled_rec,
brier_temp_scaled_f1,
],
[
"Cross-Entropy binary temperature scaling",
Expand Down Expand Up @@ -246,7 +250,7 @@ def binary_cross_entropy(probs: np.array, targets: np.ndarray) -> float:
print(
tabulate(
[
["MSE binary temperature scaling", mse_temp_scaler.temperature],
["Brier binary temperature scaling", brier_temp_scaler.temperature],
[
"Cross-Entropy binary temperature scaling",
ce_temp_scaler.temperature,
Expand Down
6 changes: 3 additions & 3 deletions fortuna/calibration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from fortuna.calibration.binary_classification.temp_scaling.bias_binary_temp_scaling import (
BiasBinaryClassificationTemperatureScaling,
)
from fortuna.calibration.binary_classification.temp_scaling.brier_binary_temp_scaling import (
BrierBinaryClassificationTemperatureScaling,
)
from fortuna.calibration.binary_classification.temp_scaling.crossentropy_binary_temp_scaling import (
CrossEntropyBinaryClassificationTemperatureScaling,
)
from fortuna.calibration.binary_classification.temp_scaling.f1_temp_scaling import (
F1BinaryClassificationTemperatureScaling,
)
from fortuna.calibration.binary_classification.temp_scaling.mse_binary_temp_scaling import (
MSEBinaryClassificationTemperatureScaling,
)
from fortuna.calibration.classification.temp_scaling.base import (
ClassificationTemperatureScaling,
)
60 changes: 56 additions & 4 deletions fortuna/calibration/binary_classification/temp_scaling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,67 @@ def __init__(self):
self._temperature = None

@abc.abstractmethod
def fit(self, probs: np.ndarray, targets: np.ndarray, **kwargs):
def fit(self, probs: np.ndarray, targets: np.ndarray):
"""
Fit the temperature scaling method.
Parameters
----------
probs: np.ndarray
A one-dimensional probabilities of positive target variables for each input.
targets: np.ndarray
A one-dimensional array of integer target variables for each input.
"""
pass

def predict_proba(self, probs: np.ndarray):
def predict_proba(self, probs: np.ndarray) -> np.ndarray:
"""
Predict the scaled probabilities for each input.
Parameters
----------
probs: np.ndarray
A one-dimensional probabilities of positive target variables for each input.
Returns
-------
np.ndarray
The predicted probabilities
"""
self._check_probs(probs)
return np.clip(probs / self._temperature, 0.0, 1.0)

def predict(self, probs: np.ndarray):
return (self.predict_proba(probs) >= 0.5).astype(int)
def predict(self, probs: np.ndarray, threshold: float = 0.5) -> np.ndarray:
"""
Predict the target variable for each input.
Parameters
----------
probs: np.ndarray
A one-dimensional probabilities of positive target variables for each input.
threshold: np.ndarray
The threshold on the predicted probabilities do decide whether a target variable is positive or
negative.
Returns
-------
np.ndarray
The predicted target variables.
"""
self._check_probs(probs)
return (self.predict_proba(probs) >= threshold).astype(int)

@property
def temperature(self):
return self._temperature

@staticmethod
def _check_probs(probs: np.ndarray):
if probs.ndim != 1:
raise ValueError("The array of probabilities must be one-dimensional.")

@staticmethod
def _check_targets(targets: np.ndarray):
if targets.ndim != 1:
raise ValueError("The array of targets must be one-dimensional.")
if targets.dtype != int:
raise ValueError("Each element in the array of targets must be an integer.")
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,13 @@
class BiasBinaryClassificationTemperatureScaling(
BaseBinaryClassificationTemperatureScaling
):
def fit(self, probs: np.ndarray, targets: np.ndarray, **kwargs):
"""
A temperature scaling class for binary classification.
It scales the probability that the target variables is positive with a single learnable parameters.
The method minimizes the expected bias.
"""

def fit(self, probs: np.ndarray, targets: np.ndarray):
self._check_probs(probs)
self._check_targets(targets)
self._temperature = np.mean(probs) / np.mean(targets)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np

from fortuna.calibration.binary_classification.temp_scaling.base import (
BaseBinaryClassificationTemperatureScaling,
)


class BrierBinaryClassificationTemperatureScaling(
BaseBinaryClassificationTemperatureScaling
):
"""
A temperature scaling class for binary classification.
It scales the probability that the target variables is positive with a single learnable parameters.
The method attempts to minimize the MSE, or Brier score.
"""

def fit(self, probs: np.ndarray, targets: np.ndarray):
self._check_probs(probs)
self._check_targets(targets)
self._temperature = np.mean(probs**2) / np.mean(probs * targets)
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict

import numpy as np
from scipy.optimize import newton
from scipy.optimize import brute

from fortuna.calibration.binary_classification.temp_scaling.base import (
BaseBinaryClassificationTemperatureScaling,
Expand All @@ -11,12 +9,22 @@
class CrossEntropyBinaryClassificationTemperatureScaling(
BaseBinaryClassificationTemperatureScaling
):
def fit(self, probs: np.ndarray, targets: np.ndarray, **kwargs) -> Dict:
scaled_probs = (1 - 1e-6) * (1e-6 + probs)
"""
A temperature scaling class for binary classification.
It scales the probability that the target variables is positive with a single learnable parameters.
The method minimizes the binary cross-entropy loss.
"""

def fit(self, probs: np.ndarray, targets: np.ndarray):
self._check_probs(probs)
self._check_targets(targets)

def temp_scaling_fn(phi):
return np.mean((1 - targets) / (1 - scaled_probs * np.exp(-phi))) - 1
def temp_scaling_fn(tau):
temp_probs = np.clip(probs / tau, 1e-9, 1 - 1e-9)
return -np.mean(
targets * np.log(temp_probs) + (1 - targets) * np.log(1 - temp_probs)
)

phi, status = newton(temp_scaling_fn, x0=0.0, full_output=True, disp=False)
self._temperature = np.exp(phi)
return status
self._temperature = brute(
temp_scaling_fn, ranges=[(np.min(probs), 10)], Ns=1000
)[0]
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
import numpy as np
from scipy.optimize import brute

from fortuna.calibration.binary_classification.temp_scaling.base import (
BaseBinaryClassificationTemperatureScaling,
)


class F1BinaryClassificationTemperatureScaling(
BaseBinaryClassificationTemperatureScaling
):
"""
A temperature scaling class for binary classification.
It scales the probability that the target variables is positive with a single learnable parameters.
The method attempts to maximize the F1 score.
"""

class F1BinaryClassificationTemperatureScaling:
def __init__(self):
super().__init__()
self._threshold = None
self._temperature = None

def fit(self, probs: np.ndarray, targets: np.ndarray, threshold: float):
self._check_probs(probs)
self._check_targets(targets)

self._threshold = threshold
n_pos_targets = np.sum(targets)

Expand All @@ -26,16 +41,10 @@ def loss_fn(tau):
loss_fn, ranges=[(np.min(probs), 1 / threshold)], Ns=1000
)[0]

def predict_proba(self, probs: np.ndarray):
return np.clip(probs / self._temperature, 0.0, 1.0)

def predict(self, probs: np.ndarray):
self._check_probs(probs)
return (self.predict_proba(probs) >= self._threshold).astype(int)

@property
def threshold(self):
return self._threshold

@property
def temperature(self):
return self._temperature

This file was deleted.

Loading

0 comments on commit f2c9b3c

Please sign in to comment.