Skip to content

Commit

Permalink
fix: divide 64 in packages (#4)
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored Dec 3, 2024
1 parent 1f1987a commit a9698f7
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 31 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
members = ["crates/*"]

[workspace.package]
version = "0.2.1"
version = "0.2.2"
edition = "2021"
description = "A Rust implementation of the RaBitQ vector search algorithm."
license = "AGPL-3.0"
Expand Down
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
packages := cli disk service

build:
cargo b

format:
@cargo +nightly fmt
@$(foreach package, $(packages), cargo +nightly fmt --package $(package);)

lint:
@cargo +nightly fmt -- --check
@cargo +nightly fmt --check
@$(foreach package, $(packages), cargo +nightly fmt --package $(package) --check;)
@cargo clippy -- -D warnings
@$(foreach package, $(packages), cargo clippy --package $(package) -- -D warnings;)

test:
@cargo test
2 changes: 1 addition & 1 deletion crates/disk/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl CachedVector {
let s3_client = Arc::new(Client::new(&s3_config));
let num_per_block = BLOCK_BYTE_LIMIT / (4 * (dim + 1));
let total_num = num;
let total_block = (total_num + num_per_block - 1) / num_per_block;
let total_block = total_num.div_ceil(num_per_block);
let sqlite_conn = Connection::open(Path::new(&local_path)).expect("failed to open sqlite");
sqlite_conn
.execute(
Expand Down
28 changes: 17 additions & 11 deletions crates/disk/src/disk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ use crate::cache::CachedVector;

/// Rank with cached raw vectors.
#[derive(Debug)]
pub struct CacheReRanker {
pub struct CacheReRanker<'a> {
threshold: f32,
topk: usize,
heap: BinaryHeap<(Ord32, AlwaysEqual<u32>)>,
query: Vec<f32>,
query: &'a [f32],
}

impl CacheReRanker {
fn new(query: Vec<f32>, topk: usize) -> Self {
impl<'a> CacheReRanker<'a> {
fn new(query: &'a [f32], topk: usize) -> Self {
Self {
threshold: f32::MAX,
query,
Expand All @@ -45,7 +45,7 @@ impl CacheReRanker {
for &(rough, u) in rough_distances.iter() {
if rough < self.threshold {
let accurate = cache
.get_query_vec_distance(&self.query, u)
.get_query_vec_distance(self.query, u)
.await
.expect("failed to get distance");
precise += 1;
Expand Down Expand Up @@ -142,11 +142,16 @@ impl DiskRaBitQ {

/// Query the topk nearest neighbors for the given query asynchronously.
pub async fn query(&self, query: Vec<f32>, probe: usize, topk: usize) -> Vec<(f32, u32)> {
assert_eq!(self.dim as usize, query.len());
let y_projected = project(&query, &self.orthogonal.as_ref());
assert_eq!(self.dim as usize, query.len().div_ceil(64) * 64);
// padding
let mut query_vec = query.to_vec();
if query.len() < self.dim as usize {
query_vec.extend_from_slice(&vec![0.0; self.dim as usize - query.len()]);
}

let y_projected = project(&query_vec, &self.orthogonal.as_ref());
let k = self.centroids.shape().1;
let mut lists = Vec::with_capacity(k);
let mut residual = vec![0f32; self.dim as usize];
for (i, centroid) in self.centroids.col_iter().enumerate() {
let dist = l2_squared_distance(
centroid
Expand All @@ -161,10 +166,11 @@ impl DiskRaBitQ {
lists.truncate(length);
lists.sort_by(|a, b| a.0.total_cmp(&b.0));

let mut re_ranker = CacheReRanker::new(query, topk);
let mut re_ranker = CacheReRanker::new(&query_vec, topk);
let mut residual = vec![0f32; self.dim as usize];
let mut quantized = vec![0u8; (self.dim as usize).div_ceil(64) * 64];
let mut rough_distances = Vec::new();
let mut quantized = vec![0u8; self.dim as usize];
let mut binary_vec = vec![0u64; self.dim as usize * THETA_LOG_DIM as usize / 64];
let mut binary_vec = vec![0u64; self.dim.div_ceil(64) as usize * THETA_LOG_DIM as usize];
for &(dist, i) in lists[..length].iter() {
let (lower_bound, upper_bound) =
min_max_residual(&mut residual, &y_projected.as_ref(), &self.centroids.col(i));
Expand Down
23 changes: 10 additions & 13 deletions crates/service/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ mod args;
async fn shutdown_signal() {
let mut interrupt = signal(SignalKind::interrupt()).unwrap();
let mut terminate = signal(SignalKind::terminate()).unwrap();
loop {
tokio::select! {
_ = interrupt.recv() => {
info!("Received interrupt signal");
break;
}
_ = terminate.recv() => {
info!("Received terminate signal");
break;
}
};
}
tokio::select! {
_ = interrupt.recv() => {
info!("Received interrupt signal");
}
_ = terminate.recv() => {
info!("Received terminate signal");
}
};
info!("Shutting down");
}

async fn health_check() -> impl IntoResponse {
Expand Down Expand Up @@ -75,7 +72,7 @@ async fn main() {

let config: args::Args = argh::from_env();
let model_path = Path::new(&config.dir);
download_meta_from_s3(&config.bucket, &config.key, &model_path)
download_meta_from_s3(&config.bucket, &config.key, model_path)
.await
.expect("failed to download meta");
let rabitq =
Expand Down

0 comments on commit a9698f7

Please sign in to comment.