diff --git a/src/algorithms/rsa.rs b/src/algorithms/rsa.rs index d8e871bb..d743d407 100644 --- a/src/algorithms/rsa.rs +++ b/src/algorithms/rsa.rs @@ -1,7 +1,7 @@ //! Generic RSA implementation use alloc::vec::Vec; -use crypto_bigint::modular::BoxedResidueParams; +use crypto_bigint::modular::{BoxedResidue, BoxedResidueParams}; use crypto_bigint::{BoxedUint, RandomMod}; use num_bigint::{BigUint, IntoBigInt, IntoBigUint, ModInverse, ToBigInt}; use num_integer::{sqrt, Integer}; @@ -11,7 +11,7 @@ use subtle::CtOption; use zeroize::{Zeroize, Zeroizing}; use crate::errors::{Error, Result}; -use crate::key::{reduce, to_biguint, to_uint}; +use crate::key::{reduce, to_biguint, to_uint_exact}; use crate::traits::keys::{PrivateKeyPartsNew, PublicKeyPartsNew}; use crate::traits::PublicKeyParts; @@ -40,11 +40,13 @@ pub fn rsa_decrypt( priv_key: &impl PrivateKeyPartsNew, c_orig: &BigUint, ) -> Result { - // convert to crypto bigint - let c = to_uint(c_orig.clone()); let n = priv_key.n(); + let nbits = n.bits_precision(); + let c = to_uint_exact(c_orig.clone(), nbits); let d = priv_key.d(); + std::dbg!(nbits, d.bits_precision(), c.bits_precision()); + if c >= **n { return Err(Error::Decryption); } @@ -124,10 +126,10 @@ pub fn rsa_decrypt( c.zeroize(); m2.zeroize(); - to_uint(m.into_biguint().expect("failed to decrypt")) + to_uint_exact(m.into_biguint().expect("failed to decrypt"), nbits) } _ => { - let c = reduce(&c, n_params); + let c = reduce(&c, n_params.clone()); c.pow(&d).retrieve() } }; @@ -135,7 +137,7 @@ pub fn rsa_decrypt( match ir { Some(ref ir) => { // unblind - let res = to_biguint(&unblind(priv_key, &m, ir)); + let res = to_biguint(&unblind(&m, ir, n_params)); Ok(res) } None => Ok(to_biguint(&m)), @@ -204,7 +206,7 @@ fn blind( let c = { let r = reduce(&r, n_params.clone()); let mut rpowe = r.pow(key.e()).retrieve(); - let c = c.mul_mod(&rpowe, key.n()); + let c = mul_mod_params(c, &rpowe, n_params.clone()); rpowe.zeroize(); c @@ -213,9 +215,17 @@ fn blind( (c, unblinder) } +/// Computes `lhs.mul_mod(rhs, n)` with precomputed `n_param`. +fn mul_mod_params(lhs: &BoxedUint, rhs: &BoxedUint, n_params: BoxedResidueParams) -> BoxedUint { + // TODO: nicer api in crypto-bigint? + let lhs = BoxedResidue::new(lhs, n_params.clone()); + let rhs = BoxedResidue::new(rhs, n_params); + (lhs * rhs).retrieve() +} + /// Given an m and and unblinding factor, unblind the m. -fn unblind(key: &impl PublicKeyPartsNew, m: &BoxedUint, unblinder: &BoxedUint) -> BoxedUint { - m.mul_mod(unblinder, key.n()) +fn unblind(m: &BoxedUint, unblinder: &BoxedUint, n_params: BoxedResidueParams) -> BoxedUint { + mul_mod_params(m, unblinder, n_params) } /// The following (deterministic) algorithm also recovers the prime factors `p` and `q` of a modulus `n`, given the diff --git a/src/key.rs b/src/key.rs index fb15e84d..b70d9f6f 100644 --- a/src/key.rs +++ b/src/key.rs @@ -200,11 +200,13 @@ impl RsaPublicKey { /// Create a new public key from its components. pub fn new_with_max_size(n: BigUint, e: BigUint, max_size: usize) -> Result { - let k = Self { - n: NonZero::new(to_uint(n)).unwrap(), - e: to_uint(e), - }; - check_public_with_max_size(&k, max_size)?; + check_public_with_max_size(&n, &e, max_size)?; + + let n = NonZero::new(to_uint(n)).unwrap(); + // widen to 64bit + let e = to_uint_exact(e, 64); + let k = Self { n, e }; + Ok(k) } @@ -215,10 +217,30 @@ impl RsaPublicKey { /// Most applications should use [`RsaPublicKey::new`] or /// [`RsaPublicKey::new_with_max_size`] instead. pub fn new_unchecked(n: BigUint, e: BigUint) -> Self { - Self { - n: NonZero::new(to_uint(n)).unwrap(), - e: to_uint(e), - } + // TODO: widen? + let n = NonZero::new(to_uint(n)).unwrap(); + let e = to_uint_exact(e, 64); + Self { n, e } + } +} + +fn needed_bits(n: &BigUint) -> usize { + // widen to the max size bits + let n_bits = n.bits(); + + // TODO: better algorithm/more sizes + if n_bits <= 512 { + 512 + } else if n_bits <= 1024 { + 1024 + } else if n_bits <= 2048 { + 2048 + } else if n_bits <= 4096 { + 4096 + } else if n_bits <= 8192 { + 8192 + } else { + 16384 } } @@ -274,8 +296,16 @@ impl RsaPrivateKey { d: BigUint, primes: Vec, ) -> Result { + let n_c = NonZero::new(to_uint(n.clone())).unwrap(); + let nbits = n_c.bits_precision(); + + std::dbg!(nbits); + let mut should_validate = false; - let mut primes: Vec<_> = primes.into_iter().map(to_uint).collect(); + let mut primes: Vec<_> = primes + .into_iter() + .map(|p| to_uint_exact(p, nbits)) + .collect(); if primes.len() < 2 { if !primes.is_empty() { @@ -284,17 +314,15 @@ impl RsaPrivateKey { // Recover `p` and `q` from `d`. // See method in Appendix C.2: https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Br2.pdf let (p, q) = recover_primes(&n, &e, &d)?; - primes.push(to_uint(p)); - primes.push(to_uint(q)); + primes.push(to_uint_exact(p, nbits)); + primes.push(to_uint_exact(q, nbits)); should_validate = true; } + let e = to_uint_exact(e, 64); let mut k = RsaPrivateKey { - pubkey_components: RsaPublicKey { - n: NonZero::new(to_uint(n)).unwrap(), - e: to_uint(e), - }, - d: to_uint(d), + pubkey_components: RsaPublicKey { n: n_c, e }, + d: to_uint_exact(d, nbits), primes, precomputed: None, }; @@ -363,6 +391,10 @@ impl RsaPrivateKey { if self.precomputed.is_some() { return Ok(()); } + + // already widened to what we need + let nbits = self.pubkey_components.n.bits_precision(); + let d = to_biguint(&self.d); let dp = &d % (&to_biguint(&self.primes[0]) - BigUint::one()); let dq = &d % (&to_biguint(&self.primes[1]) - BigUint::one()); @@ -378,14 +410,15 @@ impl RsaPrivateKey { for prime in &self.primes[2..] { let prime = to_biguint(prime); let res = CrtValueNew { - exp: to_uint(&d % (&prime - BigUint::one())), - r: to_uint(r.clone()), - coeff: to_uint( + exp: to_uint_exact(&d % (&prime - BigUint::one()), nbits), + r: to_uint_exact(r.clone(), nbits), + coeff: to_uint_exact( r.clone() .mod_inverse(&prime) .ok_or(Error::InvalidCoefficient)? .to_biguint() .unwrap(), + nbits, ), }; r *= prime; @@ -400,8 +433,8 @@ impl RsaPrivateKey { BoxedResidueParams::new(self.pubkey_components.n.clone().get()).unwrap(); self.precomputed = Some(PrecomputedValues { - dp: to_uint(dp), - dq: to_uint(dq), + dp: to_uint_exact(dp, nbits), + dq: to_uint_exact(dq, nbits), qinv, crt_values, residue_params, @@ -539,34 +572,31 @@ impl PrivateKeyPartsNew for RsaPrivateKey { /// Check that the public key is well formed and has an exponent within acceptable bounds. #[inline] pub fn check_public(public_key: &impl PublicKeyParts) -> Result<()> { - check_public_with_max_size(public_key, RsaPublicKey::MAX_SIZE) + check_public_with_max_size(&public_key.n(), &public_key.e(), RsaPublicKey::MAX_SIZE) } /// Check that the public key is well formed and has an exponent within acceptable bounds. #[inline] -fn check_public_with_max_size(public_key: &impl PublicKeyParts, max_size: usize) -> Result<()> { - if public_key.n().bits() > max_size { +fn check_public_with_max_size(n: &BigUint, e: &BigUint, max_size: usize) -> Result<()> { + if n.bits() > max_size { return Err(Error::ModulusTooLarge); } - let e = public_key - .e() - .to_u64() - .ok_or(Error::PublicExponentTooLarge)?; + let eu64 = e.to_u64().ok_or(Error::PublicExponentTooLarge)?; - if public_key.e() >= public_key.n() || public_key.n().is_even() { + if e >= n || n.is_even() { return Err(Error::InvalidModulus); } - if public_key.e().is_even() { + if e.is_even() { return Err(Error::InvalidExponent); } - if e < RsaPublicKey::MIN_PUB_EXPONENT { + if eu64 < RsaPublicKey::MIN_PUB_EXPONENT { return Err(Error::PublicExponentTooSmall); } - if e > RsaPublicKey::MAX_PUB_EXPONENT { + if eu64 > RsaPublicKey::MAX_PUB_EXPONENT { return Err(Error::PublicExponentTooLarge); } @@ -577,12 +607,35 @@ pub(crate) fn to_biguint(uint: &BoxedUint) -> BigUint { BigUint::from_bytes_be(&uint.to_be_bytes()) } +pub(crate) fn to_uint_exact(big_uint: BigUint, nbits: usize) -> BoxedUint { + let bytes = big_uint.to_bytes_be(); + let pad_count = Limb::BYTES - (bytes.len() % Limb::BYTES); + let mut padded_bytes = vec![0u8; pad_count]; + padded_bytes.extend_from_slice(&bytes); + + let res = BoxedUint::from_be_slice(&padded_bytes, padded_bytes.len() * 8).unwrap(); + + match res.bits_precision().cmp(&nbits) { + Ordering::Equal => res, + Ordering::Greater => panic!("too large: {} > {}", res.bits_precision(), nbits), + Ordering::Less => res.widen(nbits), + } +} + pub(crate) fn to_uint(big_uint: BigUint) -> BoxedUint { + let nbits = needed_bits(&big_uint); + let bytes = big_uint.to_bytes_be(); let pad_count = Limb::BYTES - (bytes.len() % Limb::BYTES); let mut padded_bytes = vec![0u8; pad_count]; padded_bytes.extend_from_slice(&bytes); - BoxedUint::from_be_slice(&padded_bytes, padded_bytes.len() * 8).unwrap() + + let res = BoxedUint::from_be_slice(&padded_bytes, padded_bytes.len() * 8).unwrap(); + + if res.bits() < nbits { + return res.widen(nbits); + } + res } pub(crate) fn reduce(n: &BoxedUint, p: BoxedResidueParams) -> BoxedResidue { @@ -614,10 +667,10 @@ mod tests { fn test_from_into() { let private_key = RsaPrivateKey { pubkey_components: RsaPublicKey { - n: NonZero::new(to_uint(BigUint::from_u64(100).unwrap())).unwrap(), - e: to_uint(BigUint::from_u64(200).unwrap()), + n: NonZero::new(to_uint(BigUint::from_u64(100).unwrap()).widen(64)).unwrap(), + e: to_uint(BigUint::from_u64(200).unwrap()).widen(64), }, - d: to_uint(BigUint::from_u64(123).unwrap()), + d: to_uint(BigUint::from_u64(123).unwrap()).widen(64), primes: vec![], precomputed: None, }; @@ -654,7 +707,8 @@ mod tests { let mut rng = ChaCha8Rng::from_seed([42; 32]); let exp = BigUint::from_u64(RsaPrivateKey::EXP).expect("invalid static exponent"); - for _ in 0..10 { + for i in 0..10 { + std::dbg!(i, $size); let components = generate_multi_prime_key_with_exp(&mut rng, $multi, $size, &exp).unwrap(); let private_key = RsaPrivateKey::from_components(