Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support dim that cannot be diveded by 64 #2

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
- [x] HTTP service
- [ ] insert & update & delete
- [ ] cosine similarity distance
- [ ] early stop
2 changes: 1 addition & 1 deletion crates/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fn main() {
let mut total_time = 0.0;
let mut recall = 0.0;
let total_num = queries.len();
for (i, query) in queries.into_iter().enumerate() {
for (i, query) in queries.iter().enumerate() {
let start_time = Instant::now();
let res = rabitq.query(query, args.probe, args.topk, args.heuristic_rank);
total_time += start_time.elapsed().as_secs_f64();
Expand Down
47 changes: 36 additions & 11 deletions src/rabitq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ impl RaBitQ {
.collect();

let dim = orthogonal.nrows();
assert!(dim % 64 == 0);
let base = matrix_from_fvecs(&path.join("base.fvecs"))
.transpose()
.to_owned();
Expand Down Expand Up @@ -156,12 +157,30 @@ impl RaBitQ {

/// Build the RaBitQ model from the base and centroids files.
pub fn from_path(base_path: &Path, centroid_path: &Path) -> Self {
let base = matrix_from_fvecs(base_path);
let (n, dim) = base.shape();
let centroids = matrix_from_fvecs(centroid_path);
let k = centroids.shape().0;
let mut base = matrix_from_fvecs(base_path);
let n = base.nrows();
let mut dim = base.ncols();
let mut centroids = matrix_from_fvecs(centroid_path);
let k = centroids.nrows();
assert!(dim == centroids.ncols());

// padding to 64
if dim % 64 != 0 {
let dim_pad = dim.div_ceil(64) * 64;
base = Mat::from_fn(n, dim_pad, |i, j| match j < dim {
true => base.read(i, j),
false => 0.0,
});
centroids = Mat::from_fn(k, dim_pad, |i, j| match j < dim {
true => centroids.read(i, j),
false => 0.0,
});
dim = dim_pad;
}

debug!("n: {}, dim: {}, k: {}", n, dim, k);
let orthogonal = gen_random_qr_orthogonal(dim);
// let orthogonal = Mat::identity(dim, dim);
let rand_bias = gen_random_bias(dim);

// projection
Expand Down Expand Up @@ -248,16 +267,21 @@ impl RaBitQ {
/// Query the topk nearest neighbors for the given query.
pub fn query(
&self,
query: Vec<f32>,
query: &[f32],
probe: usize,
topk: usize,
heuristic_rank: bool,
) -> Vec<(f32, u32)> {
assert_eq!(self.dim as usize, query.len());
let y_projected = project(&query, &self.orthogonal.as_ref());
assert_eq!(self.dim as usize, query.len().div_ceil(64) * 64);
// padding
let mut query_vec = query.to_vec();
if query.len() < self.dim as usize {
query_vec.extend_from_slice(&vec![0.0; self.dim as usize - query.len()]);
}

let y_projected = project(&query_vec, &self.orthogonal.as_ref());
let k = self.centroids.shape().1;
let mut lists = Vec::with_capacity(k);
let mut residual = vec![0f32; self.dim as usize];
for (i, centroid) in self.centroids.col_iter().enumerate() {
let dist = l2_squared_distance(
centroid
Expand All @@ -272,10 +296,11 @@ impl RaBitQ {
lists.truncate(length);
lists.sort_by(|a, b| a.0.total_cmp(&b.0));

let mut re_ranker = new_re_ranker(query, topk, heuristic_rank);
let mut re_ranker = new_re_ranker(&query_vec, topk, heuristic_rank);
let mut residual = vec![0f32; self.dim as usize];
let mut quantized = vec![0u8; (self.dim as usize).div_ceil(64) * 64];
let mut rough_distances = Vec::new();
let mut quantized = vec![0u8; self.dim as usize];
let mut binary_vec = vec![0u64; self.dim as usize * THETA_LOG_DIM as usize / 64];
let mut binary_vec = vec![0u64; (self.dim).div_ceil(64) as usize * THETA_LOG_DIM as usize];
for &(dist, i) in lists[..length].iter() {
let (lower_bound, upper_bound) =
min_max_residual(&mut residual, &y_projected.as_ref(), &self.centroids.col(i));
Expand Down
34 changes: 17 additions & 17 deletions src/rerank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ use crate::ord32::{AlwaysEqual, Ord32};
use crate::utils::l2_squared_distance;

/// ReRanker enum.
pub enum ReRanker {
pub enum ReRanker<'a> {
/// ReRanker with heap.
Heap(HeapReRanker),
Heap(HeapReRanker<'a>),
/// ReRanker with heuristic.
Heuristic(HeuristicReRanker),
Heuristic(HeuristicReRanker<'a>),
}

/// Create a new re-ranker.
pub fn new_re_ranker(query: Vec<f32>, topk: usize, heuristic_rank: bool) -> ReRanker {
pub fn new_re_ranker(query: &[f32], topk: usize, heuristic_rank: bool) -> ReRanker {
if heuristic_rank {
ReRanker::Heuristic(HeuristicReRanker::new(query, topk))
} else {
ReRanker::Heap(HeapReRanker::new(query, topk))
}
}

impl ReRanker {
impl ReRanker<'_> {
/// Rank a batch of items.
pub fn rank_batch(
&mut self,
Expand Down Expand Up @@ -59,15 +59,15 @@ pub trait ReRankerTrait {

/// Rank with heap.
#[derive(Debug)]
pub struct HeapReRanker {
pub struct HeapReRanker<'a> {
threshold: f32,
topk: usize,
heap: BinaryHeap<(Ord32, AlwaysEqual<u32>)>,
query: Vec<f32>,
query: &'a [f32],
}

impl HeapReRanker {
fn new(query: Vec<f32>, topk: usize) -> Self {
impl<'a> HeapReRanker<'a> {
fn new(query: &'a [f32], topk: usize) -> Self {
Self {
threshold: f32::MAX,
query,
Expand All @@ -77,7 +77,7 @@ impl HeapReRanker {
}
}

impl ReRankerTrait for HeapReRanker {
impl ReRankerTrait for HeapReRanker<'_> {
fn rank_batch(&mut self, rough_distances: &[(f32, u32)], base: &MatRef<f32>, map_ids: &[u32]) {
let mut precise = 0;
for &(rough, u) in rough_distances.iter() {
Expand All @@ -86,7 +86,7 @@ impl ReRankerTrait for HeapReRanker {
base.col(u as usize)
.try_as_slice()
.expect("failed to get base slice"),
&self.query,
self.query,
);
precise += 1;
if accurate < self.threshold {
Expand Down Expand Up @@ -115,18 +115,18 @@ impl ReRankerTrait for HeapReRanker {

/// Rank in a heuristic way.
#[derive(Debug)]
pub struct HeuristicReRanker {
pub struct HeuristicReRanker<'a> {
threshold: f32,
recent_max_accurate: f32,
topk: usize,
array: Vec<(f32, u32)>,
query: Vec<f32>,
query: &'a [f32],
count: usize,
window_size: usize,
}

impl HeuristicReRanker {
fn new(query: Vec<f32>, topk: usize) -> Self {
impl<'a> HeuristicReRanker<'a> {
fn new(query: &'a [f32], topk: usize) -> Self {
Self {
threshold: f32::MAX,
recent_max_accurate: f32::MIN,
Expand All @@ -139,7 +139,7 @@ impl HeuristicReRanker {
}
}

impl ReRankerTrait for HeuristicReRanker {
impl ReRankerTrait for HeuristicReRanker<'_> {
fn rank_batch(&mut self, rough_distances: &[(f32, u32)], base: &MatRef<f32>, map_ids: &[u32]) {
let mut precise = 0;
for &(rough, u) in rough_distances.iter() {
Expand All @@ -148,7 +148,7 @@ impl ReRankerTrait for HeuristicReRanker {
base.col(u as usize)
.try_as_slice()
.expect("failed to get base slice"),
&self.query,
self.query,
);
precise += 1;
if accurate < self.threshold {
Expand Down
7 changes: 4 additions & 3 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rand_distr::StandardNormal;

use crate::consts::THETA_LOG_DIM;

/// Generate a random orthogonal matrix from QR decomposition.
/// Generate a random orthogonal matrix from the standard normal distribution QR decomposition.
pub fn gen_random_qr_orthogonal(dim: usize) -> Mat<f32> {
let mut rng = rand::thread_rng();
let random: Mat<f32> = Mat::from_fn(dim, dim, |_, _| StandardNormal.sample(&mut rng));
Expand Down Expand Up @@ -51,7 +51,7 @@ pub fn matrix_from_fvecs(path: &Path) -> Mat<f32> {
/// Convert the vector to binary format and store in a u64 vector.
#[inline]
pub fn vector_binarize_u64(vec: &ColRef<f32>) -> Vec<u64> {
let mut binary = vec![0u64; (vec.nrows() + 63) / 64];
let mut binary = vec![0u64; vec.nrows().div_ceil(64)];
for (i, &v) in vec.iter().enumerate() {
if v > 0.0 {
binary[i / 64] |= 1 << (i % 64);
Expand Down Expand Up @@ -199,7 +199,8 @@ fn scalar_quantize_raw(
multiplier: f32,
) -> u32 {
let mut sum = 0u32;
for i in 0..quantized.len() {
assert!(vec.len() <= quantized.len());
for i in 0..vec.len() {
let q = ((vec[i] - lower_bound) * multiplier + bias[i]) as u8;
quantized[i] = q;
sum += q as u32;
Expand Down