From 4c53de22ff735b3c226f23a799830a69ac0fa09f Mon Sep 17 00:00:00 2001 From: rmazzine Date: Mon, 11 Sep 2023 20:29:49 -0300 Subject: [PATCH 1/4] Implement Counterplots Signed-off-by: rmazzine --- dice_ml/counterfactual_explanations.py | 53 ++++++++++++++- requirements.txt | 1 + tests/test_counterfactual_explanations.py | 79 +++++++++++++++++++++++ 3 files changed, 132 insertions(+), 1 deletion(-) diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index c3a5dcea..4c25f336 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -2,9 +2,12 @@ import os import jsonschema +import numpy as np +import pandas as pd +from counterplots import CreatePlot from raiutils.exceptions import UserConfigValidationException -from dice_ml.constants import _SchemaVersions +from dice_ml.constants import _SchemaVersions, BackEndTypes from dice_ml.diverse_counterfactuals import (CounterfactualExamples, _DiverseCFV2SchemaConstants) @@ -111,6 +114,54 @@ def visualize_as_list(self, display_sparse_df=True, display_sparse_df=display_sparse_df, show_only_changes=show_only_changes) + def plot_counterplots(self, dice_model): + """Plot counterfactual with CounterPlots package. + + :param dice_model: DiCE's model object. + """ + counterplots_out = [] + for cf_examples in self.cf_examples_list: + features_names = list(cf_examples.test_instance_df.columns)[:-1] + features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1] + factual_instance = cf_examples.test_instance_df.to_numpy()[0][:-1] + + def convert_data(x): + df_x = pd.DataFrame(data=x, columns=features_names) + # Transform each dtype according to features_dtypes + for feature_name, f_dtype in zip(features_names, features_dtypes): + df_x[feature_name] = df_x[feature_name].astype(f_dtype) + + return df_x + + if dice_model.backend == BackEndTypes.Sklearn: + factual_class_idx = np.argmax( + dice_model.model.predict_proba(convert_data([factual_instance]))) + def model_pred(x): + # Use one against all strategy + pred_prob = dice_model.model.predict_proba(convert_data(x)) + class_f_proba = pred_prob[:, factual_class_idx] + + # Probability for all other classes (excluding class 0) + not_class_f_proba = 1 - class_f_proba + + # Normalize to sum to 1 + class_f_proba = class_f_proba / (class_f_proba + not_class_f_proba) + + return class_f_proba + else: + def model_pred(x): + return dice_model.model.predict(dice_model.transformer.transform(convert_data(x))) + + for cf_instance in cf_examples.final_cfs_df.to_numpy(): + counterplots_out.append( + CreatePlot( + factual=factual_instance, + cf=cf_instance[:-1], + model_pred=model_pred, + feature_names=features_names, + )) + return counterplots_out + @staticmethod def _check_cf_exp_output_against_json_schema( cf_dict, version): diff --git a/requirements.txt b/requirements.txt index 7f89d1f4..1c2686cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ pandas<2.0.0 scikit-learn tqdm raiutils>=0.4.0 +counterplots>=0.0.7 \ No newline at end of file diff --git a/tests/test_counterfactual_explanations.py b/tests/test_counterfactual_explanations.py index 4dcb5628..e6163a6f 100644 --- a/tests/test_counterfactual_explanations.py +++ b/tests/test_counterfactual_explanations.py @@ -1,8 +1,12 @@ import json import pytest +import unittest +from unittest.mock import patch, Mock from raiutils.exceptions import UserConfigValidationException +import pandas as pd +import numpy as np from dice_ml.counterfactual_explanations import CounterfactualExplanations @@ -319,3 +323,78 @@ def test_unsupported_versions_to_json(self, unsupported_version): counterfactual_explanations.to_json() assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve) + + +class TestCounterfactualExplanations(unittest.TestCase): + + @patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot") + def test_plot_counterplots_sklearn(self, mock_create_plot): + # Dummy DiCE's model object with a Sklearn backend + dummy_model = Mock() + dummy_model.backend = "sklearn" + dummy_model.model.predict_proba = Mock(return_value=np.array([[0.4, 0.6], [0.2, 0.8]])) + + # Sample cf_examples to test with + cf_examples_mock = Mock() + cf_examples_mock.test_instance_df = pd.DataFrame({ + 'feature1': [1], + 'feature2': [2], + 'target': [0] + }) + cf_examples_mock.final_cfs_df = pd.DataFrame({ + 'feature1': [1.1, 1.2], + 'feature2': [2.1, 2.2], + 'target': [1, 1] + }) + + counterfact = CounterfactualExplanations( + cf_examples_list=[cf_examples_mock], + local_importance=None, + summary_importance=None, + version=None) + + # Call function + result = counterfact.plot_counterplots(dummy_model) + + # Assert the CreatePlot was called twice (as there are 2 counterfactual instances) + self.assertEqual(mock_create_plot.call_count, 2) + + # Assert that the result is as expected + self.assertEqual(result, ["dummy_plot", "dummy_plot"]) + + @patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot") + def test_plot_counterplots_non_sklearn(self, mock_create_plot): + # Sample Non-Sklearn backend + dummy_model = Mock() + dummy_model.backend = "NonSklearn" + dummy_model.model.predict = Mock(return_value=np.array([0, 1])) + dummy_model.transformer = Mock() + dummy_model.transformer.transform = Mock(return_value=np.array([[1, 2], [1.1, 2.1]])) + + # Sample cf_examples to test with + cf_examples_mock = Mock() + cf_examples_mock.test_instance_df = pd.DataFrame({ + 'feature1': [1], + 'feature2': [2], + 'target': [0] + }) + cf_examples_mock.final_cfs_df = pd.DataFrame({ + 'feature1': [1.1, 1.2], + 'feature2': [2.1, 2.2], + 'target': [1, 1] + }) + + counterfact = CounterfactualExplanations( + cf_examples_list=[cf_examples_mock], + local_importance=None, + summary_importance=None, + version=None) + + # Call function + result = counterfact.plot_counterplots(dummy_model) + + # Assert the CreatePlot was called twice (as there are 2 counterfactual instances) + self.assertEqual(mock_create_plot.call_count, 2) + + # Assert that the result is as expected + self.assertEqual(result, ["dummy_plot", "dummy_plot"]) From 6c0799212a857d457bf9379c7d946403fc16b123 Mon Sep 17 00:00:00 2001 From: rmazzine Date: Mon, 11 Sep 2023 20:49:04 -0300 Subject: [PATCH 2/4] Fix importing order Signed-off-by: rmazzine --- dice_ml/counterfactual_explanations.py | 2 +- tests/test_counterfactual_explanations.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index 4c25f336..39b45a5b 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -7,7 +7,7 @@ from counterplots import CreatePlot from raiutils.exceptions import UserConfigValidationException -from dice_ml.constants import _SchemaVersions, BackEndTypes +from dice_ml.constants import BackEndTypes, _SchemaVersions from dice_ml.diverse_counterfactuals import (CounterfactualExamples, _DiverseCFV2SchemaConstants) diff --git a/tests/test_counterfactual_explanations.py b/tests/test_counterfactual_explanations.py index e6163a6f..65311564 100644 --- a/tests/test_counterfactual_explanations.py +++ b/tests/test_counterfactual_explanations.py @@ -1,12 +1,12 @@ import json +import unittest +from unittest.mock import Mock, patch +import numpy as np +import pandas as pd import pytest -import unittest -from unittest.mock import patch, Mock from raiutils.exceptions import UserConfigValidationException -import pandas as pd -import numpy as np from dice_ml.counterfactual_explanations import CounterfactualExplanations From 770121c742b1bf508d6ad9cf0d9fcee9ff5d384e Mon Sep 17 00:00:00 2001 From: rmazzine Date: Mon, 11 Sep 2023 21:22:07 -0300 Subject: [PATCH 3/4] Correct DataFrame numeric handle Signed-off-by: rmazzine --- dice_ml/counterfactual_explanations.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index 39b45a5b..5b67fee7 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -121,15 +121,15 @@ def plot_counterplots(self, dice_model): """ counterplots_out = [] for cf_examples in self.cf_examples_list: - features_names = list(cf_examples.test_instance_df.columns)[:-1] - features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1] + self.features_names = list(cf_examples.test_instance_df.columns)[:-1] + self.features_dtypes = list(cf_examples.test_instance_df.dtypes)[:-1] factual_instance = cf_examples.test_instance_df.to_numpy()[0][:-1] def convert_data(x): - df_x = pd.DataFrame(data=x, columns=features_names) + df_x = pd.DataFrame(data=x, columns=self.features_names) # Transform each dtype according to features_dtypes - for feature_name, f_dtype in zip(features_names, features_dtypes): - df_x[feature_name] = df_x[feature_name].astype(f_dtype) + for feature_name, f_dtype in zip(self.features_names, self.features_dtypes): + df_x[feature_name] = pd.to_numeric(df_x[feature_name], errors='ignore').astype(f_dtype) return df_x @@ -158,7 +158,7 @@ def model_pred(x): factual=factual_instance, cf=cf_instance[:-1], model_pred=model_pred, - feature_names=features_names, + feature_names=self.features_names, )) return counterplots_out From a0f37a9d29edf417747878fe065b417c5e24bc12 Mon Sep 17 00:00:00 2001 From: rmazzine Date: Mon, 11 Sep 2023 21:25:56 -0300 Subject: [PATCH 4/4] Complay with flake8 rules Signed-off-by: rmazzine --- dice_ml/counterfactual_explanations.py | 5 +++-- tests/test_counterfactual_explanations.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index 5b67fee7..ac1ea7e3 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -134,12 +134,13 @@ def convert_data(x): return df_x if dice_model.backend == BackEndTypes.Sklearn: - factual_class_idx = np.argmax( + self.factual_class_idx = np.argmax( dice_model.model.predict_proba(convert_data([factual_instance]))) + def model_pred(x): # Use one against all strategy pred_prob = dice_model.model.predict_proba(convert_data(x)) - class_f_proba = pred_prob[:, factual_class_idx] + class_f_proba = pred_prob[:, self.factual_class_idx] # Probability for all other classes (excluding class 0) not_class_f_proba = 1 - class_f_proba diff --git a/tests/test_counterfactual_explanations.py b/tests/test_counterfactual_explanations.py index 65311564..7593701a 100644 --- a/tests/test_counterfactual_explanations.py +++ b/tests/test_counterfactual_explanations.py @@ -325,7 +325,7 @@ def test_unsupported_versions_to_json(self, unsupported_version): assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve) -class TestCounterfactualExplanations(unittest.TestCase): +class TestCounterfactualExplanationsPlot(unittest.TestCase): @patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot") def test_plot_counterplots_sklearn(self, mock_create_plot): @@ -357,10 +357,10 @@ def test_plot_counterplots_sklearn(self, mock_create_plot): result = counterfact.plot_counterplots(dummy_model) # Assert the CreatePlot was called twice (as there are 2 counterfactual instances) - self.assertEqual(mock_create_plot.call_count, 2) + assert mock_create_plot.call_count == 2 # Assert that the result is as expected - self.assertEqual(result, ["dummy_plot", "dummy_plot"]) + assert result == ["dummy_plot", "dummy_plot"] @patch('dice_ml.counterfactual_explanations.CreatePlot', return_value="dummy_plot") def test_plot_counterplots_non_sklearn(self, mock_create_plot): @@ -394,7 +394,7 @@ def test_plot_counterplots_non_sklearn(self, mock_create_plot): result = counterfact.plot_counterplots(dummy_model) # Assert the CreatePlot was called twice (as there are 2 counterfactual instances) - self.assertEqual(mock_create_plot.call_count, 2) + assert mock_create_plot.call_count == 2 # Assert that the result is as expected - self.assertEqual(result, ["dummy_plot", "dummy_plot"]) + assert result == ["dummy_plot", "dummy_plot"]