Skip to content

Commit

Permalink
rewrite binary_dot_product with simd
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Sep 15, 2024
1 parent 8f962f7 commit edabd4a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
64 changes: 64 additions & 0 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,67 @@ pub unsafe fn vector_dot_product(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {

sum
}

/// Compute the binary dot product of two vectors.
///
/// Refer to: https://github.com/komrad36/popcount
///
/// # Safety
///
/// This function is marked unsafe because it requires the AVX2 intrinsics.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2,avx,avx2")]
pub unsafe fn binary_dot_product(lhs: &[u64], rhs: &[u64]) -> u32 {
use std::arch::x86_64::*;

let mut x_ptr = lhs.as_ptr() as *const __m256i;
let mut y_ptr = rhs.as_ptr() as *const __m256i;

let length = lhs.len() / 4;
let rest = lhs.len() & 0b11;
let lookup_table = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, // 0-7
1, 2, 2, 3, 2, 3, 3, 4, // 8-15
0, 1, 1, 2, 1, 2, 2, 3, // 16-23
1, 2, 2, 3, 2, 3, 3, 4, // 24-31
);
let mask = _mm256_set1_epi8(15);
let zero = _mm256_setzero_si256();

#[inline]
unsafe fn mm256_popcnt_epi64(
x: __m256i,
lookup_table: __m256i,
mask: __m256i,
zero: __m256i,
) -> __m256i {
use std::arch::x86_64::*;

let mut low = _mm256_and_si256(x, mask);
let mut high = _mm256_and_si256(_mm256_srli_epi64(x, 4), mask);
low = _mm256_shuffle_epi8(lookup_table, low);
high = _mm256_shuffle_epi8(lookup_table, high);
_mm256_sad_epu8(_mm256_add_epi8(low, high), zero)
}

let mut sum256 = _mm256_setzero_si256();
for _ in 0..length {
let x256 = _mm256_loadu_si256(x_ptr);
let y256 = _mm256_loadu_si256(y_ptr);
let and = _mm256_and_si256(x256, y256);
sum256 = _mm256_add_epi64(sum256, mm256_popcnt_epi64(and, lookup_table, mask, zero));
x_ptr = x_ptr.add(1);
y_ptr = y_ptr.add(1);
}

let xa = _mm_add_epi64(
_mm256_castsi256_si128(sum256),
_mm256_extracti128_si256(sum256, 1),
);
let mut sum = _mm_cvtsi128_si64(_mm_add_epi64(xa, _mm_shuffle_epi32(xa, 78))) as u32;
for i in 0..rest {
sum += (lhs[4 * length + i] & rhs[4 * length + i]).count_ones();
}

sum
}
15 changes: 14 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,20 @@ pub fn asymmetric_binary_dot_product(x: &[u64], y: &[u64]) -> u32 {
let length = x.len();
let mut y_slice = y;
for i in 0..THETA_LOG_DIM as usize {
res += binary_dot_product(x, y_slice) << i;
res += {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { crate::simd::binary_dot_product(x, y_slice) << i }
} else {
binary_dot_product(x, y_slice) << i
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
{
binary_dot_product(x, y_slice) << i
}
};
y_slice = &y_slice[length..];
}
res
Expand Down

0 comments on commit edabd4a

Please sign in to comment.