Skip to content

Commit

Permalink
Multi-scalar-mul for TwistedEdwards (#297)
Browse files Browse the repository at this point in the history
* generalize MSM to any curve short weierstrass + twisted edwards

* Better cacheline prefetching

* Generalize MSM to TwistedEdwards curves including Bandersnatch and Banderwagon
  • Loading branch information
mratsim authored Nov 20, 2023
1 parent 5f7ba18 commit d77bb79
Show file tree
Hide file tree
Showing 27 changed files with 744 additions and 358 deletions.
49 changes: 49 additions & 0 deletions benchmarks/bench_ec_msm_bandersnatch.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
# Internals
../constantine/math/config/curves,
../constantine/math/arithmetic,
../constantine/math/elliptic/ec_twistededwards_projective,
# Helpers
../helpers/prng_unsafe,
./bench_elliptic_parallel_template

# ############################################################
#
# Benchmark of the G1 group of
# Short Weierstrass elliptic curves
# in (homogeneous) projective coordinates
#
# ############################################################


const Iters = 10_000
const AvailableCurves = [
Bandersnatch
]

# const testNumPoints = [10, 100, 1000, 10000, 100000]
# const testNumPoints = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
const testNumPoints = [1 shl 8, 1 shl 9, 1 shl 10, 1 shl 11, 1 shl 12, 1 shl 13, 1 shl 14, 1 shl 15, 1 shl 16, 1 shl 17, 1 shl 22]

proc main() =
separator()
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
var ctx = createBenchMsmContext(ECP_TwEdwards_Prj[Fp[curve]], testNumPoints)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
ctx.msmParallelBench(numPoints, batchIters)
separator()
separator()

main()
notes()
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ proc main() =
separator()
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
var ctx = createBenchMsmContext(ECP_ShortW_Jac[Fp[curve], G1], testNumPoints)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
msmParallelBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
ctx.msmParallelBench(numPoints, batchIters)
separator()
separator()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import
ec_shortweierstrass_projective,
ec_shortweierstrass_jacobian],
# Helpers
../helpers/prng_unsafe,
./bench_elliptic_parallel_template

# ############################################################
Expand All @@ -32,17 +31,18 @@ const AvailableCurves = [
]

# const testNumPoints = [10, 100, 1000, 10000, 100000]
# const testNumPoints = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
const testNumPoints = [1 shl 16, 1 shl 22]
const testNumPoints = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 1 shl 22]
# const testNumPoints = [1 shl 16, 1 shl 22]

proc main() =
separator()
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
var ctx = createBenchMsmContext(ECP_ShortW_Jac[Fp[curve], G1], testNumPoints)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
msmParallelBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
ctx.msmParallelBench(numPoints, batchIters)
separator()
separator()

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ proc main() =
separator()
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
var ctx = createBenchMsmContext(ECP_ShortW_Jac[Fp[curve], G1], testNumPoints)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
msmParallelBench(ECP_ShortW_Jac[Fp[curve], G1], numPoints, batchIters)
ctx.msmParallelBench(numPoints, batchIters)
separator()
separator()

Expand Down
1 change: 1 addition & 0 deletions benchmarks/bench_ec_msm_pasta.nim.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--threads:on
86 changes: 57 additions & 29 deletions benchmarks/bench_elliptic_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,43 @@ export bench_elliptic_template
#
# ############################################################

proc multiAddParallelBench*(EC: typedesc, numPoints: int, iters: int) =
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints)
proc multiAddParallelBench*(EC: typedesc, numInputs: int, iters: int) =
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numInputs)

for i in 0 ..< numPoints:
for i in 0 ..< numInputs:
points[i] = rng.random_unsafe(ECP_ShortW_Aff[EC.F, EC.G])

var r{.noInit.}: EC

let tp = Threadpool.new()

bench("EC parallel batch add (" & align($tp.numThreads, 2) & " threads) " & $EC.G & " (" & $numPoints & " points)", EC, iters):
bench("EC parallel batch add (" & align($tp.numThreads, 2) & " threads) " & $EC.G & " (" & $numInputs & " points)", EC, iters):
tp.sum_reduce_vartime_parallel(r, points)

tp.shutdown()

proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
# Multi-scalar multiplication
# ---------------------------------------------------------------------------

type BenchMsmContext*[EC] = object
tp: Threadpool
numInputs: int
coefs: seq[matchingOrderBigInt(EC.F.C)]
points: seq[affine(EC)]

proc createBenchMsmContext*(EC: typedesc, inputSizes: openArray[int]): BenchMsmContext[EC] =
result.tp = Threadpool.new()
let maxNumInputs = inputSizes.max()

const bits = EC.F.C.getCurveOrderBitwidth()
var points = newSeq[ECP_ShortW_Aff[EC.F, EC.G]](numPoints)
var scalars = newSeq[BigInt[bits]](numPoints)
type ECaff = affine(EC)

# Creating millions of points and clearing their cofactor takes a long long time
var tp = Threadpool.new()
result.points = newSeq[ECaff](maxNumInputs)
result.coefs = newSeq[BigInt[bits]](maxNumInputs)

proc genCoefPointPairs(rngSeed: uint64, start, len: int, points: ptr ECP_ShortW_Aff[EC.F, EC.G], scalars: ptr BigInt[bits]) {.nimcall.} =
let points = cast[ptr UncheckedArray[ECP_ShortW_Aff[EC.F, EC.G]]](points) # TODO use views to reduce verbosity
let scalars = cast[ptr UncheckedArray[BigInt[bits]]](scalars)
proc genCoefPointPairsChunk[EC, ECaff](rngSeed: uint64, start, len: int, points: ptr ECaff, coefs: ptr BigInt[bits]) {.nimcall.} =
let points = cast[ptr UncheckedArray[ECaff]](points)
let coefs = cast[ptr UncheckedArray[BigInt[bits]]](coefs)

# RNGs are not threadsafe, create a threadlocal one seeded from the global RNG
var threadRng: RngState
Expand All @@ -70,60 +81,77 @@ proc msmParallelBench*(EC: typedesc, numPoints: int, iters: int) =
var tmp = threadRng.random_unsafe(EC)
tmp.clearCofactor()
points[i].affine(tmp)
scalars[i] = rng.random_unsafe(BigInt[bits])
coefs[i] = rng.random_unsafe(BigInt[bits])

let chunks = balancedChunksPrioNumber(0, maxNumInputs, result.tp.numThreads)


stdout.write &"Generating {maxNumInputs} (coefs, points) pairs ... "
stdout.flushFile()

let chunks = balancedChunksPrioNumber(0, numPoints, tp.numThreads)
let start = getMonotime()

syncScope:
for (id, start, size) in items(chunks):
tp.spawn genCoefPointPairs(rng.next(), start, size, points[0].addr, scalars[0].addr)
result.tp.spawn genCoefPointPairsChunk[EC, ECaff](rng.next(), start, size, result.points[0].addr, result.coefs[0].addr)

# Even if child threads are sleeping, it seems like perf is lower when there are threads around
# maybe because the kernel has more overhead or time quantum to keep track off so shut them down.
tp.shutdown()
result.tp.shutdown()

let stop = getMonotime()
stdout.write &"in {float64(inNanoSeconds(stop-start)) / 1e6:6.3f} ms\n"

proc msmParallelBench*[EC](ctx: var BenchMsmContext[EC], numInputs: int, iters: int) =
const bits = EC.F.C.getCurveOrderBitwidth()
type ECaff = affine(EC)

template coefs: untyped = ctx.coefs.toOpenArray(0, numInputs-1)
template points: untyped = ctx.points.toOpenArray(0, numInputs-1)


var r{.noInit.}: EC
var startNaive, stopNaive, startMSMbaseline, stopMSMbaseline, startMSMopt, stopMSMopt, startMSMpara, stopMSMpara: MonoTime

if numPoints <= 100000:
if numInputs <= 100000:
startNaive = getMonotime()
bench("EC scalar muls " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
bench("EC scalar muls " & align($numInputs, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
var tmp: EC
r.setInf()
for i in 0 ..< points.len:
tmp.fromAffine(points[i])
tmp.scalarMul(scalars[i])
tmp.scalarMul(coefs[i])
r += tmp
stopNaive = getMonotime()

if numPoints <= 100000:
if numInputs <= 100000:
startMSMbaseline = getMonotime()
bench("EC multi-scalar-mul baseline " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
r.multiScalarMul_reference_vartime(scalars, points)
bench("EC multi-scalar-mul baseline " & align($numInputs, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
r.multiScalarMul_reference_vartime(coefs, points)
stopMSMbaseline = getMonotime()

block:
startMSMopt = getMonotime()
bench("EC multi-scalar-mul optimized " & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
r.multiScalarMul_vartime(scalars, points)
bench("EC multi-scalar-mul optimized " & align($numInputs, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
r.multiScalarMul_vartime(coefs, points)
stopMSMopt = getMonotime()

block:
tp = Threadpool.new()
ctx.tp = Threadpool.new()

startMSMpara = getMonotime()
bench("EC multi-scalar-mul" & align($tp.numThreads & " threads", 11) & align($numPoints, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
tp.multiScalarMul_vartime_parallel(r, scalars, points)
bench("EC multi-scalar-mul" & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit coefs, points)", EC, iters):
ctx.tp.multiScalarMul_vartime_parallel(r, coefs, points)
stopMSMpara = getMonotime()

tp.shutdown()
ctx.tp.shutdown()

let perfNaive = inNanoseconds((stopNaive-startNaive) div iters)
let perfMSMbaseline = inNanoseconds((stopMSMbaseline-startMSMbaseline) div iters)
let perfMSMopt = inNanoseconds((stopMSMopt-startMSMopt) div iters)
let perfMSMpara = inNanoseconds((stopMSMpara-startMSMpara) div iters)

if numPoints <= 100000:
if numInputs <= 100000:
let speedupBaseline = float(perfNaive) / float(perfMSMbaseline)
echo &"Speedup ratio baseline over naive linear combination: {speedupBaseline:>6.3f}x"

Expand Down
1 change: 1 addition & 0 deletions benchmarks/bench_elliptic_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import

export notes
export abstractions # generic sandwich on SecretBool and SecretBool in Jacobian sum
export bench_blueprint

proc separator*() = separator(179)

Expand Down
27 changes: 17 additions & 10 deletions constantine.nimble
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
("tests/math_elliptic_curves/t_ec_shortw_jacext_g1_mixed_add.nim", false),

# ("tests/math_elliptic_curves/t_ec_twedwards_prj_add_double", false),
# ("tests/math_elliptic_curves/t_ec_twedwards_prj_mul_sanity", false),
# ("tests/math_elliptic_curves/t_ec_twedwards_prj_mul_distri", false),
("tests/math_elliptic_curves/t_ec_twedwards_prj_mul_sanity", false),
("tests/math_elliptic_curves/t_ec_twedwards_prj_mul_distri", false),


# Elliptic curve arithmetic G2
Expand Down Expand Up @@ -433,6 +433,7 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[
("tests/math_elliptic_curves/t_ec_shortw_jacext_g1_sum_reduce.nim", false),
("tests/math_elliptic_curves/t_ec_shortw_prj_g1_msm.nim", false),
("tests/math_elliptic_curves/t_ec_shortw_jac_g1_msm.nim", false),
("tests/math_elliptic_curves/t_ec_twedwards_prj_msm.nim", false),

# Subgroups and cofactors
# ----------------------------------------------------------
Expand Down Expand Up @@ -525,8 +526,10 @@ const benchDesc = [
"bench_ec_g1",
"bench_ec_g1_scalar_mul",
"bench_ec_g1_batch",
"bench_ec_g1_msm_bn254_snarks",
"bench_ec_g1_msm_bls12_381",
"bench_ec_msm_bandersnatch",
"bench_ec_msm_bn254_snarks_g1",
"bench_ec_msm_bls12_381_g1",
"bench_ec_msm_pasta",
"bench_ec_g2",
"bench_ec_g2_scalar_mul",
"bench_pairing_bls12_377",
Expand Down Expand Up @@ -879,14 +882,18 @@ task bench_ec_g1_scalar_mul, "Run benchmark on Elliptic Curve group 𝔾1 (Scala
# Elliptic curve G1 - Multi-scalar-mul
# ------------------------------------------

task bench_ec_g1_msm_pasta, "Run benchmark on Elliptic Curve group 𝔾1 (Multi-Scalar-Mul) for Pasta curves - CC compiler":
runBench("bench_ec_g1_msm_pasta")
task bench_ec_msm_pasta, "Run benchmark: Multi-Scalar-Mul for Pasta curves - CC compiler":
runBench("bench_ec_msm_pasta")

task bench_ec_g1_msm_bn254_snarks, "Run benchmark on Elliptic Curve group 𝔾1 (Multi-Scalar-Mul) for BN254-Snarks - CC compiler":
runBench("bench_ec_g1_msm_bn254_snarks")
task bench_ec_msm_bn254_snarks_g1, "Run benchmark: Multi-Scalar-Mul for BN254-Snarks 𝔾1 - CC compiler":
runBench("bench_ec_msm_bn254_snarks_g1")

task bench_ec_msm_bls12_381_g1, "Run benchmark: Multi-Scalar-Mul for BLS12-381 𝔾1 - CC compiler":
runBench("bench_ec_msm_bls12_381_g1")

task bench_ec_msm_bandersnatch, "Run benchmark: Multi-Scalar-Mul for Bandersnatch - CC compiler":
runBench("bench_ec_msm_bandersnatch")

task bench_ec_g1_msm_bls12_381, "Run benchmark on Elliptic Curve group 𝔾1 (Multi-Scalar-Mul) for BLS12-381 - CC compiler":
runBench("bench_ec_g1_msm_bls12_381")

# Elliptic curve G2
# ------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions constantine/math/constants/bandersnatch_subgroups.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
# Internals
../config/curves,
../arithmetic,
../elliptic/ec_twistededwards_projective

# ############################################################
#
# Clear Cofactor
#
# ############################################################

func clearCofactorReference*(P: var ECP_TwEdwards_Prj[Fp[Bandersnatch]]) {.inline.} =
## Clear the cofactor of Bandersnatch
# https://hackmd.io/@6iQDuIePQjyYBqDChYw_jg/BJBNcv9fq#Bandersnatch-Subgroup
#
# Bandersnatch Subgroup

# The group structure of bandersnatch is ℤ₂ x ℤ₂ x p, where p is a prime
#
# The non-cyclic subgroup which we may refer to as the 2 torsion subgroup is:
# E[2] = {(0, 1), D₀, D₁, D₂}
#
# Remark: All of these points have order 2 or 1, so it is sufficient to double any point in the bandersnatch group to clear the cofactor. ie one does not need to multiply by the cofactor; 4.
# Remark: We may also refer to the 2 torsion subgroup as the small order subgroup.
P.double()
2 changes: 2 additions & 0 deletions constantine/math/constants/zoo_subgroups.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import
# Internals
../config/curves,
./bandersnatch_subgroups,
./bls12_377_subgroups,
./bls12_381_subgroups,
./bn254_nogami_subgroups,
Expand All @@ -19,6 +20,7 @@ import
./secp256k1_subgroups

export
bandersnatch_subgroups,
bls12_377_subgroups,
bls12_381_subgroups,
bn254_nogami_subgroups,
Expand Down
Loading

0 comments on commit d77bb79

Please sign in to comment.