Skip to content

Shrike.jl: Fast approximate nearest neighbor search with random projection trees. (Benchmarks included)

License

Notifications You must be signed in to change notification settings

djpasseyjr/Shrike.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Shrike.jl

Build Status codecov

Random Projection Splits

Shrike is a 100% pure Julia package for building ensembles of random projection trees. Random projection trees are a generalization of KD-Trees and are used to quickly approximate nearest neighbors or build k-nearest-neighbor graphs. They conform to low dimensionality that is often present in high dimensional data.

The implementation here is based on the MRPT algorithm. This package also includes optimizations for knn-graph creation and has built-in support for multithreading.

Installation

To install just type

] add Shrike

in the REPL or

using Pkg
Pkg.add("Shrike")

Build an Index

To build an ensemble of random projection trees use the ShrikeIndex type.

using Shrike
maxk = 100
X = rand(100, 10000)
shi = ShrikeIndex(X, maxk; depth=8, ntrees=10)

The type accepts a matrix of data, X where each column represents a datapoint.

  1. maxk represents the maximum number of nearest neighbors you will be able to find with this index. maxk is used to set a safe depth for the tree. You can also construct an index without this parameter if you need to.
  2. depth describes the number of times each random projection tree will split the data. Leaf nodes in the tree contain about npoints / 2^depth data points. Increasing depth increases speed but decreases accuracy. By default, the index sets depth as large as possible.
  3. ntrees controls the number of trees in the ensemble. More trees means more accuracy but more memory.

In this case, since we need an index that can find the 100 nearest neighbors, setting depth equal to 8 will result in some leaf nodes with less than 100 points. The index will infer this using maxk and set the depth to be as large as possible given maxk. In this case, depth = 6.

To query the index for approximate 10 nearest neighbors use:

k = 10
q = X[:, 1]
approx_nn = ann(shi, q, k; vote_cutoff=2)
  1. The vote_cutoff parameter signifies how many "votes" a point needs in order to be included in a linear search. Increasing vote_cutoff speeds up the algorithm but may reduce accuracy. Each tree "votes" for all points in relevant leaf nodes. If there aren't many points in the leaves, and there aren't many trees, the odds of a point receiving more than one vote is low. Thus, when depth is large and ntrees is less than 5, it is reccomended to set vote_cutoff = 1.

KNN-Graphs

This package includes fast algorithms to generate k-nearest-neighbor graphs and has specialized functions for this purpose. It uses neighbor of neighbor exploration (outlined here) to efficiently improve the accuracy of a knn-graph.

Nearest neighbor graphs are used to give a sparse topology to large datasets. Their structure can be used to project the data onto a lower dimensional manifold, to cluster datapoints with community detection algorithms or to preform other analyses.

To generate nearest neighbor graphs:

using Shrike
X = rand(100, 10000)
shi = ShrikeIndex(X; depth=6, ntrees=5)
k = 10
g = knngraph(shi, k; vote_cutoff=1, ne_iters=1, gtype=SimpleDiGraph)
  1. The vote_cutoff parameter signifies how many "votes" a point needs in order to be included in a linear search.
  2. ne_iters controlls how many iterations of neighbor exploration the algorithm will undergo. Successive iterations are increasingly fast. It is reccomened to use more iterations of neighbor exploration when the number of trees is small and less when many trees are used.
  3. The gtype parameter allows the user to specify a LightGraphs.jl graph type to return. gtype=identity returns a sparse adjacency matrix.

If an array of nearest neighbor indices is preferred,

nn = allknn(shi, k; vote_cutoff=1, ne_iters=0)

can be used to generate an shi.npointsxk array of integer indexes where nn[i, :] corresponds to the nearest neighbors of X[:, i]. The keyword arguments work in the same way as in knngraph (outlined above).

Threading

Shrike has built in support for multithreading. To allocate multiple threads, start julia with the --threads flag:

user@sys:~$ julia --threads 4

To see this at work, consider a small scale example:

user@sys:~$ cmd="using Shrike; shi=ShrikeIndex(rand(100, 10000)); @time knngraph(shi, 10, ne_iters=1)"
user@sys:~$ julia -e "$cmd"
  12.373127 seconds (8.66 M allocations: 4.510 GiB, 6.85% gc time, 18.88% compilation time)
user@sys:~$ julia  --threads 4 -e "$cmd"
  6.306410 seconds (8.67 M allocations: 4.498 GiB, 13.12% gc time, 31.64% compilation time)

(This assumes that Shrike is installed.)

Benchmark

This package was compared to the original mrpt C++ implementation (on which this algorithm was based), annoy, a popular package for approximate nearest neighbors, and NearestNeighbors.jl, a Julia package for nearest neighbor search. The benchmarks were written in the spirit of ann-benchmarks, a repository for comparing different approximate nearest neighbor algorithms. The datasets used for the benchmark were taken directly from ann-benchmarks. The following are links to the HDF5 files in question: FashionMNIST, SIFT, MNIST and GIST. The benchmarks below were run on a compute cluster, restricting all algorithms to a single thread.

FashionMNIST Speed Comparison

In this plot, up and to the right is better. (Faster queries, better recall). Each point represents a parameter combination. For a full documentation of parameters run and timing methods consult the original scripts located in the benchmark/ directory.

This plot illustrates how for this dataset, on most parameter combinations, Shrike has better preformance. Compared to SIFT, below, where some parameter combinations are not as strong. We speculate that this has to do with the high dimensionality of points in FashionMNIST (d=784), compared to the lower dimensionality of SIFT (d=128).

SIFT Speed Comparison

It is important to note that NearestNeighbors.jl was designed to return the exact k-nearest-neighbors as quickly as possible, and does not approximate, hence the high accuracy and lower speed.

The takeaway here is that Shrike is fast! It is possibly a little faster than the original C++ implementation. Go Julia! We should note, that Shrike was not benchmarked against state of the art algorithms for approximate nearest neighbor search. These algorithms are faster than annoy and mrpt, but unfortunately, the developers of Shrike aren't familiar with these algorithms.

Releases

No releases published

Packages

No packages published