Skip to content

Commit

Permalink
Merge pull request #1391 from akoshelev/rp25519-less
Browse files Browse the repository at this point in the history
Optimize Ristretto points handling in IPA
  • Loading branch information
akoshelev authored Nov 4, 2024
2 parents d5cc8f8 + bedf3ee commit bbf2deb
Showing 1 changed file with 114 additions and 56 deletions.
170 changes: 114 additions & 56 deletions ipa-core/src/ff/curve_points.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::sync::OnceLock;

use curve25519_dalek::{
ristretto::{CompressedRistretto, RistrettoPoint},
Scalar,
};
use generic_array::GenericArray;
use typenum::U32;
use typenum::{U128, U32};

use crate::{
ff::{ec_prime_field::Fp25519, Serializable},
Expand All @@ -16,19 +18,42 @@ impl Block for CompressedRistretto {
type Size = U32;
}

///ristretto point for curve 25519,
/// we store it in compressed format since it is 3 times smaller and we do a limited amount of
/// arithmetic operations on the curve points
impl Block for RistrettoPoint {
type Size = U128;
}

/// Ristretto point for curve 25519, stored in uncompressed format for efficient
/// additions and multiplications.
///
/// We use ristretto points such that we have a prime order elliptic curve,
/// This is needed for the Dodis Yampolski PRF
/// We use Ristretto points such that we have a prime order elliptic curve,
/// This is needed for the Dodis Yampolski PRF.
///
/// decompressing invalid curve points will cause panics,
/// since we always generate curve points from scalars (elements in Fp25519) and
/// only deserialize previously serialized valid points, panics will not occur
/// However, we still added a debug assert to deserialize since values are sent by other servers
///
/// ## Memory/CPU tradeoff
/// We optimize for CPU utilization because invert operations are expensive.
/// Previous implementation used compressed format and we ended up with
/// 10 compress and decompress operations per cycle/row. Storing Ristretto
/// points in uncompressed format allowed us to go down to 3: one in serialize,
/// one in deserialize and one in hash. The only reason why we compress in hash
/// is to get access to raw bytes representation - potentially we could use an
/// API that does not exist in curve25519 crate.
///
/// This tradeoff means that it is highly recommended to avoid collecting
/// many of those points into a vector or any other collection.
/// For PRF evaluation, all of those points
/// are ephemeral (computed once per PRF cycle and then dropped), so this schema
/// works fine as there are only limited number of Ristretto points present
/// in memory at any given time.
///
/// Also, one need to be considerate of stack usage when thinking about vectorizing
/// their operations. Putting too many of those points together may blow up the stack
/// faster. A potential solution would be to keep the compressed view
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct RP25519(CompressedRistretto);
pub struct RP25519(RistrettoRepr);

impl Default for RP25519 {
fn default() -> Self {
Expand All @@ -38,9 +63,9 @@ impl Default for RP25519 {

/// Implementing trait for secret sharing
impl SharedValue for RP25519 {
type Storage = CompressedRistretto;
const BITS: u32 = 256;
const ZERO: Self = Self(CompressedRistretto([0_u8; 32]));
type Storage = RistrettoPoint;
const BITS: u32 = 1024;
const ZERO: Self = Self(RistrettoRepr::Zero);

impl_shared_value_common!();
}
Expand All @@ -58,31 +83,25 @@ impl Vectorizable<PRF_CHUNK> for RP25519 {
pub struct NonCanonicalEncoding(CompressedRistretto);

impl Serializable for RP25519 {
type Size = <<RP25519 as SharedValue>::Storage as Block>::Size;
type Size = <CompressedRistretto as Block>::Size;
type DeserializationError = NonCanonicalEncoding;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
*buf.as_mut() = self.0.to_bytes();
*buf.as_mut() = self.0.as_point().compress().to_bytes();
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
let point = CompressedRistretto((*buf).into());
if cfg!(debug_assertions) && point.decompress().is_none() {
return Err(NonCanonicalEncoding(point));
}

Ok(RP25519(point))
let point = point.decompress().ok_or(NonCanonicalEncoding(point))?;
Ok(Self::from(point))
}
}

///## Panics
/// Panics when decompressing invalid curve point. This can happen when deserialize curve point
/// from bit array that does not have a valid representation on the curve
impl std::ops::Add for RP25519 {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
Self((self.0.decompress().unwrap() + rhs.0.decompress().unwrap()).compress())
Self((self.0.as_point() + rhs.0.as_point()).into())
}
}

Expand All @@ -93,25 +112,19 @@ impl std::ops::AddAssign for RP25519 {
}
}

///## Panics
/// Panics when decompressing invalid curve point. This can happen when deserialize curve point
/// from bit array that does not have a valid representation on the curve
impl std::ops::Neg for RP25519 {
type Output = Self;

fn neg(self) -> Self::Output {
Self(self.0.decompress().unwrap().neg().compress())
Self(self.0.as_point().neg().into())
}
}

///## Panics
/// Panics when decompressing invalid curve point. This can happen when deserialize curve point
/// from bit array that does not have a valid representation on the curve
impl std::ops::Sub for RP25519 {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
Self((self.0.decompress().unwrap() - rhs.0.decompress().unwrap()).compress())
Self((self.0.as_point() - rhs.0.as_point()).into())
}
}

Expand All @@ -122,18 +135,13 @@ impl std::ops::SubAssign for RP25519 {
}
}

///Scalar Multiplication
/// Scalar Multiplication
/// allows to multiply curve points with scalars from Fp25519
///## Panics
/// Panics when decompressing invalid curve point. This can happen when deserialize curve point
/// from bit array that does not have a valid representation on the curve
impl std::ops::Mul<Fp25519> for RP25519 {
type Output = Self;

fn mul(self, rhs: Fp25519) -> RP25519 {
(self.0.decompress().unwrap() * Scalar::from(rhs))
.compress()
.into()
fn mul(self, rhs: Fp25519) -> Self {
Self((self.0.as_point() * Scalar::from(rhs)).into())
}
}

Expand All @@ -146,25 +154,19 @@ impl std::ops::MulAssign<Fp25519> for RP25519 {

impl From<Scalar> for RP25519 {
fn from(s: Scalar) -> Self {
RP25519(RistrettoPoint::mul_base(&s).compress())
Self::from(Fp25519::from(s))
}
}

impl From<Fp25519> for RP25519 {
fn from(s: Fp25519) -> Self {
RP25519(RistrettoPoint::mul_base(&s.into()).compress())
}
}

impl From<CompressedRistretto> for RP25519 {
fn from(s: CompressedRistretto) -> Self {
RP25519(s)
Self((RistrettoPoint::mul_base(&s.into())).into())
}
}

impl From<RP25519> for CompressedRistretto {
fn from(s: RP25519) -> Self {
s.0
impl From<RistrettoPoint> for RP25519 {
fn from(s: RistrettoPoint) -> Self {
Self(RistrettoRepr::Point(s))
}
}

Expand All @@ -175,7 +177,7 @@ macro_rules! cp_hash_impl {
fn from(s: RP25519) -> Self {
use hkdf::Hkdf;
use sha2::Sha256;
let hk = Hkdf::<Sha256>::new(None, s.0.as_bytes());
let hk = Hkdf::<Sha256>::new(None, s.0.as_point().compress().as_bytes());
let mut okm = <$u_type>::MIN.to_le_bytes();
//error invalid length from expand only happens when okm is very large
hk.expand(&[], &mut okm).unwrap();
Expand All @@ -185,8 +187,6 @@ macro_rules! cp_hash_impl {
};
}

cp_hash_impl!(u128);

cp_hash_impl!(u64);

/// implementing random curve point generation for testing purposes,
Expand All @@ -196,7 +196,64 @@ impl rand::distributions::Distribution<RP25519> for rand::distributions::Standar
fn sample<R: crate::rand::Rng + ?Sized>(&self, rng: &mut R) -> RP25519 {
let mut scalar_bytes = [0u8; 64];
rng.fill_bytes(&mut scalar_bytes);
RP25519(RistrettoPoint::from_uniform_bytes(&scalar_bytes).compress())
RP25519(RistrettoPoint::from_uniform_bytes(&scalar_bytes).into())
}
}

/// Internal representation of Ristretto point, suitable
/// for our needs.
/// Due to constraints imposed by dalek crate,
/// we can't construct a zero uncompressed Ristretto point
/// at compile time. It is only possible to construct
/// a compressed ristretto from 0 byte array in const context.
/// We work around that limitation by adding an enum value
/// that represents a zero value.
#[derive(Clone, Copy, Eq, Debug)]
enum RistrettoRepr {
/// In PRF code this path is never used as
/// we always construct Ristretto points from scalars.
/// Constructing this value and attempting to use it
/// as Ristretto point will panic
Zero,
Point(RistrettoPoint),
}

impl PartialEq for RistrettoRepr {
fn eq(&self, other: &Self) -> bool {
self.as_point().eq(other.as_point())
}
}

impl RistrettoRepr {
pub fn as_point(&self) -> &RistrettoPoint {
match self {
Self::Zero => {
if cfg!(test) {
static INSTANCE: OnceLock<RistrettoPoint> = OnceLock::new();
INSTANCE.get_or_init(|| {
let zero = CompressedRistretto([0_u8; 32]);
// we could also cache the compressed Ristretto, if we end up
// sending a lot of zeroes
zero.decompress().unwrap()
})
} else {
// We debated whether we should support it or no,
// and decided not to. There is a valid concern about
// keeping arithmetics on Ristretto point constant-time
// and short-cutting Zero representation has obvious problems
// and someone measuring the time it takes to multiply may
// guess correctly that one of the arguments was zero.
unimplemented!("Zero repr is not supported.")
}
}
Self::Point(p) => p,
}
}
}

impl From<RistrettoPoint> for RistrettoRepr {
fn from(value: RistrettoPoint) -> Self {
Self::Point(value)
}
}

Expand Down Expand Up @@ -232,8 +289,8 @@ mod test {
let b: RP25519 = a.into();
let d: Fp25519 = a.into();
let c: RP25519 = RP25519::from(d);
assert_eq!(b, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_eq!(c, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_eq!(b, RP25519::from(constants::RISTRETTO_BASEPOINT_POINT));
assert_eq!(c, RP25519::from(constants::RISTRETTO_BASEPOINT_POINT));
}

///testing simple curve arithmetics to check that `curve25519_dalek` library is used correctly
Expand All @@ -245,14 +302,15 @@ mod test {
let fp_c = fp_a + fp_b;
let fp_d = RP25519::from(fp_a) + RP25519::from(fp_b);
assert_eq!(fp_d, RP25519::from(fp_c));
assert_ne!(fp_d, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_ne!(fp_d, RP25519::from(constants::RISTRETTO_BASEPOINT_POINT));
let fp_e = rng.gen::<Fp25519>();
let fp_f = rng.gen::<Fp25519>();
let fp_g = fp_e * fp_f;
let fp_h = RP25519::from(fp_e) * fp_f;
assert_eq!(fp_h, RP25519::from(fp_g));
assert_ne!(fp_h, RP25519(constants::RISTRETTO_BASEPOINT_COMPRESSED));
assert_ne!(fp_h, RP25519::from(constants::RISTRETTO_BASEPOINT_POINT));
assert_eq!(RP25519::ZERO, fp_h * Scalar::ZERO.into());
assert_eq!(fp_h, fp_h + RP25519::ZERO);
}

///testing curve to unsigned integer conversion has entropy (!= 0)
Expand Down

0 comments on commit bbf2deb

Please sign in to comment.