Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development - Including Reject Option in Unsupervised Anomaly Detection #605

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ The full API Reference is available at `PyOD Documentation <https://pyod.readthe
* **predict(X)**: Determine whether a sample is an outlier or not as binary labels using the fitted detector.
* **predict_proba(X)**: Estimate the probability of a sample being an outlier using the fitted detector.
* **predict_confidence(X)**: Assess the model's confidence on a per-sample basis (applicable in predict and predict_proba) [#Perini2020Quantifying]_.
* **predict_with_rejection(X)**\ : Allow the detector to reject (i.e., abstain from making) highly uncertain predictions (output = -2) [#Perini2023Rejection]_.

**Key Attributes of a fitted model**:

Expand Down Expand Up @@ -567,6 +568,8 @@ Reference

.. [#Perini2020Quantifying] Perini, L., Vercruyssen, V., Davis, J. Quantifying the confidence of anomaly detectors in their example-wise predictions. In *Joint European Conference on Machine Learning and Knowledge Discovery in Databases (ECML-PKDD)*, 2020.

.. [#Perini2023Rejection] Perini, L., Davis, J. Unsupervised anomaly detection with rejection. In *Proceedings of the Thirty-Seven Conference on Neural Information Processing Systems (NeurIPS)*, 2023.

.. [#Ramaswamy2000Efficient] Ramaswamy, S., Rastogi, R. and Shim, K., 2000, May. Efficient algorithms for mining outliers from large data sets. *ACM Sigmod Record*\ , 29(2), pp. 427-438.

.. [#Rousseeuw1999A] Rousseeuw, P.J. and Driessen, K.V., 1999. A fast algorithm for the minimum covariance determinant estimator. *Technometrics*\ , 41(3), pp.212-223.
Expand Down
148 changes: 148 additions & 0 deletions pyod/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.utils import deprecated
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_is_fitted
from scipy.optimize import root_scalar

from .sklearn_base import _pprint
from ..utils.utility import precision_n_scores
Expand Down Expand Up @@ -293,7 +294,154 @@ def predict_confidence(self, X):
np.place(confidence, prediction == 0, 1 - confidence[prediction == 0])

return confidence

def predict_with_rejection(self, X, T = 32, return_stats = False,
delta = 0.1, c_fp = 1, c_fn = 1, c_r = -1):
"""Predict if a particular sample is an outlier or not,
allowing the detector to reject (i.e., output = -2)
low confidence predictions.

Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.

T : int, optional(default=32)
It allows to set the rejection threshold to 1-2exp(-T).
The higher the value of T, the more rejections are made.

return_stats: bool, optional (default = False)
If true, it returns also three additional float values:
the estimated rejection rate, the upper bound rejection rate,
and the upper bound of the cost.

delta: float, optional (default = 0.1)
The upper bound rejection rate holds with probability 1-delta.

c_fp, c_fn, c_r: floats (positive), optional (default = [1,1, contamination])
costs for false positive predictions (c_fp), false negative
predictions (c_fn) and rejections (c_r).

Returns
-------
outlier_labels : numpy array of shape (n_samples,)
For each observation, it tells whether it should be considered
as an outlier according to the fitted model. 0 stands for inliers,
1 for outliers and -2 for rejection.

expected_rejection_rate: float, if return_stats is True;
upperbound_rejection_rate: float, if return_stats is True;
upperbound_cost: float, if return_stats is True;

"""
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])
if c_r <0:
warnings.warn("The cost of rejection must be positive. It has been set to the contamination rate.")
c_r = self.contamination

if delta<=0 or delta>=1:
warnings.warn("delta must belong to (0,1). It's value has been set to 0.1")
delta = 0.1

self.rejection_threshold_ = 1- 2*np.exp(-T)
prediction = self.predict(X)
confidence = self.predict_confidence(X)
np.place(confidence, prediction == 0, 1 - confidence[prediction == 0])
confidence = 2*abs(confidence-.5)
prediction[np.where(confidence<=self.rejection_threshold_)[0]] = -2

if return_stats:
expected_rejrate, ub_rejrate, ub_cost = self.compute_rejection_stats(T = T, delta = delta,
c_fp=c_fp, c_fn =c_fn, c_r = c_r)
return prediction, [expected_rejrate, ub_rejrate, ub_cost]

return prediction


def compute_rejection_stats(self, T = 32, delta = 0.1, c_fp = 1, c_fn = 1, c_r = -1, verbose = False):
"""Add reject option into the unsupervised detector.
This comes with guarantees: an estimate of the expected
rejection rate (return_rejectrate=True), an upper
bound of the rejection rate (return_ub_rejectrate= True),
and an upper bound on the cost (return_ub_cost=True).

Parameters
----------
T: int, optional(default=32)
It allows to set the rejection threshold to 1-2exp(-T).
The higher the value of T, the more rejections are made.

delta: float, optional (default = 0.1)
The upper bound rejection rate holds with probability 1-delta.

c_fp, c_fn, c_r: floats (positive), optional (default = [1,1, contamination])
costs for false positive predictions (c_fp), false negative
predictions (c_fn) and rejections (c_r).

verbose: bool, optional (default = False)
If true, it prints the expected rejection rate, the upper bound rejection rate,
and the upper bound of the cost.

Returns
-------
expected_rejection_rate: float, the expected rejection rate;
upperbound_rejection_rate: float, the upper bound for the rejection rate
satisfied with probability 1-delta;
upperbound_cost: float, the upper bound for the cost;
"""

check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])

if c_r <0:
c_r = self.contamination

if delta<=0 or delta>=1:
delta = 0.1

# Computing the expected rejection rate
n = len(self.decision_scores_)
n_gamma_minus1 = int(n * self.contamination) -1
argsmin = (n_gamma_minus1, n, 1-np.exp(-T))
argsmax = (n_gamma_minus1, n, np.exp(-T))
q1 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmin).root
q2 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmax).root
expected_reject_rate = q2-q1

# Computing the upper bound for the rejection rate
right_mar = (-self.contamination * (n + 2) + n + 1) / n + (T * (n + 2)) / (np.sqrt(2 * n**3 * T))
right_mar = min(1, right_mar)
left_mar = (
(2 + n * (1 - self.contamination) * (n + 1)) / n**2
- np.sqrt(
0.5 * n**5 * (
2 * n * (
-3 * self.contamination**2
- 2 * n * (1 - self.contamination)**2
+ 4 * self.contamination - 3
)
+ T * (n + 2)**2 - 8
)
) / n**4
)
left_mar = max(0, left_mar)
add_term = 2 * np.sqrt(np.log(2 / delta) / (2 * n))
upperbound_rejectrate = right_mar - left_mar + add_term

# Computing the upper bound for the cost function
n_gamma_minus1 = int(n * self.contamination) -1
argsmin = (n_gamma_minus1, n, 1-np.exp(-T))
argsmax = (n_gamma_minus1, n, np.exp(-T))
q1 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmin).root
q2 = root_scalar(lambda p, k, n, C: binom.cdf(k, n, p) - C, bracket=[0, 1], method='brentq', args=argsmax).root
upperbound_cost = np.min([self.contamination,q1])*c_fp + np.min([1-q2,self.contamination])*c_fn + (q2-q1)*c_r

if verbose:
print("Expected rejection rate: ", np.round(expected_reject_rate, 4), '%')
print("Upper bound rejection rate: ", np.round(upperbound_rejectrate, 4), '%')
print("Upper bound cost: ", np.round(upperbound_cost, 4))

return expected_reject_rate, upperbound_rejectrate, upperbound_cost

def _predict_rank(self, X, normalized=False):
"""Predict the outlyingness rank of a sample by a fitted model. The
method is for outlier detector score combination.
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_abod.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_ae1svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict_score(self):
self.clf.fit_predict_score(self.X_test, self.y_test)
self.clf.fit_predict_score(self.X_test, self.y_test,
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_alad.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ def test_prediction_proba_linear_confidence(self):
self.assertEqual(confidence.shape, self.y_test.shape)
self.assertInRange(confidence, 0, 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
self.assertEqual(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
self.assertGreaterEqual(expected_rejrate, 0)
self.assertLessEqual(expected_rejrate, 1)
self.assertGreaterEqual(ub_rejrate, 0)
self.assertLessEqual(ub_rejrate, 1)
self.assertGreaterEqual(ub_cost, 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
self.assertEqual(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_cblof.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def test_prediction_proba_linear_confidence(self):
assert_equal(confidence.shape, self.y_test.shape)
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def test_prediction_labels_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_prediction_proba_linear_confidence(self):
pred_proba, confidence = self.clf.predict_proba(self.X_test,
method='linear',
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_cof.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_copod.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_deepsvdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_devnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train, self.y_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
12 changes: 12 additions & 0 deletions pyod/test/test_dif.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ def test_prediction_proba_linear_confidence(self):
assert (confidence.min() >= 0)
assert (confidence.max() <= 1)

def test_prediction_with_rejection(self):
pred_labels = self.clf.predict_with_rejection(self.X_test, return_stats = False)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_with_rejection_stats(self):
_, [expected_rejrate, ub_rejrate, ub_cost] = self.clf.predict_with_rejection(self.X_test, return_stats = True)
assert (expected_rejrate >= 0)
assert (expected_rejrate <= 1)
assert (ub_rejrate >= 0)
assert (ub_rejrate <= 1)
assert (ub_cost >= 0)

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)
Expand Down
Loading
Loading