Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A draft API for validation of replicated shares #936

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ipa-core/src/protocol/basics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod reshare;
mod reveal;
mod share_known_value;
pub mod sum_of_product;
pub mod validate;

pub use check_zero::check_zero;
pub use if_else::if_else;
Expand Down
218 changes: 218 additions & 0 deletions ipa-core/src/protocol/basics/validate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
use std::{
convert::Infallible,
marker::PhantomData,
pin::Pin,
task::{Context as TaskContext, Poll},
};

use futures::{
future::try_join,
stream::{Fuse, Stream, StreamExt},
Future, FutureExt,
};
use generic_array::GenericArray;
use pin_project::pin_project;
use sha2::{
digest::{typenum::Unsigned, FixedOutput, OutputSizeUser},
Digest, Sha256,
};

use crate::{
error::Error,
ff::Serializable,
helpers::{Direction, Message},
protocol::{context::Context, RecordId},
secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue},
};

type HashFunction = Sha256;
type HashSize = <HashFunction as OutputSizeUser>::OutputSize;
type HashOutputArray = [u8; <HashSize as Unsigned>::USIZE];

#[derive(Debug, Clone, PartialEq, Eq)]
struct HashValue(GenericArray<u8, HashSize>);

impl Serializable for HashValue {
type Size = HashSize;
type DeserializationError = Infallible;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
buf.copy_from_slice(self.0.as_slice())
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
Ok(Self(buf.to_owned()))
}
}

impl Message for HashValue {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have also defined a hash struct that implements message in my draft. Should we move it to a more general crate since we seem to need it in several places?


struct ReplicatedValidatorFinalization<C> {
f: Pin<Box<dyn Future<Output = Result<(), Error>>>>,
ctx: C,
}

impl<C: Context + 'static> ReplicatedValidatorFinalization<C> {
fn new(active: ReplicatedValidatorActive<C>) -> Self {
let ReplicatedValidatorActive {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be nicer from API's perspective to let ReplicatedValidatorActive to turn itself into a pair of hashes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is just internal, I didn't do that. Though I did implement From for HashValue, which made this function a little less ugly.

ctx,
left_hash,
right_hash,
} = active;
// Ugh: The version of sha2 we currently use doesn't use the same GenericArray version as we do.
let left_hash = HashValue(GenericArray::from(<HashOutputArray>::from(
left_hash.finalize_fixed(),
)));
let right_hash = HashValue(GenericArray::from(<HashOutputArray>::from(
right_hash.finalize_fixed(),
)));
let left_peer = ctx.role().peer(Direction::Left);
let right_peer = ctx.role().peer(Direction::Left);
martinthomson marked this conversation as resolved.
Show resolved Hide resolved
let ctx_ref = &ctx;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to borrow here or it is acceptable to clone the context and move it inside the box?

If we want to borrow from C, we would need to create a self-referential struct, so maybe we need an inner struct that is pinned and holds f that borrows from ctx and ctx

however, I was able to compile it just by let ctx = ctx.clone() but that's probably not what you want here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I managed to get this to compile with a similar trick. Though the tests are now failing because the TestWorld doesn't have a 'static lifetime. I put the 'static condition on to avoid problems creating the future, ...


let f = Box::pin(async move {
try_join(
ctx_ref
.send_channel(left_peer)
.send(RecordId::FIRST, left_hash.clone()),
ctx_ref
.send_channel(right_peer)
.send(RecordId::FIRST, right_hash.clone()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to send it to both, left and right. It seems sufficient to me if each party verifies one hash. It doesn't add significant costs if a party checks both hashes, however it also doesn't seem to add anything either.

)
.await?;
let (left_recvd, right_recvd) = try_join(
ctx_ref.recv_channel(left_peer).receive(RecordId::FIRST),
ctx_ref.recv_channel(right_peer).receive(RecordId::FIRST),
)
.await?;
if left_hash == left_recvd && right_hash == right_recvd {
Ok(())
} else {
Err(Error::Internal) // TODO add a code
}
});
Self { f, ctx }
}

fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Error>> {
self.f.poll_unpin(cx)
}
}

struct ReplicatedValidatorActive<C> {
ctx: C,
left_hash: Sha256,
right_hash: Sha256,
}

impl<C: Context + 'static> ReplicatedValidatorActive<C> {
fn new(ctx: C) -> Self {
Self {
ctx,
left_hash: HashFunction::new(),
right_hash: HashFunction::new(),
}
}

fn update<S, V>(&mut self, s: &S)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For debugging, it might be useful if the update API takes both a Step and data, and creates a trace of the validator inputs. Then we can diagnose where things went wrong if there is a mismatch. (Obviously, we would want a flag so we only pay the cost of the detailed tracing when needed.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context carries a step, so we could smuggle that in somewhere.

I'm more concerned that I've been unable to instantiate this code :)

where
S: ReplicatedSecretSharing<V>,
martinthomson marked this conversation as resolved.
Show resolved Hide resolved
V: SharedValue,
{
let mut buf = GenericArray::default(); // ::<u8, <V as Serializable>::Size>
s.left().serialize(&mut buf);
self.left_hash.update(buf.as_slice());
s.right().serialize(&mut buf);
self.right_hash.update(buf.as_slice());
}

fn finalize(self) -> ReplicatedValidatorFinalization<C> {
ReplicatedValidatorFinalization::new(self)
}
}

enum ReplicatedValidatorState<C> {
/// While the validator is waiting, it holds a context reference.
Pending(Option<ReplicatedValidatorActive<C>>),
/// After the validator has taken all of its inputs, it holds a future.
Finalizing(ReplicatedValidatorFinalization<C>),
}

impl<C: Context + 'static> ReplicatedValidatorState<C> {
/// # Panics
/// This panics if it is called after `finalize()`.
fn update<S, V>(&mut self, s: &S)
where
S: ReplicatedSecretSharing<V>,
V: SharedValue,
{
if let Self::Pending(Some(a)) = self {
a.update(s);
} else {
panic!();
}
}

fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Error>> {
match self {
Self::Pending(ref mut active) => {
let mut f = active.take().unwrap().finalize();
let res = f.poll(cx);
*self = ReplicatedValidatorState::Finalizing(f);
res
}
Self::Finalizing(f) => f.poll(cx),
}
}
}

#[pin_project]
struct ReplicatedValidator<C, T: Stream, S, V> {
#[pin]
input: Fuse<T>,
state: ReplicatedValidatorState<C>,
_marker: PhantomData<(S, V)>,
}

impl<C: Context + 'static, T: Stream, S, V> ReplicatedValidator<C, T, S, V> {
pub fn new(ctx: C, s: T) -> Self {
Self {
input: s.fuse(),
state: ReplicatedValidatorState::Pending(Some(ReplicatedValidatorActive::new(ctx))),
_marker: PhantomData,
}
}
}

impl<C, T, S, V> Stream for ReplicatedValidator<C, T, S, V>
where
C: Context + 'static,
T: Stream<Item = Result<S, Error>>,
S: ReplicatedSecretSharing<V>,
V: SharedValue,
{
type Item = Result<S, Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.input.poll_next(cx) {
Poll::Ready(Some(v)) => match v {
Ok(v) => {
this.state.update(&v);
Poll::Ready(Some(Ok(v)))
}
Err(e) => Poll::Ready(Some(Err(e))),
},
Poll::Ready(None) => match this.state.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
},
Poll::Pending => Poll::Pending,
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.input.size_hint()
}
}
Loading