-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathmnist_helpers.py
114 lines (84 loc) · 3.54 KB
/
mnist_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Standard scientific Python imports
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import numpy as np
def show_some_digits(images, targets, sample_size=24, title_text='Digit {}' ):
'''
Visualize random digits in a grid plot
images - array of flatten gidigs [:,784]
targets - final labels
'''
nsamples=sample_size
rand_idx = np.random.choice(images.shape[0],nsamples)
images_and_labels = list(zip(images[rand_idx], targets[rand_idx]))
img = plt.figure(1, figsize=(15, 12), dpi=160)
for index, (image, label) in enumerate(images_and_labels):
plt.subplot(np.ceil(nsamples/6.0), 6, index + 1)
plt.axis('off')
#each image is flat, we have to reshape to 2D array 28x28-784
plt.imshow(image.reshape(28,28), cmap=plt.cm.gray_r, interpolation='nearest')
plt.title(title_text.format(label))
plt.show()
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
"""
Plots confusion matrix,
cm - confusion matrix
"""
plt.figure(1, figsize=(15, 12), dpi=160)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
class MidpointNormalize(Normalize):
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
def plot_param_space_heatmap(scores, C_range, gamma_range):
"""
Draw heatmap of the validation accuracy as a function of gamma and C
Parameters
----------
scores - 2D numpy array with accuracies
"""
#
# The score are encoded as colors with the hot colormap which varies from dark
# red to bright yellow. As the most interesting scores are all located in the
# 0.92 to 0.97 range we use a custom normalizer to set the mid-point to 0.92 so
# as to make it easier to visualize the small variations of score values in the
# interesting range while not brutally collapsing all the low score values to
# the same color.
plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.jet,
norm=MidpointNormalize(vmin=0.5, midpoint=0.9))
plt.xlabel('gamma')
plt.ylabel('C')
plt.colorbar()
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
plt.title('Validation accuracy')
plt.show()
def plot_param_space_bubble(scores, x_range, y_range):
"""
Plot scatter plot of the validation accuracy as a function of gamma and C
Parameters
----------
scores - 2D numpy array with accuracies
"""
plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
# Change color with c and alpha. I map the color to the X axis value.
plt.scatter(x_range, y_range, s=scores*2000, c=scores, cmap="Blues", alpha=0.4, edgecolors="grey", linewidth=2)
plt.xlabel('C')
plt.ylabel('gamma')
plt.colorbar()
# plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
# plt.yticks(np.arange(len(C_range)), C_range)
plt.title('Validation accuracy')
plt.show()