-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
60 lines (54 loc) · 2.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import collections as mc
def get_adjacency_matrix(tri):
"""Adapted from https://github.com/danielegrattarola/spektral
Input: scipy.spatial.Delaunay triangulation
Output: adjacency_matrix
"""
# Important: edges may contain duplicates
edges = np.concatenate((tri.vertices[:, :2],
tri.vertices[:, 1:],
tri.vertices[:, ::2]), axis=0)
n = tri.points.shape[0]
adj = np.zeros((n, n), dtype=np.int64)
adj[edges[:, 0], edges[:, 1]] = 1
return np.clip(adj + adj.T, 0, 1)
def get_matches(X):
"""
return a list of matches (i, p) from the assignment matrix X
"""
return np.transpose(np.nonzero(X))
def draw_matches(plot, points1, points2, matches=None, colorm='g', s=50, linewidth=2):
fig, ax = plot
# Draw point based on above x, y axis values.
plt.scatter(points1[:, 0], points1[:, 1], s=s)
plt.scatter(points2[:, 0], points2[:, 1], s=s)
if matches is not None:
lines = []
for i1, i2 in matches:
lines.append([points1[i1], points2[i2]])
colors = [colorm]*len(lines)
lc = mc.LineCollection(lines, colors=colors, linewidths=linewidth)
ax.add_collection(lc)
def draw_results(plot, points1, points2, X=None, X_gt=None):
"""
Draw good matches (true positives) in green, bad matches (false positives)
in red, and missed matches (false negatives) in yellow. If X_gt is not given
then draw all matches in green. If in addition, X is not given, then draw
only the points
Args:
X: assignment matrix
X_gt: ground-truth assignment matrix
"""
if X is None:
draw_matches(plot, points1, points2, matches=None)
elif X_gt is None:
draw_matches(plot, points1, points2, matches=get_matches(X))
else:
# Draw true positives in green
draw_matches(plot, points1, points2, matches=get_matches(X & X_gt), colorm='g')
# Draw false positives in red
draw_matches(plot, points1, points2, matches=get_matches(X & ~X_gt), colorm='r')
# Draw false negatives in yellow
draw_matches(plot, points1, points2, matches=get_matches(~X & X_gt), colorm='y')