-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
122 lines (101 loc) · 4.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import random
import os
import tensorflow as tf
import tensorflow_hub as hub
import datetime
def view_random_images(target_dir, target_class):
# Setup the target directory (we'll view images from here)
target_folder = target_dir+target_class
# Get a random image path
random_image = random.sample(os.listdir(target_folder), 1)
# Read in the image and plot it using matplotlib
img = mpimg.imread(target_folder + "/" + random_image[0])
plt.imshow(img)
plt.title(target_class)
plt.axis("off");
print(f"Image shape: {img.shape}") # show the shape of the image
return img
def plot_training_curves(history):
# Plot training & validation accuracy values
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.tight_layout()
plt.show()
# Create a function to import and image and resize it to be able to be used with our model (prediction on custom data)
def load_and_prep_image(filename, img_shape=224):
"""
Reads an image from filename, turns it into a tensor and reshapes it
to (img_shape, img_shape, colour_channels).
"""
# Read in the image
img = tf.io.read_file(filename)
# Decode the read file into a tensor
img = tf.image.decode_image(img)
# Resize the image
img = tf.image.resize(img, size=[img_shape, img_shape])
# Rescale the image (get all values between 0 and 1)
img = img/255.
return img
# Reconfig pred_and_plot function to work with multi-class images
def pred_and_plot(model, filename, class_names):
"""
Imports an image located at filename, makes a prediction with model
and plots the image with the predicted class as the title.
"""
# Import the target image and preprocess it
img = load_and_prep_image(filename)
# Make a prediction
pred = model.predict(tf.expand_dims(img, axis=0))
# Add in logic for multi-class & get pred_class name
if len(pred[0]) > 1:
pred_class = class_names[tf.argmax(pred[0])]
else:
pred_class = class_names[int(tf.round(pred[0]))]
# Plot the image and predicted class
plt.imshow(img)
plt.title(f"Prediction: {pred_class}")
plt.axis(False);
def create_feature_extractor_th(model_url):
# download the pretrained model and save it as a keras layer
feature_extractor = hub.KerasLayer(model_url,
trainable=False, # freeze the already learned pattern
name='feature_extraction_layer',
input_shape=(224, 224, 3))
return feature_extractor
# walk through directories
def walk_through_dir(path):
for dirpath, dirnames, filesnames in os.walk(path):
print(f"There are {len(dirnames)} directories and {len(filesnames)} images in {dirpath}")
# create a model checkpoint callback that saves the model's weights only
def checkpoint_callback_weights_only(path):
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=path,
save_weights_only=True,
save_best_only=False,
save_freq='epoch', # save every epoch
verbose=0, # dont print text,
monitor='val_acc'
)
return checkpoint_callback
# create a tensorboard callback
def create_tensorboard_callback(dir_name, experiment_name):
log_dir = dir_name + os.sep + experiment_name + os.sep + datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
print(f'Saving Tensorboard log files to: {log_dir}')
return tensorboard_callback