Skip to content

Commit

Permalink
Merge branch 'dev' into franziskus/towards-ml-kem-c-extraction1
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer authored May 3, 2024
2 parents eddefae + 4aec2f2 commit 31b92aa
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 33 deletions.
16 changes: 16 additions & 0 deletions libcrux-ml-kem/examples/decapsulate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use libcrux_ml_kem::{kyber768, ENCAPS_SEED_SIZE, KEY_GENERATION_SEED_SIZE};
use rand::{rngs::OsRng, RngCore};

fn main() {
let mut randomness = [0u8; KEY_GENERATION_SEED_SIZE];
OsRng.fill_bytes(&mut randomness);

let key_pair = kyber768::generate_key_pair(randomness);
let mut randomness = [0u8; ENCAPS_SEED_SIZE];
OsRng.fill_bytes(&mut randomness);
let (ct, ss) = kyber768::encapsulate(key_pair.public_key(), randomness);

for _ in 0..100_000 {
let _ = kyber768::decapsulate(key_pair.private_key(), &ct);
}
}
10 changes: 10 additions & 0 deletions libcrux-ml-kem/examples/keygen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use libcrux_ml_kem::{kyber768, KEY_GENERATION_SEED_SIZE};
use rand::{rngs::OsRng, RngCore};

fn main() {
let mut randomness = [0u8; KEY_GENERATION_SEED_SIZE];
for _ in 0..100_000 {
OsRng.fill_bytes(&mut randomness);
let _ = kyber768::generate_key_pair(randomness);
}
}
4 changes: 3 additions & 1 deletion libcrux-ml-kem/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ pub(crate) fn compress_message_coefficient(fe: u16) -> u8 {
((shifted_positive_in_range >> 15) & 1) as u8
}

pub(crate) const CIPHERTEXT_COMPRESSION_MULTIPLIER: i32 = 10_321_340;

#[cfg_attr(hax,
hax_lib::requires(
(coefficient_bits == 4 ||
Expand All @@ -76,7 +78,7 @@ pub(crate) fn compress_ciphertext_coefficient(coefficient_bits: u8, fe: u16) ->
let mut compressed = (fe as u64) << coefficient_bits;
compressed += 1664 as u64;

compressed *= 10_321_340;
compressed *= CIPHERTEXT_COMPRESSION_MULTIPLIER as u64;
compressed >>= 35;

get_n_least_significant_bits(coefficient_bits, compressed as u32) as FieldElement
Expand Down
152 changes: 120 additions & 32 deletions libcrux-ml-kem/src/simd/simd256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
BARRETT_MULTIPLIER, BARRETT_R, BARRETT_SHIFT, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R,
MONTGOMERY_SHIFT,
},
compress::CIPHERTEXT_COMPRESSION_MULTIPLIER,
constants::FIELD_MODULUS,
simd::{portable, simd_trait::*},
};
Expand All @@ -16,10 +17,17 @@ pub(crate) struct SIMD256Vector {
#[allow(dead_code)]
fn print_m256i_as_i32s(a: __m256i, prefix: String) {
let mut a_bytes = [0i32; 8];
unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) };
unsafe { _mm256_storeu_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) };
println!("{}: {:?}", prefix, a_bytes);
}

#[allow(dead_code)]
fn print_m256i_as_i64s(a: __m256i, prefix: String) {
let mut a_bytes = [0i64; 4];
unsafe { _mm256_storeu_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) };
println!("{}: {:x?}", prefix, a_bytes);
}

#[allow(non_snake_case)]
#[inline(always)]
fn ZERO() -> SIMD256Vector {
Expand Down Expand Up @@ -134,6 +142,8 @@ fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector {
let mut t_high = _mm256_shuffle_epi32(v.elements, 0b00_11_00_01);
t_high = _mm256_mul_epi32(t_high, barrett_multiplier);
t_high = _mm256_add_epi64(t_high, barrett_r_halved);

// Right shift by 26, then left shift by 32
let quotient_high = _mm256_slli_epi64(t_high, 6);

let quotient = _mm256_blend_epi32(quotient_low, quotient_high, 0b1_0_1_0_1_0_1_0);
Expand All @@ -148,23 +158,16 @@ fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector {
#[inline(always)]
fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector {
let reduced = unsafe {
let montgomery_shift_mask = _mm256_set1_epi32((1 << MONTGOMERY_SHIFT) - 1);
let field_modulus = _mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
_mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

let t = _mm256_and_si256(v.elements, montgomery_shift_mask);
let t = _mm256_mullo_epi32(t, inverse_of_modulus_mod_montgomery_r);

let k = _mm256_and_si256(t, montgomery_shift_mask);
let k = _mm256_slli_epi32(k, 16);
let k = _mm256_srai_epi32(k, 16);

let k_times_modulus = _mm256_mullo_epi32(k, field_modulus);
let c = _mm256_srai_epi32(k_times_modulus, MONTGOMERY_SHIFT as i32);
let value_high = _mm256_srai_epi32(v.elements, MONTGOMERY_SHIFT as i32);

_mm256_sub_epi32(value_high, c)
let t = _mm256_mullo_epi16(v.elements, inverse_of_modulus_mod_montgomery_r);
let k_times_modulus = _mm256_mulhi_epi16(t, field_modulus);
let value_high = _mm256_srli_epi32(v.elements, MONTGOMERY_SHIFT as i32);
let res = _mm256_sub_epi16(value_high, k_times_modulus);
let res = _mm256_slli_epi32(res, 16);
_mm256_srai_epi32(res, 16)
};

SIMD256Vector { elements: reduced }
Expand All @@ -189,43 +192,106 @@ fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector {
}

#[inline(always)]
fn compress<const COEFFICIENT_BITS: i32>(v: SIMD256Vector) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::compress::<COEFFICIENT_BITS>(input);
fn compress<const COEFFICIENT_BITS: i32>(mut v: SIMD256Vector) -> SIMD256Vector {
let compressed = unsafe {
let field_modulus_halved = _mm256_set1_epi32((FIELD_MODULUS - 1) / 2);
let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1);
let multiplier = _mm256_set1_epi32(CIPHERTEXT_COMPRESSION_MULTIPLIER);

from_i32_array(portable::PortableVector::to_i32_array(output))
v.elements = _mm256_slli_epi32(v.elements, COEFFICIENT_BITS);
v.elements = _mm256_add_epi32(v.elements, field_modulus_halved);

let compressed_half_1 = _mm256_mul_epu32(v.elements, multiplier);
let compressed_half_1 = _mm256_srli_epi64(compressed_half_1, 35);

let compressed_half_2 = _mm256_shuffle_epi32(v.elements, 0b00_11_00_01);
let compressed_half_2 = _mm256_mul_epu32(compressed_half_2, multiplier);
// Right shift by 35, and then left shift by 32
let compressed_half_2 = _mm256_srli_epi64(compressed_half_2, 3);

let compressed =
_mm256_blend_epi32(compressed_half_1, compressed_half_2, 0b1_0_1_0_1_0_1_0);

_mm256_and_si256(compressed, coefficient_bits_mask)
};

SIMD256Vector {
elements: compressed,
}
}

#[inline(always)]
fn ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::ntt_layer_1_step(input, zeta1, zeta2);
let result = unsafe {
let zetas = _mm256_set_epi32(-zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, zeta1);
let zeta_multipliers = _mm256_shuffle_epi32(v.elements, 0b11_10_11_10);

from_i32_array(portable::PortableVector::to_i32_array(output))
let rhs = _mm256_mullo_epi32(zeta_multipliers, zetas);
let rhs = montgomery_reduce(SIMD256Vector { elements: rhs }).elements;

let lhs = _mm256_shuffle_epi32(v.elements, 0b01_00_01_00);

_mm256_add_epi32(rhs, lhs)
};

SIMD256Vector { elements: result }
}

#[inline(always)]
fn ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::ntt_layer_2_step(input, zeta);
let result = unsafe {
let zetas = _mm256_set_epi32(-zeta, -zeta, -zeta, -zeta, zeta, zeta, zeta, zeta);
let zeta_multipliers = _mm256_permute4x64_epi64(v.elements, 0b11_10_11_10);

from_i32_array(portable::PortableVector::to_i32_array(output))
let rhs = _mm256_mullo_epi32(zeta_multipliers, zetas);
let rhs = montgomery_reduce(SIMD256Vector { elements: rhs }).elements;

let lhs = _mm256_permute4x64_epi64(v.elements, 0b01_00_01_00);

_mm256_add_epi32(rhs, lhs)
};

SIMD256Vector { elements: result }
}

#[inline(always)]
fn inv_ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::inv_ntt_layer_1_step(input, zeta1, zeta2);
let result = unsafe {
let zetas = _mm256_set_epi32(zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0);

from_i32_array(portable::PortableVector::to_i32_array(output))
let add_by_signs = _mm256_set_epi32(-1, -1, 1, 1, -1, -1, 1, 1);
let add_by = _mm256_shuffle_epi32(v.elements, 0b01_00_11_10);
let add_by = _mm256_mullo_epi32(add_by, add_by_signs);

let sums = _mm256_add_epi32(v.elements, add_by);

let products = _mm256_mullo_epi32(sums, zetas);
let products_reduced = montgomery_reduce(SIMD256Vector { elements: products }).elements;

_mm256_blend_epi32(sums, products_reduced, 0b1_1_0_0_1_1_0_0)
};

SIMD256Vector { elements: result }
}

#[inline(always)]
fn inv_ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::inv_ntt_layer_2_step(input, zeta);
let result = unsafe {
let zetas = _mm256_set_epi32(zeta, zeta, zeta, zeta, 0, 0, 0, 0);

from_i32_array(portable::PortableVector::to_i32_array(output))
let add_by_signs = _mm256_set_epi32(-1, -1, -1, -1, 1, 1, 1, 1);
let add_by = _mm256_permute4x64_epi64(v.elements, 0b01_00_11_10);
let add_by = _mm256_mullo_epi32(add_by, add_by_signs);

let sums = _mm256_add_epi32(v.elements, add_by);

let products = _mm256_mullo_epi32(sums, zetas);
let products_reduced = montgomery_reduce(SIMD256Vector { elements: products }).elements;

_mm256_blend_epi32(sums, products_reduced, 0b1_1_1_1_0_0_0_0)
};

SIMD256Vector { elements: result }
}

#[inline(always)]
Expand Down Expand Up @@ -364,8 +430,30 @@ fn deserialize_5(v: &[u8]) -> SIMD256Vector {

#[inline(always)]
fn serialize_10(v: SIMD256Vector) -> [u8; 10] {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
portable::PortableVector::serialize_10(input)
let mut out = [0u8; 16];

unsafe {
let shifted = _mm256_sllv_epi32(v.elements, _mm256_set_epi32(10, 0, 10, 0, 10, 0, 10, 0));
let shifted = _mm256_shuffle_epi32(shifted, 0b_00_11_00_01);

let bits = _mm256_add_epi32(v.elements, shifted);

let bits = _mm256_shuffle_epi32(bits, 0b_00_00_10_00);
let bits = _mm256_sllv_epi32(bits, _mm256_set_epi32(0, 0, 0, 12, 0, 0, 0, 12));
let bits = _mm256_srli_epi64(bits, 12);

let bits = _mm256_permute4x64_epi64(bits, 0b00_00_10_00);
let shuffle_by = _mm256_set_epi8(
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
12, 11, 10, 9, 8, 4, 3, 2, 1, 0,
);

let bits_sequential = _mm256_shuffle_epi8(bits, shuffle_by);
let bits_sequential = _mm256_castsi256_si128(bits_sequential);
_mm_storeu_si128(out.as_mut_ptr() as *mut __m128i, bits_sequential);
};

out[0..10].try_into().unwrap()
}

#[inline(always)]
Expand Down

0 comments on commit 31b92aa

Please sign in to comment.