Skip to content

Commit

Permalink
use dot_product simd to replace faer mul
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Sep 15, 2024
1 parent 0411821 commit e5fd6ec
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
14 changes: 6 additions & 8 deletions src/rabitq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use core::f32;
use std::collections::BinaryHeap;
use std::path::Path;

use faer::{Col, ColRef, Mat, MatRef, Row, RowRef};
use faer::{Col, ColRef, Mat, MatRef, Row};
use log::debug;
// use nalgebra::{DMatrix, DMatrixView, DVector, DVectorView};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -177,20 +177,20 @@ fn scalar_quantize(
/// Project the vector to the orthogonal matrix.
#[allow(dead_code)]
#[inline]
fn project(vec: &RowRef<f32>, orthogonal: &MatRef<f32>) -> Row<f32> {
fn project(vec: &ColRef<f32>, orthogonal: &MatRef<f32>) -> Col<f32> {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx2") {
Row::from_fn(orthogonal.ncols(), |i| unsafe {
Col::from_fn(orthogonal.ncols(), |i| unsafe {
crate::simd::vector_dot_product(vec, &orthogonal.col(i))
})
} else {
vec * orthogonal
(vec.transpose() * orthogonal).transpose().to_owned()
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
{
vec * orthogonal
(vec.transpose() * orthogonal).transpose().to_owned()
}
}

Expand Down Expand Up @@ -418,9 +418,7 @@ impl RaBitQ {
topk: usize,
heuristic_rank: bool,
) -> Vec<(f32, u32)> {
let y_projected = (query.transpose() * &self.orthogonal)
.transpose()
.to_owned();
let y_projected = project(query, &self.orthogonal.as_ref());
let k = self.centroids.shape().1;
let mut lists = Vec::with_capacity(k);
let mut residual = vec![0f32; self.dim as usize];
Expand Down
6 changes: 3 additions & 3 deletions src/simd.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Accelerate with SIMD.
use faer::{ColRef, RowRef};
use faer::ColRef;

use crate::consts::THETA_LOG_DIM;

Expand Down Expand Up @@ -229,12 +229,12 @@ pub unsafe fn scalar_quantize(
/// This function is marked unsafe because it requires the AVX intrinsics.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma,avx,avx2")]
pub unsafe fn vector_dot_product(lhs: &RowRef<f32>, rhs: &ColRef<f32>) -> f32 {
pub unsafe fn vector_dot_product(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
use std::arch::x86_64::*;

let mut lhs_ptr = lhs.as_ptr();
let mut rhs_ptr = rhs.as_ptr();
let length = lhs.ncols();
let length = lhs.nrows();
let rest = length & 0b111;
let (mut vx, mut vy): (__m256, __m256);
let mut accumulate = _mm256_setzero_ps();
Expand Down

0 comments on commit e5fd6ec

Please sign in to comment.