-
Notifications
You must be signed in to change notification settings - Fork 16
/
pyw_hnswlib.py
63 lines (53 loc) · 1.88 KB
/
pyw_hnswlib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# import hnswlib
import numpy as np
import threading
import pickle
import hnswlib
class Index():
def __init__(self, space, dim):
self.index = hnswlib.Index(space, dim)
self.lock = threading.Lock()
self.dict_labels = {}
self.cur_ind = 0
def init_index(self, max_elements, ef_construction = 200, M = 16):
self.index.init_index(max_elements = max_elements, ef_construction = ef_construction, M = M)
def add_items(self, data, ids=None):
if ids is not None:
assert len(data) == len(ids)
num_added = len(data)
with self.lock:
start = self.cur_ind
self.cur_ind += num_added
int_labels = []
if ids is not None:
for dl in ids:
int_labels.append(start)
self.dict_labels[start] = dl
start += 1
else:
for _ in range(len(data)):
int_labels.append(start)
self.dict_labels[start] = start
start += 1
self.index.add_items(data=data, ids=np.asarray(int_labels))
def set_ef(self, ef):
self.index.set_ef(ef)
def load_index(self, path):
self.index.load_index(path)
with open(path + ".pkl", "rb") as f:
self.cur_ind, self.dict_labels = pickle.load(f)
def save_index(self, path):
self.index.save_index(path)
with open(path + ".pkl", "wb") as f:
pickle.dump((self.cur_ind, self.dict_labels), f)
def set_num_threads(self, num_threads):
self.index.set_num_threads(num_threads)
def knn_query(self, data, k=1):
labels_int, distances = self.index.knn_query(data=data, k=k)
labels = []
for li in labels_int:
line = []
for l in li:
line.append(self.dict_labels[l])
labels.append(line)
return labels, distances