-
Notifications
You must be signed in to change notification settings - Fork 23
/
predict.py
39 lines (31 loc) · 1.5 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
from model import MLPMixer
from argparse import ArgumentParser
import os
import numpy as np
if __name__ == "__main__":
home_dir = os.getcwd()
parser = ArgumentParser()
parser.add_argument("--test-file-path", default='{}/data/test'.format(home_dir), type=str, required=True)
parser.add_argument("--model-folder", default='{}/model/mlp/'.format(home_dir), type=str)
parser.add_argument("--image-size", default=150, type=int)
args = parser.parse_args()
print('---------------------Welcome to ProtonX MLP Mixer-------------------')
print('Github: bangoc123')
print('Email: [email protected]')
print('---------------------------------------------------------------------')
print('Predict using MLP Mixer for image path: {}'.format(args.test_file_path))
print('===========================')
# Loading Model
mlpmixer = tf.keras.models.load_model(args.model_folder)
# Load test images from folder
image = tf.keras.preprocessing.image.load_img(args.test_file_path)
input_arr = tf.keras.preprocessing.image.img_to_array(image)
input_arr = np.array([input_arr]) # Convert single image to a batch.
x = tf.image.resize(
input_arr, [args.image_size, args.image_size]
)
predictions = mlpmixer.predict(x)
print('---------------------Prediction Result: -------------------')
print('Output Softmax: {}'.format(predictions))
print('This image belongs to class: {}'.format(np.argmax(predictions), axis=1))