Skip to content

Commit

Permalink
Merge pull request #36 from fzi-forschungszentrum-informatik/COMTE
Browse files Browse the repository at this point in the history
Comte
  • Loading branch information
JHoelli committed Aug 8, 2023
2 parents 8158c8d + b949537 commit 503a196
Show file tree
Hide file tree
Showing 19 changed files with 604 additions and 12,456 deletions.
Binary file not shown.
Binary file modified ClassificationModels/models/ECG5000/OneHotEncoder.pkl
Binary file not shown.
Binary file modified ClassificationModels/models/ECG5000/ResNet
Binary file not shown.
Binary file modified ClassificationModels/models/ECG5000/ResNet_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions ClassificationModels/models/ECG5000/classification_report.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
0,1,2,3,4,accuracy,macro avg,weighted avg
0.9821088694328131,0.9282567652611705,0.6521739130434783,0.32051282051282054,0.5,0.9248888888888889,0.6766104736500564,0.928689990417054
0.9821088694328131,0.9276729559748428,0.3488372093023256,0.42857142857142855,0.09090909090909091,0.9248888888888889,0.5556199108381001,0.9248888888888889
0.9821088694328131,0.9279647687952186,0.4545454545454546,0.36674816625916873,0.15384615384615385,0.9248888888888889,0.5770426825757616,0.9249156524345059
2627.0,1590.0,86.0,175.0,22.0,0.9248888888888889,4500.0,4500.0
0.967861100849649,0.9443671766342142,0.6530612244897959,0.2962962962962963,0.5555555555555556,0.9117777777777778,0.6834282707651023,0.9254116138134725
0.9973353635325466,0.8540880503144654,0.37209302325581395,0.5028571428571429,0.22727272727272727,0.9117777777777778,0.5907292614465391,0.9117777777777778
0.9823772028496438,0.8969616908850727,0.47407407407407404,0.37288135593220334,0.3225806451612903,0.9117777777777778,0.6097749937804569,0.9155545293878521
2627.0,1590.0,86.0,175.0,22.0,0.9117777777777778,4500.0,4500.0
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ pip install https://github.com/fzi-forschungszentrum-informatik/TSInterpret/arch
```


Due to the <a href='https://github.com/scikit-learn/sklearn-pypi-package'>sklearn brownout</a> `pip install sklearn` is no longer available in third party dependencies. As the current release of <a href='https://github.com/gkhayes/mlrose'>mlrose</a> still relies on sklearn, we eliminated the dependency. If you still want to use COMTE (with dependecy to mlrose), it can be installed via :
```shell
pip install https://github.com/gkhayes/mlrose/archive/refs/heads/master.zip
```


## 🍫 Quickstart
The following example creates a simple Neural Network based on tensorflow and interprets the Classfier with Integrated Gradients and Temporal Saliency Rescaling [1].
Expand Down
2 changes: 1 addition & 1 deletion TSInterpret/InterpretabilityModels/counterfactual/CF.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def plot_in_one(
"""
if self.mode == "time":
item = item.reshape(item.shape[-1], item.shape[-2])
exp = exp.reshape(item.shape[-1], item.shape[-2])
exp = exp.reshape(exp.shape[-1], exp.shape[-2])
else:
item = item.reshape(item.shape[-2], item.shape[-1])
exp = exp.reshape(item.shape[-2], item.shape[-1])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import logging
import multiprocessing
import numbers
import sys
from typing import Tuple
import numpy as np
import pandas as pd

# Workaround for mlrose package
import six
from sklearn.neighbors import KDTree
from skopt import gbrt_minimize, gp_minimize

sys.modules["sklearn.externals.six"] = six
import mlrose

from TSInterpret.InterpretabilityModels.counterfactual.CF import CF
from TSInterpret.Models.PyTorchModel import PyTorchModel
from TSInterpret.Models.SklearnModel import SklearnModel
from TSInterpret.Models.TensorflowModel import TensorFlowModel
from TSInterpret.InterpretabilityModels.counterfactual.COMTE.Problem import (
LossDiscreteState,
Problem,
)
from TSInterpret.InterpretabilityModels.counterfactual.COMTE.Optmization_helpers import (
random_hill_climb,
)


class BaseExplanation:
Expand Down Expand Up @@ -91,9 +86,6 @@ def _get_distractors(self, x_test, to_maximize, n_distractors=2):
if isinstance(to_maximize, numbers.Integral):
to_maximize = np.unique(self.labels)[to_maximize]
distractors = []
# print('to_maximize',to_maximize)
# print('Class Tree',self.per_class_trees)
# print('Class Tree with id',self.per_class_trees[to_maximize])
for idx in (
self.per_class_trees[to_maximize]
.query(x_test.T.flatten().reshape(1, -1), k=n_distractors)[1]
Expand Down Expand Up @@ -332,58 +324,6 @@ def explain(self, x_test, to_maximize=None, num_features=10):
return other, target


class LossDiscreteState:
def __init__(
self,
label_idx,
clf,
x_test,
distractor,
cols_swap,
reg,
max_features=3,
maximize=True,
):
self.target = label_idx
self.clf = clf
self.x_test = x_test
self.reg = reg
self.distractor = distractor
self.cols_swap = cols_swap # Column names that we can swap
self.prob_type = "discrete"
self.max_features = 3 if max_features is None else max_features
self.maximize = maximize
self.window_size = x_test.shape[-1]
self.channels = x_test.shape[-2]

def __call__(self, feature_matrix):
return self.evaluate(feature_matrix)

def evaluate(self, feature_matrix):
new_case = self.x_test.copy()
assert len(self.cols_swap) == len(feature_matrix)

for col_replace, a in zip(self.cols_swap, feature_matrix):
if a == 1:
new_case[0][col_replace] = self.distractor[0][col_replace]

replaced_feature_count = np.sum(feature_matrix)

input_ = new_case.reshape(1, self.channels, self.window_size)
result = self.clf(input_)[0][self.target]
feature_loss = self.reg * np.maximum(
0, replaced_feature_count - self.max_features
)
loss_pred = np.square(np.maximum(0, 0.95 - result))

loss_pred = loss_pred + feature_loss
# print(loss_pred)
return -loss_pred if self.maximize else loss_pred

def get_prob_type(self):
return self.prob_type


class OptimizedSearch(BaseExplanation):
def __init__(
self,
Expand Down Expand Up @@ -414,10 +354,8 @@ def opt_Discrete(self, to_maximize, x_test, dist, columns, init, num_features=No
max_features=num_features,
maximize=False,
)
problem = mlrose.DiscreteOpt(
length=len(columns), fitness_fn=fitness_fn, maximize=False, max_val=2
)
best_state, best_fitness = mlrose.random_hill_climb(
problem = Problem(length=len(columns), loss=fitness_fn, max_val=2)
best_state, best_fitness = random_hill_climb(
problem,
max_attempts=self.max_attemps,
max_iters=self.maxiter,
Expand Down Expand Up @@ -469,9 +407,6 @@ def explain(
if to_maximize is None:
to_maximize = np.argsort(orig_preds)[0][-2:-1][0]

# print('Current may',np.argmax(orig_preds))
# print(to_maximize)

if orig_label == to_maximize:
print("Original and Target Label are identical !")
return None, None
Expand All @@ -494,6 +429,8 @@ def _get_explanation(self, x_test, to_maximize, num_features):
distractors = self._get_distractors(
x_test, to_maximize, n_distractors=self.num_distractors
)
# print('distracotr shape',np.array(distractors).shape)
# print('distracotr classification',np.argmax(self.clf(np.array(distractors).reshape(2,6,100)), axis=1))

# Avoid constructing KDtrees twice
self.backup.per_class_trees = self.per_class_trees
Expand Down Expand Up @@ -537,7 +474,6 @@ def _get_explanation(self, x_test, to_maximize, num_features):

if not self.silent:
logging.info("Current probas: %s", probas)

if np.argmax(probas) == to_maximize:
current_best = np.max(probas)
if current_best > best_explanation_score:
Expand All @@ -549,100 +485,3 @@ def _get_explanation(self, x_test, to_maximize, num_features):
return None, None

return best_modified, best_explanation


class AtesCF(CF):
"""Calculates and Visualizes Counterfactuals for Multivariate Time Series in accordance to the paper [1].
References
----------
[1] Ates, Emre, et al.
"Counterfactual Explanations for Multivariate Time Series."
2021 International Conference on Applied Artificial Intelligence (ICAPAI). IEEE, 2021.
----------
"""

def __init__(
self,
model,
data,
backend,
mode,
method="opt",
number_distractors=2,
max_attempts=1000,
max_iter=1000,
silent=False,
) -> None:
"""
Arguments:
model [torch.nn.Module, Callable, tf.keras.model]: Model to be interpreted.
ref Tuple: Reference Dataset as Tuple (x,y).
backend str: desired Model Backend ('PYT', 'TF', 'SK').
mode str: Name of second dimension: `time` -> `(-1, time, feature)` or `feat` -> `(-1, feature, time)`
method str : 'opt' if optimized calculation, 'brute' for Brute Force
number_distractors int: number of distractore to be used
silent bool: logging.
"""
super().__init__(model, mode)
self.backend = backend
test_x, test_y = data
shape = test_x.shape
if mode == "time":
# Parse test data into (1, feat, time):
change = True
self.ts_length = shape[-2]
test_x = test_x.reshape(test_x.shape[0], test_x.shape[2], test_x.shape[1])
elif mode == "feat":
change = False
self.ts_length = shape[-1]

if backend == "PYT":
self.predict = PyTorchModel(model, change).predict
elif backend == "TF":
self.predict = TensorFlowModel(model, change).predict

elif backend == "SK":
self.predict = SklearnModel(model, change).predict
self.referenceset = (test_x, test_y)
self.method = method
self.silent = silent
self.number_distractors = number_distractors
self.max_attemps = max_attempts
self.max_iter = max_iter

def explain(
self, x: np.ndarray, orig_class: int = None, target: int = None
) -> Tuple[np.ndarray, int]:
"""
Calculates the Counterfactual according to Ates.
Arguments:
x (np.array): The instance to explain. Shape : `mode = time` -> `(1,time, feat)` or `mode = time` -> `(1,feat, time)`
target int: target class. If no target class is given, the class with the secon heighest classification probability is selected.
Returns:
([np.array], int): Tuple of Counterfactual and Label. Shape of CF : `mode = time` -> `(time, feat)` or `mode = time` -> `(feat, time)`
"""

if self.mode != "feat":
x = x.reshape(-1, x.shape[-1], x.shape[-2])
train_x, train_y = self.referenceset
if len(train_y.shape) > 1:
train_y = np.argmax(train_y, axis=1)
if self.method == "opt":
opt = OptimizedSearch(
self.predict,
train_x,
train_y,
silent=self.silent,
threads=1,
num_distractors=self.number_distractors,
max_attempts=self.max_attemps,
maxiter=self.max_iter,
)
return opt.explain(x, to_maximize=target)
elif self.method == "brute":
opt = BruteForceSearch(self.predict, train_x, train_y, threads=1)
return opt.explain(x, to_maximize=target)
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np


def random_hill_climb(
problem,
max_attempts=10,
max_iters=np.inf,
restarts=0,
init_state=None,
curve=False,
random_state=None,
):
# Set random seed
if isinstance(random_state, int) and random_state > 0:
np.random.seed(random_state)

best_fitness = np.inf
best_state = None

if curve:
fitness_values = []

for _ in range(restarts + 1):
# Initialize optimization problem and attempts counter
if init_state is None:
problem.reset()
else:
problem.set_state(init_state)

attempts = 0
iters = 0

while (attempts < max_attempts) and (iters < max_iters):
iters += 1

# Find random neighbor and evaluate fitness
next_state = problem.random_neighbor()
next_fitness = problem.eval_fitness(next_state)

if next_fitness < problem.get_fitness():
problem.set_state(next_state)
attempts = 0

else:
attempts += 1

if curve:
fitness_values.append(problem.get_fitness())

# Update best state and best fitness
# print('best_fitness',best_fitness)
if problem.get_fitness() < best_fitness:
best_fitness = problem.get_fitness()
best_state = problem.get_state()
# print('bestfitness after', best_fitness)

if curve:
import matplotlib.pyplot as plt

plt.plot(np.asarray(fitness_values))
plt.show()

return best_state, best_fitness
Loading

0 comments on commit 503a196

Please sign in to comment.