Skip to content

Commit

Permalink
Merge pull request #293 from cryspen/goutam/more-kyber-avx2-docs
Browse files Browse the repository at this point in the history
More Kyber AVX2-serialization documentation.
  • Loading branch information
franziskuskiefer committed Jun 5, 2024
2 parents 29f4d91 + a4a55ab commit 1202ef2
Showing 1 changed file with 155 additions and 29 deletions.
184 changes: 155 additions & 29 deletions libcrux-ml-kem/src/vector/avx2/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@ use super::*;

#[inline(always)]
pub(crate) fn serialize_1(vector: Vec256) -> [u8; 2] {
// Suppose |vector| is laid out as follows (superscript number indicates the
// corresponding bit is duplicated that many times):
//
// 0¹⁵a₀ 0¹⁵b₀ 0¹⁵c₀ 0¹⁵d₀ | 0¹⁵e₀ 0¹⁵f₀ 0¹⁵g₀ 0¹⁵h₀ | ...
//
// We care only about the least significant bit in each lane,
// move it to the most significant position to make it easier to work with.
// |vector| now becomes:
//
// a₀0¹⁵ b₀0¹⁵ c₀0¹⁵ d₀0¹⁵ | e₀0¹⁵ f₀0¹⁵ g₀0¹⁵ h₀0¹⁵ | ↩
// i₀0¹⁵ j₀0¹⁵ k₀0¹⁵ l₀0¹⁵ | m₀0¹⁵ n₀0¹⁵ o₀0¹⁵ p₀0¹⁵
let lsb_to_msb = mm256_slli_epi16::<15>(vector);

// Get the first 8 16-bit elements ...
Expand All @@ -15,16 +24,26 @@ pub(crate) fn serialize_1(vector: Vec256) -> [u8; 2] {
// ... and then pack them into 8-bit values using signed saturation.
// This function packs all the |low_msbs|, and then the high ones.
//
// We shifted by 15 above to take advantage of signed saturation:
//
// low_msbs = a₀0¹⁵ b₀0¹⁵ c₀0¹⁵ d₀0¹⁵ | e₀0¹⁵ f₀0¹⁵ g₀0¹⁵ h₀0¹⁵
// high_msbs = i₀0¹⁵ j₀0¹⁵ k₀0¹⁵ l₀0¹⁵ | m₀0¹⁵ n₀0¹⁵ o₀0¹⁵ p₀0¹⁵
//
// We shifted by 15 above to take advantage of the signed saturation performed
// by mm_packs_epi16:
//
// - if the sign bit of the 16-bit element being packed is 1, the
// corresponding 8-bit element in |msbs| will be 0xFF.
// - if the sign bit of the 16-bit element being packed is 0, the
// corresponding 8-bit element in |msbs| will be 0.
//
// Thus, if, for example, a₀ = 1, e₀ = 1, and p₀ = 1, and every other bit
// is 0, after packing into 8 bit value, |msbs| will look like:
//
// 0xFF 0x00 0x00 0x00 | 0xFF 0x00 0x00 0x00 | 0x00 0x00 0x00 0x00 | 0x00 0x00 0x00 0xFF
let msbs = mm_packs_epi16(low_msbs, high_msbs);

// Now that we have all 16 bits we need conveniently placed in one vector,
// extract them into two bytes.
// Now that every element is either 0xFF or 0x00, we just extract the most
// significant bit from each element and collate them into two bytes.
let bits_packed = mm_movemask_epi8(msbs);

let mut serialized = [0u8; 2];
Expand All @@ -41,8 +60,8 @@ pub(crate) fn deserialize_1(bytes: &[u8]) -> Vec256 {
// duplicate them, and right-shift the 0th element by 0 bits,
// the first element by 1 bit, the second by 2 bits and so on before AND-ing
// with 0x1 to leave only the least signifinicant bit.
// But |_mm256_srlv_epi16| does not exist unfortunately, so we have to resort
// to a workaround.
// But since |_mm256_srlv_epi16| does not exist, so we have to resort to a
// workaround.
//
// Rather than shifting each element by a different amount, we'll multiply
// each element by a value such that the bit we're interested in becomes the most
Expand Down Expand Up @@ -161,28 +180,21 @@ pub(crate) fn serialize_4(vector: Vec256) -> [u8; 8] {

#[inline(always)]
pub(crate) fn deserialize_4(bytes: &[u8]) -> Vec256 {
let shift_lsbs_to_msbs = mm256_set_epi16(
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
);

// Every 4 bits from each byte of input should be put into its own 16-bit lane.
// Since |_mm256_srlv_epi16| does not exist, we have to resort to a workaround.
//
// Rather than shifting each element by a different amount, we'll multiply
// each element by a value such that the bits we're interested in become the most
// significant bits (of an 8-bit value).
let coefficients = mm256_set_epi16(
// In this lane, the 4 bits we need to put are already the most
// significant bits of |bytes[7]|.
bytes[7] as i16,
// In this lane, the 4 bits we need to put are the least significant bits,
// so we need to shift the 4 least-significant bits of |bytes[7]| to the
// most significant bits (of an 8-bit value).
bytes[7] as i16,
// and so on ...
bytes[6] as i16,
bytes[6] as i16,
bytes[5] as i16,
Expand All @@ -199,16 +211,53 @@ pub(crate) fn deserialize_4(bytes: &[u8]) -> Vec256 {
bytes[0] as i16,
);

let shift_lsbs_to_msbs = mm256_set_epi16(
// These constants are chosen to shift the bits of the values
// that we loaded into |coefficients|.
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
1 << 0,
1 << 4,
);

let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs);

// Once the 4-bit coefficients are in the most significant positions (of
// an 8-bit value), shift them all down by 4.
let coefficients_in_lsb = mm256_srli_epi16::<4>(coefficients_in_msb);

// Zero the remaining bits.
mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 4) - 1))
}

#[inline(always)]
pub(crate) fn serialize_5(vector: Vec256) -> [u8; 10] {
let mut serialized = [0u8; 32];

// If |vector| is laid out as follows (superscript number indicates the
// corresponding bit is duplicated that many times):
//
// 0¹¹a₄a₃a₂a₁a₀ 0¹¹b₄b₃b₂b₁b₀ 0¹¹c₄c₃c₂c₁c₀ 0¹¹d₄d₃d₂d₁d₀ | ↩
// 0¹¹e₄e₃e₂e₁e₀ 0¹¹f₄f₃f₂f₁f₀ 0¹¹g₄g₃g₂g₁g₀ 0¹¹h₄h₃h₂h₁h₀ | ↩
//
// |adjacent_2_combined| will be laid out as a series of 32-bit integers,
// as follows:
//
// 0²²b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩
// 0²²f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩
// ....
let adjacent_2_combined = mm256_madd_epi16(
vector,
mm256_set_epi16(
Expand All @@ -231,23 +280,60 @@ pub(crate) fn serialize_5(vector: Vec256) -> [u8; 10] {
),
);

// Recall that |adjacent_2_combined| is laid out as follows:
//
// 0²²b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩
// 0²²f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩
// ....
//
// This shift results in:
//
// b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀0²² 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩
// f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀0²² 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩
// ....
//
let adjacent_4_combined = mm256_sllv_epi32(
adjacent_2_combined,
mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22),
);

// |adjacent_4_combined|, when viewed as 64-bit lanes, is:
//
// 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀0²² | ↩
// 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀0²² | ↩
// ...
//
// so we just shift down by 22 bits to remove the least significant 0 bits
// that aren't part of the bits we need.
let adjacent_4_combined = mm256_srli_epi64::<22>(adjacent_4_combined);

// |adjacent_4_combined|, when viewed as a set of 32-bit values, looks like:
//
// 0:0¹²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 1:0³² 2:0¹²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 3:0³² | ↩
//
// To be able to read out the bytes in one go, we need to shifts the bits in
// position 2 to position 1 in each 128-bit lane.
let adjacent_8_combined = mm256_shuffle_epi32::<0b00_00_10_00>(adjacent_4_combined);

// |adjacent_8_combined|, when viewed as a set of 32-bit values, now looks like:
//
// 0¹²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0¹²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0³² 0³² | ↩
//
// Once again, we line these bits up by shifting the up values at indices
// 0 and 5 by 12, viewing the resulting register as a set of 64-bit values,
// and then shifting down the 64-bit values by 12 bits.
let adjacent_8_combined = mm256_sllv_epi32(
adjacent_8_combined,
mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12),
mm256_set_epi32(0, 0, 0, 12, 0, 0, 0, 12),
);
let adjacent_8_combined = mm256_srli_epi64::<12>(adjacent_8_combined);

// We now have 40 bits starting at position 0 in the lower 128-bit lane, ...
let lower_8 = mm256_castsi256_si128(adjacent_8_combined);
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);

mm_storeu_bytes_si128(&mut serialized[0..16], lower_8);

// ... and the second 40 bits at position 0 in the upper 128-bit lane
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);
mm_storeu_bytes_si128(&mut serialized[5..21], upper_8);

serialized[0..10].try_into().unwrap()
Expand Down Expand Up @@ -299,6 +385,19 @@ pub(crate) fn deserialize_5(bytes: &[u8]) -> Vec256 {
pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] {
let mut serialized = [0u8; 32];

// If |vector| is laid out as follows (superscript number indicates the
// corresponding bit is duplicated that many times):
//
// 0⁶a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ 0⁶b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀ 0⁶c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ 0⁶d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀ | ↩
// 0⁶e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ 0⁶f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀ 0⁶g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ 0⁶h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀ | ↩
// ...
//
// |adjacent_2_combined| will be laid out as a series of 32-bit integers,
// as follows:
//
// 0¹²b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ | ↩
// 0¹²f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ | ↩
// ....
let adjacent_2_combined = mm256_madd_epi16(
vector,
mm256_set_epi16(
Expand All @@ -321,12 +420,37 @@ pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] {
),
);

// Shifting up the values at the even indices by 12, we get:
//
// b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀0¹² 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ | ↩
// f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀0¹² 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ | ↩
// ...
let adjacent_4_combined = mm256_sllv_epi32(
adjacent_2_combined,
mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12),
);

// Viewing this as a set of 64-bit integers we get:
//
// 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀0¹² | ↩
// 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀0¹² | ↩
// ...
//
// Shifting down by 12 gives us:
//
// 0²⁴d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ | ↩
// 0²⁴h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ | ↩
// ...
let adjacent_4_combined = mm256_srli_epi64::<12>(adjacent_4_combined);

// |adjacent_4_combined|, when the bottom and top 128 bit-lanes are grouped
// into bytes, looks like:
//
// 0₇0₆0₅B₄B₃B₂B₁B₀ | ↩
// 0₁₅0₁₄0₁₃B₁₂B₁₁B₁₀B₉B₈ | ↩
//
// In each 128-bit lane, we want to put bytes 8, 9, 10, 11, 12 after
// bytes 0, 1, 2, 3 to allow for sequential reading.
let adjacent_8_combined = mm256_shuffle_epi8(
adjacent_4_combined,
mm256_set_epi8(
Expand All @@ -335,10 +459,12 @@ pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] {
),
);

// We now have 64 bits starting at position 0 in the lower 128-bit lane, ...
let lower_8 = mm256_castsi256_si128(adjacent_8_combined);
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);

mm_storeu_bytes_si128(&mut serialized[0..16], lower_8);

// and 64 bits starting at position 0 in the upper 128-bit lane.
let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined);
mm_storeu_bytes_si128(&mut serialized[10..26], upper_8);

serialized[0..20].try_into().unwrap()
Expand Down

0 comments on commit 1202ef2

Please sign in to comment.