From 100738063a61a070eac31b3e8b9626e6cd413233 Mon Sep 17 00:00:00 2001 From: Keming Date: Tue, 3 Dec 2024 19:31:30 +0800 Subject: [PATCH] fix: x86 simd (#5) Signed-off-by: Keming --- src/simd.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/simd.rs b/src/simd.rs index c563421..aca6f93 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -81,6 +81,9 @@ pub unsafe fn l2_squared_distance(lhs: &[f32], rhs: &[f32]) -> f32 { #[target_feature(enable = "avx,avx2")] #[inline] pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; let length = vec.len(); @@ -112,6 +115,9 @@ pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) { #[target_feature(enable = "avx")] #[inline] pub unsafe fn min_max_residual(res: &mut [f32], x: &[f32], y: &[f32]) -> (f32, f32) { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; let mut min_32x8 = _mm256_set1_ps(f32::MAX); @@ -182,6 +188,9 @@ pub unsafe fn scalar_quantize( lower_bound: f32, multiplier: f32, ) -> u32 { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; let mut quantize_ptr = quantized.as_mut_ptr() as *mut u64; @@ -246,6 +255,9 @@ pub unsafe fn scalar_quantize( #[target_feature(enable = "fma,avx,avx2")] #[inline] pub unsafe fn vector_dot_product(lhs: &[f32], rhs: &[f32]) -> f32 { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; let mut lhs_ptr = lhs.as_ptr(); @@ -312,6 +324,9 @@ pub unsafe fn vector_dot_product(lhs: &[f32], rhs: &[f32]) -> f32 { #[target_feature(enable = "sse2,avx,avx2")] #[inline] pub unsafe fn binary_dot_product(lhs: &[u64], rhs: &[u64]) -> u32 { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; let mut sum = 0;