Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: divide 64 in packages #4

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
members = ["crates/*"]

[workspace.package]
version = "0.2.1"
version = "0.2.2"
edition = "2021"
description = "A Rust implementation of the RaBitQ vector search algorithm."
license = "AGPL-3.0"
Expand Down
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
packages := cli disk service

build:
cargo b

format:
@cargo +nightly fmt
@$(foreach package, $(packages), cargo +nightly fmt --package $(package);)

lint:
@cargo +nightly fmt -- --check
@cargo +nightly fmt --check
@$(foreach package, $(packages), cargo +nightly fmt --package $(package) --check;)
@cargo clippy -- -D warnings
@$(foreach package, $(packages), cargo clippy --package $(package) -- -D warnings;)

test:
@cargo test
2 changes: 1 addition & 1 deletion crates/disk/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl CachedVector {
let s3_client = Arc::new(Client::new(&s3_config));
let num_per_block = BLOCK_BYTE_LIMIT / (4 * (dim + 1));
let total_num = num;
let total_block = (total_num + num_per_block - 1) / num_per_block;
let total_block = total_num.div_ceil(num_per_block);
let sqlite_conn = Connection::open(Path::new(&local_path)).expect("failed to open sqlite");
sqlite_conn
.execute(
Expand Down
28 changes: 17 additions & 11 deletions crates/disk/src/disk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ use crate::cache::CachedVector;

/// Rank with cached raw vectors.
#[derive(Debug)]
pub struct CacheReRanker {
pub struct CacheReRanker<'a> {
threshold: f32,
topk: usize,
heap: BinaryHeap<(Ord32, AlwaysEqual<u32>)>,
query: Vec<f32>,
query: &'a [f32],
}

impl CacheReRanker {
fn new(query: Vec<f32>, topk: usize) -> Self {
impl<'a> CacheReRanker<'a> {
fn new(query: &'a [f32], topk: usize) -> Self {
Self {
threshold: f32::MAX,
query,
Expand All @@ -45,7 +45,7 @@ impl CacheReRanker {
for &(rough, u) in rough_distances.iter() {
if rough < self.threshold {
let accurate = cache
.get_query_vec_distance(&self.query, u)
.get_query_vec_distance(self.query, u)
.await
.expect("failed to get distance");
precise += 1;
Expand Down Expand Up @@ -142,11 +142,16 @@ impl DiskRaBitQ {

/// Query the topk nearest neighbors for the given query asynchronously.
pub async fn query(&self, query: Vec<f32>, probe: usize, topk: usize) -> Vec<(f32, u32)> {
assert_eq!(self.dim as usize, query.len());
let y_projected = project(&query, &self.orthogonal.as_ref());
assert_eq!(self.dim as usize, query.len().div_ceil(64) * 64);
// padding
let mut query_vec = query.to_vec();
if query.len() < self.dim as usize {
query_vec.extend_from_slice(&vec![0.0; self.dim as usize - query.len()]);
}

let y_projected = project(&query_vec, &self.orthogonal.as_ref());
let k = self.centroids.shape().1;
let mut lists = Vec::with_capacity(k);
let mut residual = vec![0f32; self.dim as usize];
for (i, centroid) in self.centroids.col_iter().enumerate() {
let dist = l2_squared_distance(
centroid
Expand All @@ -161,10 +166,11 @@ impl DiskRaBitQ {
lists.truncate(length);
lists.sort_by(|a, b| a.0.total_cmp(&b.0));

let mut re_ranker = CacheReRanker::new(query, topk);
let mut re_ranker = CacheReRanker::new(&query_vec, topk);
let mut residual = vec![0f32; self.dim as usize];
let mut quantized = vec![0u8; (self.dim as usize).div_ceil(64) * 64];
let mut rough_distances = Vec::new();
let mut quantized = vec![0u8; self.dim as usize];
let mut binary_vec = vec![0u64; self.dim as usize * THETA_LOG_DIM as usize / 64];
let mut binary_vec = vec![0u64; self.dim.div_ceil(64) as usize * THETA_LOG_DIM as usize];
for &(dist, i) in lists[..length].iter() {
let (lower_bound, upper_bound) =
min_max_residual(&mut residual, &y_projected.as_ref(), &self.centroids.col(i));
Expand Down
23 changes: 10 additions & 13 deletions crates/service/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,15 @@ mod args;
async fn shutdown_signal() {
let mut interrupt = signal(SignalKind::interrupt()).unwrap();
let mut terminate = signal(SignalKind::terminate()).unwrap();
loop {
tokio::select! {
_ = interrupt.recv() => {
info!("Received interrupt signal");
break;
}
_ = terminate.recv() => {
info!("Received terminate signal");
break;
}
};
}
tokio::select! {
_ = interrupt.recv() => {
info!("Received interrupt signal");
}
_ = terminate.recv() => {
info!("Received terminate signal");
}
};
info!("Shutting down");
}

async fn health_check() -> impl IntoResponse {
Expand Down Expand Up @@ -75,7 +72,7 @@ async fn main() {

let config: args::Args = argh::from_env();
let model_path = Path::new(&config.dir);
download_meta_from_s3(&config.bucket, &config.key, &model_path)
download_meta_from_s3(&config.bucket, &config.key, model_path)
.await
.expect("failed to download meta");
let rabitq =
Expand Down
Loading