From a4a55abc903358da44ba986af515048ab5436a8a Mon Sep 17 00:00:00 2001 From: xvzcf Date: Fri, 31 May 2024 21:10:19 +0200 Subject: [PATCH] More Kyber AVX2-serialization documentation. --- libcrux-ml-kem/src/vector/avx2/serialize.rs | 184 +++++++++++++++++--- 1 file changed, 155 insertions(+), 29 deletions(-) diff --git a/libcrux-ml-kem/src/vector/avx2/serialize.rs b/libcrux-ml-kem/src/vector/avx2/serialize.rs index b377d543a..4192e95b9 100644 --- a/libcrux-ml-kem/src/vector/avx2/serialize.rs +++ b/libcrux-ml-kem/src/vector/avx2/serialize.rs @@ -2,8 +2,17 @@ use super::*; #[inline(always)] pub(crate) fn serialize_1(vector: Vec256) -> [u8; 2] { + // Suppose |vector| is laid out as follows (superscript number indicates the + // corresponding bit is duplicated that many times): + // + // 0¹⁵a₀ 0¹⁵b₀ 0¹⁵c₀ 0¹⁵d₀ | 0¹⁵e₀ 0¹⁵f₀ 0¹⁵g₀ 0¹⁵h₀ | ... + // // We care only about the least significant bit in each lane, // move it to the most significant position to make it easier to work with. + // |vector| now becomes: + // + // a₀0¹⁵ b₀0¹⁵ c₀0¹⁵ d₀0¹⁵ | e₀0¹⁵ f₀0¹⁵ g₀0¹⁵ h₀0¹⁵ | ↩ + // i₀0¹⁵ j₀0¹⁵ k₀0¹⁵ l₀0¹⁵ | m₀0¹⁵ n₀0¹⁵ o₀0¹⁵ p₀0¹⁵ let lsb_to_msb = mm256_slli_epi16::<15>(vector); // Get the first 8 16-bit elements ... @@ -15,16 +24,26 @@ pub(crate) fn serialize_1(vector: Vec256) -> [u8; 2] { // ... and then pack them into 8-bit values using signed saturation. // This function packs all the |low_msbs|, and then the high ones. // - // We shifted by 15 above to take advantage of signed saturation: + // + // low_msbs = a₀0¹⁵ b₀0¹⁵ c₀0¹⁵ d₀0¹⁵ | e₀0¹⁵ f₀0¹⁵ g₀0¹⁵ h₀0¹⁵ + // high_msbs = i₀0¹⁵ j₀0¹⁵ k₀0¹⁵ l₀0¹⁵ | m₀0¹⁵ n₀0¹⁵ o₀0¹⁵ p₀0¹⁵ + // + // We shifted by 15 above to take advantage of the signed saturation performed + // by mm_packs_epi16: // // - if the sign bit of the 16-bit element being packed is 1, the // corresponding 8-bit element in |msbs| will be 0xFF. // - if the sign bit of the 16-bit element being packed is 0, the // corresponding 8-bit element in |msbs| will be 0. + // + // Thus, if, for example, a₀ = 1, e₀ = 1, and p₀ = 1, and every other bit + // is 0, after packing into 8 bit value, |msbs| will look like: + // + // 0xFF 0x00 0x00 0x00 | 0xFF 0x00 0x00 0x00 | 0x00 0x00 0x00 0x00 | 0x00 0x00 0x00 0xFF let msbs = mm_packs_epi16(low_msbs, high_msbs); - // Now that we have all 16 bits we need conveniently placed in one vector, - // extract them into two bytes. + // Now that every element is either 0xFF or 0x00, we just extract the most + // significant bit from each element and collate them into two bytes. let bits_packed = mm_movemask_epi8(msbs); let mut serialized = [0u8; 2]; @@ -41,8 +60,8 @@ pub(crate) fn deserialize_1(bytes: &[u8]) -> Vec256 { // duplicate them, and right-shift the 0th element by 0 bits, // the first element by 1 bit, the second by 2 bits and so on before AND-ing // with 0x1 to leave only the least signifinicant bit. - // But |_mm256_srlv_epi16| does not exist unfortunately, so we have to resort - // to a workaround. + // But since |_mm256_srlv_epi16| does not exist, so we have to resort to a + // workaround. // // Rather than shifting each element by a different amount, we'll multiply // each element by a value such that the bit we're interested in becomes the most @@ -161,28 +180,21 @@ pub(crate) fn serialize_4(vector: Vec256) -> [u8; 8] { #[inline(always)] pub(crate) fn deserialize_4(bytes: &[u8]) -> Vec256 { - let shift_lsbs_to_msbs = mm256_set_epi16( - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - ); - + // Every 4 bits from each byte of input should be put into its own 16-bit lane. + // Since |_mm256_srlv_epi16| does not exist, we have to resort to a workaround. + // + // Rather than shifting each element by a different amount, we'll multiply + // each element by a value such that the bits we're interested in become the most + // significant bits (of an 8-bit value). let coefficients = mm256_set_epi16( + // In this lane, the 4 bits we need to put are already the most + // significant bits of |bytes[7]|. bytes[7] as i16, + // In this lane, the 4 bits we need to put are the least significant bits, + // so we need to shift the 4 least-significant bits of |bytes[7]| to the + // most significant bits (of an 8-bit value). bytes[7] as i16, + // and so on ... bytes[6] as i16, bytes[6] as i16, bytes[5] as i16, @@ -199,9 +211,34 @@ pub(crate) fn deserialize_4(bytes: &[u8]) -> Vec256 { bytes[0] as i16, ); + let shift_lsbs_to_msbs = mm256_set_epi16( + // These constants are chosen to shift the bits of the values + // that we loaded into |coefficients|. + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + + // Once the 4-bit coefficients are in the most significant positions (of + // an 8-bit value), shift them all down by 4. let coefficients_in_lsb = mm256_srli_epi16::<4>(coefficients_in_msb); + // Zero the remaining bits. mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 4) - 1)) } @@ -209,6 +246,18 @@ pub(crate) fn deserialize_4(bytes: &[u8]) -> Vec256 { pub(crate) fn serialize_5(vector: Vec256) -> [u8; 10] { let mut serialized = [0u8; 32]; + // If |vector| is laid out as follows (superscript number indicates the + // corresponding bit is duplicated that many times): + // + // 0¹¹a₄a₃a₂a₁a₀ 0¹¹b₄b₃b₂b₁b₀ 0¹¹c₄c₃c₂c₁c₀ 0¹¹d₄d₃d₂d₁d₀ | ↩ + // 0¹¹e₄e₃e₂e₁e₀ 0¹¹f₄f₃f₂f₁f₀ 0¹¹g₄g₃g₂g₁g₀ 0¹¹h₄h₃h₂h₁h₀ | ↩ + // + // |adjacent_2_combined| will be laid out as a series of 32-bit integers, + // as follows: + // + // 0²²b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩ + // 0²²f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩ + // .... let adjacent_2_combined = mm256_madd_epi16( vector, mm256_set_epi16( @@ -231,23 +280,60 @@ pub(crate) fn serialize_5(vector: Vec256) -> [u8; 10] { ), ); + // Recall that |adjacent_2_combined| is laid out as follows: + // + // 0²²b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩ + // 0²²f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩ + // .... + // + // This shift results in: + // + // b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀0²² 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀ | ↩ + // f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀0²² 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀ | ↩ + // .... + // let adjacent_4_combined = mm256_sllv_epi32( adjacent_2_combined, mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), ); + + // |adjacent_4_combined|, when viewed as 64-bit lanes, is: + // + // 0²²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀0²² | ↩ + // 0²²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀0²² | ↩ + // ... + // + // so we just shift down by 22 bits to remove the least significant 0 bits + // that aren't part of the bits we need. let adjacent_4_combined = mm256_srli_epi64::<22>(adjacent_4_combined); + // |adjacent_4_combined|, when viewed as a set of 32-bit values, looks like: + // + // 0:0¹²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 1:0³² 2:0¹²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 3:0³² | ↩ + // + // To be able to read out the bytes in one go, we need to shifts the bits in + // position 2 to position 1 in each 128-bit lane. let adjacent_8_combined = mm256_shuffle_epi32::<0b00_00_10_00>(adjacent_4_combined); + + // |adjacent_8_combined|, when viewed as a set of 32-bit values, now looks like: + // + // 0¹²d₄d₃d₂d₁d₀c₄c₃c₂c₁c₀b₄b₃b₂b₁b₀a₄a₃a₂a₁a₀ 0¹²h₄h₃h₂h₁h₀g₄g₃g₂g₁g₀f₄f₃f₂f₁f₀e₄e₃e₂e₁e₀ 0³² 0³² | ↩ + // + // Once again, we line these bits up by shifting the up values at indices + // 0 and 5 by 12, viewing the resulting register as a set of 64-bit values, + // and then shifting down the 64-bit values by 12 bits. let adjacent_8_combined = mm256_sllv_epi32( adjacent_8_combined, - mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + mm256_set_epi32(0, 0, 0, 12, 0, 0, 0, 12), ); let adjacent_8_combined = mm256_srli_epi64::<12>(adjacent_8_combined); + // We now have 40 bits starting at position 0 in the lower 128-bit lane, ... let lower_8 = mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); - mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + + // ... and the second 40 bits at position 0 in the upper 128-bit lane + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); mm_storeu_bytes_si128(&mut serialized[5..21], upper_8); serialized[0..10].try_into().unwrap() @@ -299,6 +385,19 @@ pub(crate) fn deserialize_5(bytes: &[u8]) -> Vec256 { pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] { let mut serialized = [0u8; 32]; + // If |vector| is laid out as follows (superscript number indicates the + // corresponding bit is duplicated that many times): + // + // 0⁶a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ 0⁶b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀ 0⁶c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ 0⁶d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀ | ↩ + // 0⁶e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ 0⁶f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀ 0⁶g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ 0⁶h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀ | ↩ + // ... + // + // |adjacent_2_combined| will be laid out as a series of 32-bit integers, + // as follows: + // + // 0¹²b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ | ↩ + // 0¹²f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ | ↩ + // .... let adjacent_2_combined = mm256_madd_epi16( vector, mm256_set_epi16( @@ -321,12 +420,37 @@ pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] { ), ); + // Shifting up the values at the even indices by 12, we get: + // + // b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀0¹² 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀ | ↩ + // f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀0¹² 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀ | ↩ + // ... let adjacent_4_combined = mm256_sllv_epi32( adjacent_2_combined, mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), ); + + // Viewing this as a set of 64-bit integers we get: + // + // 0¹²d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀0¹² | ↩ + // 0¹²h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀0¹² | ↩ + // ... + // + // Shifting down by 12 gives us: + // + // 0²⁴d₉d₈d₇d₆d₅d₄d₃d₂d₁d₀c₉c₈c₇c₆c₅c₄c₃c₂c₁c₀b₉b₈b₇b₆b₅b₄b₃b₂b₁b₀a₉a₈a₇a₆a₅a₄a₃a₂a₁a₀ | ↩ + // 0²⁴h₉h₈h₇h₆h₅h₄h₃h₂h₁h₀g₉g₈g₇g₆g₅g₄g₃g₂g₁g₀f₉f₈f₇f₆f₅f₄f₃f₂f₁f₀e₉e₈e₇e₆e₅e₄e₃e₂e₁e₀ | ↩ + // ... let adjacent_4_combined = mm256_srli_epi64::<12>(adjacent_4_combined); + // |adjacent_4_combined|, when the bottom and top 128 bit-lanes are grouped + // into bytes, looks like: + // + // 0₇0₆0₅B₄B₃B₂B₁B₀ | ↩ + // 0₁₅0₁₄0₁₃B₁₂B₁₁B₁₀B₉B₈ | ↩ + // + // In each 128-bit lane, we want to put bytes 8, 9, 10, 11, 12 after + // bytes 0, 1, 2, 3 to allow for sequential reading. let adjacent_8_combined = mm256_shuffle_epi8( adjacent_4_combined, mm256_set_epi8( @@ -335,10 +459,12 @@ pub(crate) fn serialize_10(vector: Vec256) -> [u8; 20] { ), ); + // We now have 64 bits starting at position 0 in the lower 128-bit lane, ... let lower_8 = mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); - mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + + // and 64 bits starting at position 0 in the upper 128-bit lane. + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); mm_storeu_bytes_si128(&mut serialized[10..26], upper_8); serialized[0..20].try_into().unwrap()