Skip to content

Commit

Permalink
Merge pull request #68 from fjarri/stateful-entry-points
Browse files Browse the repository at this point in the history
Stateful entry points
  • Loading branch information
fjarri authored Nov 18, 2024
2 parents e43d728 + 53ceede commit bcfb81a
Show file tree
Hide file tree
Showing 14 changed files with 408 additions and 342 deletions.
57 changes: 30 additions & 27 deletions examples/src/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
use tracing::debug;

#[derive(Debug)]
pub struct SimpleProtocol;

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -111,11 +112,6 @@ impl Protocol for SimpleProtocol {
}
}

#[derive(Debug, Clone)]
pub struct Inputs<Id> {
pub all_ids: BTreeSet<Id>,
}

#[derive(Debug)]
pub(crate) struct Context<Id> {
pub(crate) id: Id,
Expand Down Expand Up @@ -149,30 +145,40 @@ struct Round1Payload {
x: u8,
}

impl<Id: PartyId> EntryPoint<Id> for Round1<Id> {
type Inputs = Inputs<Id>;
#[derive(Debug, Clone)]
pub struct SimpleProtocolEntryPoint<Id> {
all_ids: BTreeSet<Id>,
}

impl<Id: PartyId> SimpleProtocolEntryPoint<Id> {
pub fn new(all_ids: BTreeSet<Id>) -> Self {
Self { all_ids }
}
}

impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
type Protocol = SimpleProtocol;
fn new(
fn make_round(
self,
_rng: &mut impl CryptoRngCore,
_shared_randomness: &[u8],
id: Id,
inputs: Self::Inputs,
id: &Id,
) -> Result<BoxedRound<Id, Self::Protocol>, LocalError> {
// Just some numbers associated with IDs to use in the dummy protocol.
// They will be the same on each node since IDs are ordered.
let ids_to_positions = inputs
let ids_to_positions = self
.all_ids
.iter()
.enumerate()
.map(|(idx, id)| (id.clone(), idx as u8))
.collect::<BTreeMap<_, _>>();

let mut ids = inputs.all_ids;
ids.remove(&id);
let mut ids = self.all_ids;
ids.remove(id);

Ok(BoxedRound::new_dynamic(Self {
Ok(BoxedRound::new_dynamic(Round1 {
context: Context {
id,
id: id.clone(),
other_ids: ids,
ids_to_positions,
},
Expand Down Expand Up @@ -318,6 +324,10 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {
BTreeSet::new()
}

fn may_produce_result(&self) -> bool {
true
}

fn message_destinations(&self) -> &BTreeSet<Id> {
&self.context.other_ids
}
Expand Down Expand Up @@ -401,12 +411,12 @@ mod tests {

use manul::{
session::{signature::Keypair, SessionOutcome},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
};
use rand_core::OsRng;
use tracing_subscriber::EnvFilter;

use super::{Inputs, Round1};
use super::SimpleProtocolEntryPoint;

#[test]
fn round() {
Expand All @@ -415,23 +425,16 @@ mod tests {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = signers
let entry_points = signers
.into_iter()
.map(|signer| {
(
signer,
Inputs {
all_ids: all_ids.clone(),
},
)
})
.map(|signer| (signer, SimpleProtocolEntryPoint::new(all_ids.clone())))
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<Round1<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

for (_id, report) in reports {
Expand Down
87 changes: 55 additions & 32 deletions examples/src/simple_chain.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,82 @@
use alloc::collections::BTreeSet;
use core::fmt::Debug;

use manul::{
combinators::chain::{Chained, ChainedEntryPoint},
protocol::PartyId,
combinators::{
chain::{Chain, ChainedJoin, ChainedProtocol, ChainedSplit},
CombinatorEntryPoint,
},
protocol::{PartyId, Protocol},
};

use super::simple::{Inputs, Round1};

pub struct ChainedSimple;
use super::simple::{SimpleProtocol, SimpleProtocolEntryPoint};

/// A protocol that runs the [`SimpleProtocol`] twice, in sequence.
/// Illustrates the chain protocol combinator.
#[derive(Debug)]
pub struct NewInputs<Id>(Inputs<Id>);
pub struct DoubleSimpleProtocol;

impl ChainedProtocol for DoubleSimpleProtocol {
type Protocol1 = SimpleProtocol;
type Protocol2 = SimpleProtocol;
}

pub struct DoubleSimpleEntryPoint<Id> {
all_ids: BTreeSet<Id>,
}

impl<'a, Id: PartyId> From<&'a NewInputs<Id>> for Inputs<Id> {
fn from(source: &'a NewInputs<Id>) -> Self {
source.0.clone()
impl<Id: PartyId> DoubleSimpleEntryPoint<Id> {
pub fn new(all_ids: BTreeSet<Id>) -> Self {
Self { all_ids }
}
}

impl<Id: PartyId> From<(NewInputs<Id>, u8)> for Inputs<Id> {
fn from(source: (NewInputs<Id>, u8)) -> Self {
let (inputs, _result) = source;
inputs.0
impl<Id> CombinatorEntryPoint for DoubleSimpleEntryPoint<Id> {
type Combinator = Chain;
}

impl<Id> ChainedSplit<Id> for DoubleSimpleEntryPoint<Id>
where
Id: PartyId,
{
type Protocol = DoubleSimpleProtocol;
type EntryPoint = SimpleProtocolEntryPoint<Id>;
fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin<Id, Protocol = Self::Protocol>) {
(
SimpleProtocolEntryPoint::new(self.all_ids.clone()),
DoubleSimpleProtocolTransition { all_ids: self.all_ids },
)
}
}

impl<Id: PartyId> Chained<Id> for ChainedSimple {
type Inputs = NewInputs<Id>;
type EntryPoint1 = Round1<Id>;
type EntryPoint2 = Round1<Id>;
#[derive(Debug)]
struct DoubleSimpleProtocolTransition<Id> {
all_ids: BTreeSet<Id>,
}

pub type DoubleSimpleEntryPoint<Id> = ChainedEntryPoint<Id, ChainedSimple>;
impl<Id> ChainedJoin<Id> for DoubleSimpleProtocolTransition<Id>
where
Id: PartyId,
{
type Protocol = DoubleSimpleProtocol;
type EntryPoint = SimpleProtocolEntryPoint<Id>;
fn make_entry_point2(self, _result: <SimpleProtocol as Protocol>::Result) -> Self::EntryPoint {
SimpleProtocolEntryPoint::new(self.all_ids)
}
}

#[cfg(test)]
mod tests {
use alloc::collections::BTreeSet;

use manul::{
session::{signature::Keypair, SessionOutcome},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
};
use rand_core::OsRng;
use tracing_subscriber::EnvFilter;

use super::{DoubleSimpleEntryPoint, NewInputs};
use crate::simple::Inputs;
use super::DoubleSimpleEntryPoint;

#[test]
fn round() {
Expand All @@ -54,24 +85,16 @@ mod tests {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = signers
let entry_points = signers
.into_iter()
.map(|signer| {
(
signer,
NewInputs(Inputs {
all_ids: all_ids.clone(),
}),
)
})
.map(|signer| (signer, DoubleSimpleEntryPoint::new(all_ids.clone())))
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<DoubleSimpleEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, inputs)
.unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

for (_id, report) in reports {
Expand Down
44 changes: 16 additions & 28 deletions examples/src/simple_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ use alloc::collections::BTreeSet;
use core::fmt::Debug;

use manul::{
combinators::misbehave::{Misbehaving, MisbehavingEntryPoint, MisbehavingInputs},
combinators::misbehave::{Misbehaving, MisbehavingEntryPoint},
protocol::{
Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart,
RoundId, Serializer,
},
session::signature::Keypair,
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
};
use rand_core::{CryptoRngCore, OsRng};
use tracing_subscriber::EnvFilter;

use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message};
use crate::simple::{Round1, Round1Message, Round2, Round2Message, SimpleProtocolEntryPoint};

#[derive(Debug, Clone, Copy)]
enum Behavior {
Expand All @@ -25,7 +25,7 @@ enum Behavior {
struct MaliciousLogic;

impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
type EntryPoint = Round1<Id>;
type EntryPoint = SimpleProtocolEntryPoint<Id>;

fn modify_direct_message(
_rng: &mut impl CryptoRngCore,
Expand Down Expand Up @@ -78,9 +78,8 @@ fn serialized_garbage() {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = Inputs { all_ids };

let run_inputs = signers
let entry_points = signers
.iter()
.enumerate()
.map(|(idx, signer)| {
Expand All @@ -90,19 +89,16 @@ fn serialized_garbage() {
None
};

let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
behavior,
};
(*signer, malicious_inputs)
let entry_point = MaliciousEntryPoint::new(SimpleProtocolEntryPoint::new(all_ids.clone()), behavior);
(*signer, entry_point)
})
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

let v0 = signers[0].verifying_key();
Expand All @@ -124,9 +120,8 @@ fn attributable_failure() {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = Inputs { all_ids };

let run_inputs = signers
let entry_points = signers
.iter()
.enumerate()
.map(|(idx, signer)| {
Expand All @@ -136,19 +131,16 @@ fn attributable_failure() {
None
};

let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
behavior,
};
(*signer, malicious_inputs)
let entry_point = MaliciousEntryPoint::new(SimpleProtocolEntryPoint::new(all_ids.clone()), behavior);
(*signer, entry_point)
})
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

let v0 = signers[0].verifying_key();
Expand All @@ -170,9 +162,8 @@ fn attributable_failure_round2() {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = Inputs { all_ids };

let run_inputs = signers
let entry_points = signers
.iter()
.enumerate()
.map(|(idx, signer)| {
Expand All @@ -182,19 +173,16 @@ fn attributable_failure_round2() {
None
};

let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
behavior,
};
(*signer, malicious_inputs)
let entry_point = MaliciousEntryPoint::new(SimpleProtocolEntryPoint::new(all_ids.clone()), behavior);
(*signer, entry_point)
})
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down
Loading

0 comments on commit bcfb81a

Please sign in to comment.