Skip to content

Commit

Permalink
fix: x86 simd (#5)
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored Dec 3, 2024
1 parent a9698f7 commit 1007380
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ pub unsafe fn l2_squared_distance(lhs: &[f32], rhs: &[f32]) -> f32 {
#[target_feature(enable = "avx,avx2")]
#[inline]
pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

let length = vec.len();
Expand Down Expand Up @@ -112,6 +115,9 @@ pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) {
#[target_feature(enable = "avx")]
#[inline]
pub unsafe fn min_max_residual(res: &mut [f32], x: &[f32], y: &[f32]) -> (f32, f32) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

let mut min_32x8 = _mm256_set1_ps(f32::MAX);
Expand Down Expand Up @@ -182,6 +188,9 @@ pub unsafe fn scalar_quantize(
lower_bound: f32,
multiplier: f32,
) -> u32 {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

let mut quantize_ptr = quantized.as_mut_ptr() as *mut u64;
Expand Down Expand Up @@ -246,6 +255,9 @@ pub unsafe fn scalar_quantize(
#[target_feature(enable = "fma,avx,avx2")]
#[inline]
pub unsafe fn vector_dot_product(lhs: &[f32], rhs: &[f32]) -> f32 {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

let mut lhs_ptr = lhs.as_ptr();
Expand Down Expand Up @@ -312,6 +324,9 @@ pub unsafe fn vector_dot_product(lhs: &[f32], rhs: &[f32]) -> f32 {
#[target_feature(enable = "sse2,avx,avx2")]
#[inline]
pub unsafe fn binary_dot_product(lhs: &[u64], rhs: &[u64]) -> u32 {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

let mut sum = 0;
Expand Down

0 comments on commit 1007380

Please sign in to comment.