Skip to content

Commit

Permalink
add reduce f32 256
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Sep 19, 2024
1 parent 0e7cd6e commit 37c688f
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::consts::THETA_LOG_DIM;
/// This function is marked unsafe because it requires the AVX intrinsics.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma,avx")]
#[inline]
pub unsafe fn l2_squared_distance(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
Expand All @@ -23,7 +24,6 @@ pub unsafe fn l2_squared_distance(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
let mut rhs_ptr = rhs.as_ptr();
let block_16_num = lhs.nrows() >> 4;
let rest_num = lhs.nrows() & 0b1111;
let mut f32x8 = [0.0f32; 8];
let (mut diff, mut vx, mut vy): (__m256, __m256, __m256);
let mut sum = _mm256_setzero_ps();

Expand Down Expand Up @@ -51,10 +51,22 @@ pub unsafe fn l2_squared_distance(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
diff = _mm256_sub_ps(vx, vy);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
_mm256_store_ps(f32x8.as_mut_ptr(), sum);
let mut res =
f32x8[0] + f32x8[1] + f32x8[2] + f32x8[3] + f32x8[4] + f32x8[5] + f32x8[6] + f32x8[7];

#[inline(always)]
unsafe fn reduce_f32_256(accumulate: __m256) -> f32 {
// add [4..7] to [0..3]
let mut combined = _mm256_add_ps(
accumulate,
_mm256_permute2f128_ps(accumulate, accumulate, 1),
);
// add [0..3] to [0..1]
combined = _mm256_hadd_ps(combined, combined);
// add [0..1] to [0]
combined = _mm256_hadd_ps(combined, combined);
_mm256_cvtss_f32(combined)
}

let mut res = reduce_f32_256(sum);
for _ in 0..rest_num {
let residual = *lhs_ptr - *rhs_ptr;
res += residual * residual;
Expand All @@ -71,6 +83,7 @@ pub unsafe fn l2_squared_distance(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
/// This function is marked unsafe because it requires the AVX intrinsics.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx,avx2")]
#[inline]
pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) {
use std::arch::x86_64::*;

Expand Down Expand Up @@ -99,6 +112,7 @@ pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) {
/// This function is marked unsafe because it requires the AVX intrinsics.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx")]
#[inline]
pub unsafe fn min_max_residual(res: &mut [f32], x: &ColRef<f32>, y: &ColRef<f32>) -> (f32, f32) {
use std::arch::x86_64::*;

Expand Down Expand Up @@ -163,6 +177,7 @@ pub unsafe fn min_max_residual(res: &mut [f32], x: &ColRef<f32>, y: &ColRef<f32>
/// This function is marked unsafe because it requires the AVX intrinsics.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx,avx2")]
#[inline]
pub unsafe fn scalar_quantize(
quantized: &mut [u8],
vec: &[f32],
Expand Down Expand Up @@ -229,6 +244,7 @@ pub unsafe fn scalar_quantize(
/// This function is marked unsafe because it requires the AVX intrinsics.
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "fma,avx,avx2")]
#[inline]
pub unsafe fn vector_dot_product(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
use std::arch::x86_64::*;

Expand All @@ -238,7 +254,6 @@ pub unsafe fn vector_dot_product(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
let rest = length & 0b111;
let (mut vx, mut vy): (__m256, __m256);
let mut accumulate = _mm256_setzero_ps();
let mut f32x8 = [0.0f32; 8];

for _ in 0..(length / 16) {
vx = _mm256_loadu_ps(lhs_ptr);
Expand All @@ -260,9 +275,22 @@ pub unsafe fn vector_dot_product(lhs: &ColRef<f32>, rhs: &ColRef<f32>) -> f32 {
lhs_ptr = lhs_ptr.add(8);
rhs_ptr = rhs_ptr.add(8);
}
_mm256_storeu_ps(f32x8.as_mut_ptr(), accumulate);
let mut sum =
f32x8[0] + f32x8[1] + f32x8[2] + f32x8[3] + f32x8[4] + f32x8[5] + f32x8[6] + f32x8[7];

#[inline(always)]
unsafe fn reduce_f32_256(accumulate: __m256) -> f32 {
// add [4..7] to [0..3]
let mut combined = _mm256_add_ps(
accumulate,
_mm256_permute2f128_ps(accumulate, accumulate, 1),
);
// add [0..3] to [0..1]
combined = _mm256_hadd_ps(combined, combined);
// add [0..1] to [0]
combined = _mm256_hadd_ps(combined, combined);
_mm256_cvtss_f32(combined)
}

let mut sum = reduce_f32_256(accumulate);

for _ in 0..rest {
sum += *lhs_ptr * *rhs_ptr;
Expand Down

0 comments on commit 37c688f

Please sign in to comment.