-
Notifications
You must be signed in to change notification settings - Fork 4
/
app.py
146 lines (120 loc) · 4.61 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
from io import BytesIO
from PIL import Image
import requests
import boto3
from math import exp
import matplotlib.pyplot as plt
SAGEMAKER_ENDPOINTS = ["XXXXXXX"]
REGIONS = ["XXXXXX"]
def img_upload_side_bar(col_display) -> tuple:
def upload_vis_image(file_path, col_vis):
if file_path is not None:
image = (
Image.open(BytesIO(file_path.read()), mode="r")
.convert("RGB")
.resize((600, 400))
)
with col_vis:
image_loc = st.empty()
image_loc.image(image, use_column_width=True)
return image_loc # col_vis.image(image, use_column_width=True)
st.sidebar.title("Image Uploading")
# Disabling warning
st.set_option("deprecation.showfileUploaderEncoding", False)
# Choose your own image
left_file = st.sidebar.file_uploader(
"Upload 1st Image", type=["png", "jpeg", "jpg"]
)
left_image_loc = upload_vis_image(left_file, col_display[0])
right_file = st.sidebar.file_uploader(
"Upload 2nd Image", type=["png", "jpeg", "jpg"]
)
right_image_loc = upload_vis_image(right_file, col_display[1])
return left_file, right_file, left_image_loc, right_image_loc
def button(left_image_path, right_image_path):
st.sidebar.write("\n")
st.sidebar.title("Model Endpoint")
region = st.sidebar.selectbox("AWS Region", REGIONS, help="AWS Region")
endpoint_name = st.sidebar.selectbox(
"Amazon SageMaker Endpoint",
SAGEMAKER_ENDPOINTS,
help="Amazon SageMaker Endpoint to invoke",
)
cam = st.sidebar.checkbox("Classification Activation Map (CAM)")
if st.sidebar.button("Predict"):
if left_image_path is None:
st.sidebar.error("Please upload 1st image.")
return None, None, None
if right_image_path is None:
st.sidebar.error("Please upload 2nd image.")
return None, None, None
elif left_image_path is not None and right_image_path is not None:
with st.spinner("Analyzing..."):
r = requests.Request(
"POST",
"http://localhost:8080/invocations",
data={"cam": str(cam)},
files={
"left": left_image_path.getvalue(),
"right": right_image_path.getvalue(),
},
).prepare()
content_type, payload = r.headers["Content-Type"], r.body
client = boto3.client("sagemaker-runtime", region_name=region)
response = client.invoke_endpoint(
EndpointName=endpoint_name, ContentType=content_type, Body=payload
)
neg, pos, *maps = eval(response["Body"].read())
return neg, pos, maps
else:
return None, None, None
def prediction(negative_logit, positive_logit):
return exp(negative_logit) / (exp(negative_logit) + exp(positive_logit)), exp(
positive_logit
) / (exp(negative_logit) + exp(positive_logit))
def cam_vis(cam_map, col_vis, file_path):
fig, ax = plt.subplots(figsize=(5, 5))
fig.patch.set_visible(False)
image = (
Image.open(BytesIO(file_path.getvalue()), mode="r")
.convert("RGB")
.resize((600, 400))
)
ax.imshow(image)
ax.imshow(
cam_map,
alpha=0.3,
extent=(0, 600, 400, 0),
interpolation="bilinear",
cmap="jet",
)
ax.axis("off")
ax.set(frame_on=False)
fig.tight_layout()
col_vis.pyplot(fig, width=None)
def main():
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("Twin Classification")
# For newline
st.write("\n")
cols = st.columns((1, 1))
cols[0].header("Left Image")
cols[1].header("Right Image")
left_image, right_image, left_image_loc, right_image_loc = img_upload_side_bar(cols)
negative_logit, positive_logit, cam_map = button(left_image, right_image)
if negative_logit is not None and positive_logit is not None:
prob = prediction(negative_logit, positive_logit)[
negative_logit < positive_logit
]
text = ["Different", "Similar"][negative_logit < positive_logit]
st.subheader(f"{text}: {prob:.2%}")
if cam_map is not None and len(cam_map) > 0:
length = len(cam_map)
cam_map_left, cam_map_right = cam_map[: length // 2], cam_map[length // 2 :]
cam_vis(cam_map_left, left_image_loc, left_image)
cam_vis(cam_map_right, right_image_loc, right_image)
if __name__ == "__main__":
main()