Skip to content

Commit

Permalink
add comments for vector binarize, fix client write fvecs
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Oct 15, 2024
1 parent 761d305 commit 0e087df
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 0 additions & 2 deletions crates/disk/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ impl CachedVector {
local_path: String,
s3_bucket: String,
s3_prefix: String,
// _mem_cache_num: u32,
// _disk_cache_mb: u32,
) -> Self {
let s3_config = aws_config::defaults(BehaviorVersion::v2024_03_28())
.load()
Expand Down
2 changes: 1 addition & 1 deletion scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def write_vec(filepath: str, vecs: np.ndarray, vec_type: np.dtype = np.float32):
"""Write vectors to a file. Support `fvecs`, `ivecs` and `bvecs` format."""
with open(filepath, "wb") as f:
for vec in vecs:
f.write(pack(len(vec), "<i"))
f.write(pack("<i", len(vec)))
f.write(vec.tobytes())


Expand Down
10 changes: 7 additions & 3 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ pub unsafe fn vector_binarize_query(vec: &[u8], binary: &mut [u64]) {
// since it's not guaranteed that the vec is fully-aligned
let mut v = _mm256_loadu_si256(ptr);
ptr = ptr.add(1);
// only the lower 4 bits are useful due to the 4-bit scalar quantization
v = _mm256_slli_epi32(v, 4);
for j in 0..THETA_LOG_DIM as usize {
// extract the MSB of each u8
let mask = (_mm256_movemask_epi8(v) as u32) as u64;
// let shift = if (i / 32) % 2 == 0 { 32 } else { 0 };
let shift = ((i >> 5) & 1) << 5;
// (opposite version) let shift = if (i / 32) % 2 == 0 { 32 } else { 0 };
let shift = i & 32;
binary[(3 - j) * (length >> 6) + (i >> 6)] |= mask << shift;
v = _mm256_slli_epi32(v, 1);
}
Expand Down Expand Up @@ -190,7 +192,9 @@ pub unsafe fn scalar_quantize(
let scalar = _mm256_set1_ps(multiplier);
let mut sum256 = _mm256_setzero_si256();
let mask = _mm256_setr_epi8(
0, 4, 8, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 4, 8, 12, -1, -1, -1, -1,
0, 4, 8, 12, -1, -1, -1, -1, //
-1, -1, -1, -1, -1, -1, -1, -1, //
0, 4, 8, 12, -1, -1, -1, -1, //
-1, -1, -1, -1, -1, -1, -1, -1,
);
let length = vec.len();
Expand Down

0 comments on commit 0e087df

Please sign in to comment.