Skip to content

Commit

Permalink
use &[f32] directly for all simd functions
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Oct 3, 2024
1 parent e4f1ec3 commit 0d969bd
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 80 deletions.
69 changes: 39 additions & 30 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 6 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use argh::FromArgs;
use env_logger::Env;
use log::{debug, info};
use rabitq::metrics::METRICS;
use rabitq::utils::{calculate_recall, matrix1d_from_vec, read_vecs};
use rabitq::utils::{calculate_recall, read_vecs};
use rabitq::RaBitQ;

#[derive(FromArgs, Debug)]
Expand Down Expand Up @@ -65,24 +65,19 @@ fn main() {
debug!("querying...");
let mut total_time = 0.0;
let mut recall = 0.0;
for (i, query) in queries.iter().enumerate() {
let query_vec = matrix1d_from_vec(query);
let total_num = queries.len();
for (i, query) in queries.into_iter().enumerate() {
let start_time = Instant::now();
let res = rabitq.query(
&query_vec.as_ref(),
args.probe,
args.topk,
args.heuristic_rank,
);
let res = rabitq.query(query, args.probe, args.topk, args.heuristic_rank);
total_time += start_time.elapsed().as_secs_f64();
let ids: Vec<i32> = res.iter().map(|(_, id)| *id as i32).collect();
recall += calculate_recall(&truth[i], &ids, args.topk);
}

info!(
"QPS: {}, recall: {}",
queries.len() as f64 / total_time,
recall / queries.len() as f32
total_num as f64 / total_time,
recall / total_num as f32
);
info!("Metrics [{}]", METRICS.to_str());
}
17 changes: 11 additions & 6 deletions src/rabitq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use core::f32;
use std::path::Path;

use faer::{Col, ColRef, Mat, Row};
use faer::{Col, Mat, Row};
use log::debug;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -208,7 +208,7 @@ impl RaBitQ {
let labels = labels
.into_iter()
.map(|mut v| {
v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
v.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("failed to compare labels"));
v.into_iter().map(|(i, _)| i).collect::<Vec<_>>()
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -243,18 +243,23 @@ impl RaBitQ {
/// Query the topk nearest neighbors for the given query.
pub fn query(
&self,
query: &ColRef<f32>,
query: Vec<f32>,
probe: usize,
topk: usize,
heuristic_rank: bool,
) -> Vec<(f32, u32)> {
assert_eq!(self.dim as usize, query.nrows());
let y_projected = project(query, &self.orthogonal.as_ref());
assert_eq!(self.dim as usize, query.len());
let y_projected = project(&query, &self.orthogonal.as_ref());
let k = self.centroids.shape().1;
let mut lists = Vec::with_capacity(k);
let mut residual = vec![0f32; self.dim as usize];
for (i, centroid) in self.centroids.col_iter().enumerate() {
let dist = l2_squared_distance(&centroid, &y_projected.as_ref());
let dist = l2_squared_distance(
centroid
.try_as_slice()
.expect("failed to get centroid slice"),
y_projected.as_slice(),
);
lists.push((dist, i));
}
let length = probe.min(k);
Expand Down
32 changes: 21 additions & 11 deletions src/rerank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::collections::BinaryHeap;

use faer::{Col, ColRef, MatRef};
use faer::MatRef;

use crate::consts::WINDOW_SIZE;
use crate::metrics::METRICS;
Expand All @@ -14,7 +14,7 @@ pub enum ReRanker {
Heuristic(HeuristicReRanker),
}

pub fn new_re_ranker(query: &ColRef<f32>, topk: usize, heuristic_rank: bool) -> ReRanker {
pub fn new_re_ranker(query: Vec<f32>, topk: usize, heuristic_rank: bool) -> ReRanker {
if heuristic_rank {
ReRanker::Heuristic(HeuristicReRanker::new(query, topk))
} else {
Expand Down Expand Up @@ -53,14 +53,14 @@ pub struct HeapReRanker {
threshold: f32,
topk: usize,
heap: BinaryHeap<(Ord32, AlwaysEqual<u32>)>,
query: Col<f32>,
query: Vec<f32>,
}

impl HeapReRanker {
fn new(query: &ColRef<f32>, topk: usize) -> Self {
fn new(query: Vec<f32>, topk: usize) -> Self {
Self {
threshold: f32::MAX,
query: query.to_owned(),
query,
topk,
heap: BinaryHeap::with_capacity(topk),
}
Expand All @@ -72,7 +72,12 @@ impl ReRankerTrait for HeapReRanker {
let mut precise = 0;
for &(rough, u) in rough_distances.iter() {
if rough < self.threshold {
let accurate = l2_squared_distance(&base.col(u as usize), &self.query.as_ref());
let accurate = l2_squared_distance(
base.col(u as usize)
.try_as_slice()
.expect("failed to get base slice"),
&self.query,
);
precise += 1;
if accurate < self.threshold {
self.heap
Expand All @@ -81,7 +86,7 @@ impl ReRankerTrait for HeapReRanker {
self.heap.pop();
}
if self.heap.len() == self.topk {
self.threshold = self.heap.peek().unwrap().0.into();
self.threshold = self.heap.peek().expect("failed to peek heap").0.into();
}
}
}
Expand All @@ -104,17 +109,17 @@ pub struct HeuristicReRanker {
recent_max_accurate: f32,
topk: usize,
array: Vec<(f32, u32)>,
query: Col<f32>,
query: Vec<f32>,
count: usize,
window_size: usize,
}

impl HeuristicReRanker {
fn new(query: &ColRef<f32>, topk: usize) -> Self {
fn new(query: Vec<f32>, topk: usize) -> Self {
Self {
threshold: f32::MAX,
recent_max_accurate: f32::MIN,
query: query.to_owned(),
query,
topk,
array: Vec::with_capacity(topk),
count: 0,
Expand All @@ -128,7 +133,12 @@ impl ReRankerTrait for HeuristicReRanker {
let mut precise = 0;
for &(rough, u) in rough_distances.iter() {
if rough < self.threshold {
let accurate = l2_squared_distance(&base.col(u as usize), &self.query.as_ref());
let accurate = l2_squared_distance(
base.col(u as usize)
.try_as_slice()
.expect("failed to get base slice"),
&self.query,
);
precise += 1;
if accurate < self.threshold {
self.array.push((accurate, map_ids[u as usize]));
Expand Down
Loading

0 comments on commit 0d969bd

Please sign in to comment.