Skip to content

Commit

Permalink
Merge pull request #605 from Lorenzo-Perini/development
Browse files Browse the repository at this point in the history
Development - Including Reject Option in Unsupervised Anomaly Detection
  • Loading branch information
yzhao062 authored Sep 6, 2024
2 parents 2a80ac8 + 3842c2c commit 0ba2fc8
Show file tree
Hide file tree
Showing 44 changed files with 655 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ The full API Reference is available at `PyOD Documentation <https://pyod.readthe
* **predict(X)**: Determine whether a sample is an outlier or not as binary labels using the fitted detector.
* **predict_proba(X)**: Estimate the probability of a sample being an outlier using the fitted detector.
* **predict_confidence(X)**: Assess the model's confidence on a per-sample basis (applicable in predict and predict_proba) [#Perini2020Quantifying]_.
* **predict_with_rejection(X)**\ : Allow the detector to reject (i.e., abstain from making) highly uncertain predictions (output = -2) [#Perini2023Rejection]_.

**Key Attributes of a fitted model**:

Expand Down Expand Up @@ -567,6 +568,8 @@ Reference
.. [#Perini2020Quantifying] Perini, L., Vercruyssen, V., Davis, J. Quantifying the confidence of anomaly detectors in their example-wise predictions. In *Joint European Conference on Machine Learning and Knowledge Discovery in Databases (ECML-PKDD)*, 2020.
.. [#Perini2023Rejection] Perini, L., Davis, J. Unsupervised anomaly detection with rejection. In *Proceedings of the Thirty-Seven Conference on Neural Information Processing Systems (NeurIPS)*, 2023.
.. [#Ramaswamy2000Efficient] Ramaswamy, S., Rastogi, R. and Shim, K., 2000, May. Efficient algorithms for mining outliers from large data sets. *ACM Sigmod Record*\ , 29(2), pp. 427-438.
.. [#Rousseeuw1999A] Rousseeuw, P.J. and Driessen, K.V., 1999. A fast algorithm for the minimum covariance determinant estimator. *Technometrics*\ , 41(3), pp.212-223.
Expand Down
148 changes: 148 additions & 0 deletions pyod/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.utils import deprecated
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_is_fitted
from scipy.optimize import root_scalar

from .sklearn_base import _pprint
from ..utils.utility import precision_n_scores
Expand Down Expand Up @@ -293,7 +294,154 @@ def predict_confidence(self, X):
np.place(confidence, prediction == 0, 1 - confidence[prediction == 0])

return confidence

def predict_with_rejection(self, X, T = 32, return_stats = False,
delta = 0.1, c_fp = 1, c_fn = 1, c_r = -1):
"""Predict if a particular sample is an outlier or not,
allowing the detector to reject (i.e., output = -2)
low confidence predictions.
Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.
T : int, optional(default=32)
It allows to set the rejection threshold to 1-2exp(-T).
The higher the value of T, the more rejections are made.
return_stats: bool, optional (default = False)
If true, it returns also three additional float values:
the estimated rejection rate, the upper bound rejection rate,
and the upper bound of the cost.
delta: float, optional (default = 0.1)
The upper bound rejection rate holds with probability 1-delta.
c_fp, c_fn, c_r: floats (positive), optional (default = [1,1, contamination])
costs for false positive predictions (c_fp), false negative
predictions (c_fn) and rejections (c_r).
Returns
-------
outlier_labels : numpy array of shape (n_samples,)
For each observation, it tells whether it should be considered
as an outlier according to the fitted model. 0 stands for inliers,
1 for outliers and -2 for rejection.
expected_rejection_rate: float, if return_stats is True;
upperbound_rejection_rate: float, if return_stats is True;
upperbound_cost: float, if return_stats is True;
"""
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])
if c_r <0:
warnings.warn("The cost of rejection must be positive. It has been set to the contamination rate.")
c_r = self.contamination

if delta<=0 or delta>=1:
warnings.warn("delta must belong to (0,1). It's value has been set to 0.1")
delta = 0.1

self.rejection_threshold_ = 1- 2*np.exp(-T)
prediction = self.predict(X)
confidence = self.predict_confidence(X)
np.place(confidence, prediction == 0, 1 - confidence[prediction == 0])
confidence = 2*abs(confidence-.5)
prediction[np.where(confidence<=self.rejection_threshold_)[0]] = -2

if return_stats:
expected_rejrate, ub_rejrate, ub_cost = self.compute_rejection_stats(T = T, delta = delta,
c_fp=c_fp, c_fn =c_fn, c_r = c_r)
return prediction, [expected_rejrate, ub_rejrate, ub_cost]

return prediction


def compute_rejection_stats(self, T = 32, delta = 0.1, c_fp = 1, c_fn = 1, c_r = -1, verbose = False):
"""Add reject option into the unsupervised detector.
This comes with guarantees: an estimate of the expected
rejection rate (return_rejectrate=True), an upper
bound of the rejection rate (return_ub_rejectrate= True),
and an upper bound on the cost (return_ub_cost=True).
Parameters
----------
T: int, optional(default=32)
It allows to set the rejection threshold to 1-2exp(-T).
The higher the value of T, the more rejections are made.
delta: float, optional (default = 0.1)
The upper bound rejection rate holds with probability 1-delta.
c_fp, c_fn, c_r: floats (positive), optional (default = [1,1, contamination])
costs for false positive predictions (c_fp), false negative
predictions (c_fn) and rejections (c_r).
verbose: bool, optional (default = False)
If true, it prints the expected rejection rate, the upper bound rejection rate,
and the upper bound of the cost.
Returns
-------
expected_rejection_rate: float, the expected rejection rate;
upperbound_rejection_rate: float, the upper bound for the rejection rate
satisfied with probability 1-delta;
upperbound_cost: float, the upper bound for the cost;
"""

check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])

if c_r <0:
c_r = self.contamination

if delta<=0 or delta>=1:
delta = 0.1

# Computing the expected rejection rate
n = len(self.decision_scores_)
n_gamma_minus1 = int(n * self.contamination) -1
argsmin = (n_gamma_minus1, n, 1-np.exp(-T))
argsmax = (n_gamma_minus1, n, np.exp(-T))
q1 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmin).root
q2 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmax).root
expected_reject_rate = q2-q1

# Computing the upper bound for the rejection rate
right_mar = (-self.contamination * (n + 2) + n + 1) / n + (T * (n + 2)) / (np.sqrt(2 * n**3 * T))
right_mar = min(1, right_mar)
left_mar = (
(2 + n * (1 - self.contamination) * (n + 1)) / n**2
- np.sqrt(
0.5 * n**5 * (
2 * n * (
-3 * self.contamination**2
- 2 * n * (1 - self.contamination)**2
+ 4 * self.contamination - 3
)
+ T * (n + 2)**2 - 8
)
) / n**4
)
left_mar = max(0, left_mar)
add_term = 2 * np.sqrt(np.log(2 / delta) / (2 * n))
upperbound_rejectrate = right_mar - left_mar + add_term

# Computing the upper bound for the cost function
n_gamma_minus1 = int(n * self.contamination) -1
argsmin = (n_gamma_minus1, n, 1-np.exp(-T))
argsmax = (n_gamma_minus1, n, np.exp(-T))
q1 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmin).root
q2 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmax).root
upperbound_cost = np.min([self.contamination,q1])*c_fp + np.min([1-q2,self.contamination])*c_fn + (q2-q1)*c_r

if verbose:
print("Expected rejection rate: ", np.round(expected_reject_rate, 4), '%')
print("Upper bound rejection rate: ", np.round(upperbound_rejectrate, 4), '%')
print("Upper bound cost: ", np.round(upperbound_cost, 4))

return expected_reject_rate, upperbound_rejectrate, upperbound_cost

def _predict_rank(self, X, normalized=False):
"""Predict the outlyingness rank of a sample by a fitted model. The
method is for outlier detector score combination.
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_abod.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_ae1svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict_score(self):
self.clf.fit_predict_score(self.X_test, self.y_test)
self.clf.fit_predict_score(self.X_test, self.y_test,
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_alad.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ def test_prediction_proba_linear_confidence(self):
self.assertEqual(confidence.shape, self.y_test.shape)
self.assertInRange(confidence, 0, 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
self.assertEqual(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
self.assertGreaterEqual(expected_rejrate, 0)
self.assertLessEqual(expected_rejrate, 1)
self.assertGreaterEqual(ub_rejrate, 0)
self.assertLessEqual(ub_rejrate, 1)
self.assertGreaterEqual(ub_cost, 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
self.assertEqual(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_cblof.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def test_prediction_proba_linear_confidence(self):
assert_equal(confidence.shape, self.y_test.shape)
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def test_prediction_labels_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_prediction_proba_linear_confidence(self):
pred_proba, confidence = self.clf.predict_proba(self.X_test,
method='linear',
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_cof.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_copod.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_deepsvdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_devnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train, self.y_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_dif.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
Loading

0 comments on commit 0ba2fc8

Please sign in to comment.