-
Notifications
You must be signed in to change notification settings - Fork 14
/
metrics.py
executable file
·36 lines (33 loc) · 1.62 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Adapted from score written by wkentaro
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
import numpy as np
class runningScore(object):
def __init__(self, n_classes):
self.n_classes = n_classes
self.confusion_matrix = np.zeros((n_classes, n_classes))
def _fast_hist(self, label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(n_class*label_true[mask].astype(int)+label_pred[mask], minlength=n_class**2).reshape(n_class, n_class)
return hist
def update(self, label_trues, label_preds):
for lt, lp in zip(label_trues, label_preds):
self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
def get_scores(self):
hist = self.confusion_matrix
acc = np.diag(hist).sum() / hist.sum()
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
dice=np.divide(np.multiply(iu,2),np.add(iu,1))
mean_iu = np.nanmean(iu[1:9])
mean_dice=(mean_iu*2)/(mean_iu+1)
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
cls_iu = dict(zip(range(self.n_classes), iu))
return {#'Overall Acc: \t': acc,
#'Mean Acc : \t': acc_cls,
#'FreqW Acc : \t': fwavacc,
'Dice : \t': dice,
'Mean Dice : \t': mean_dice,}, cls_iu
def reset(self):
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))