Skip to content

Commit

Permalink
Use return position impl trait in trait (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn committed May 16, 2024
2 parents f630f07 + 5cc9041 commit 1059370
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 460 deletions.
86 changes: 27 additions & 59 deletions src/transport/combined.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
//! Transport that combines two other transports
use super::{Connection, ConnectionCommon, ConnectionErrors, LocalAddr, ServerEndpoint};
use crate::{RpcMessage, Service};
use futures_lite::{future::Boxed as BoxFuture, Stream};
use futures_lite::Stream;
use futures_sink::Sink;
use futures_util::{FutureExt, TryFutureExt};
use pin_project::pin_project;
use std::{
error, fmt,
fmt::Debug,
marker::PhantomData,
pin::Pin,
result,
task::{Context, Poll},
};

Expand Down Expand Up @@ -273,19 +271,6 @@ impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for AcceptBiError<A,

impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for AcceptBiError<A, B> {}

/// Future returned by open_bi
pub type OpenBiFuture<A, B, In, Out> =
BoxFuture<result::Result<Socket<A, B, In, Out>, self::OpenBiError<A, B>>>;

/// Future returned by accept_bi
pub type AcceptBiFuture<A, B, In, Out> =
BoxFuture<result::Result<self::Socket<A, B, In, Out>, self::AcceptBiError<A, B>>>;

type Socket<A, B, In, Out> = (
self::SendSink<A, B, In, Out>,
self::RecvStream<A, B, In, Out>,
);

impl<A: ConnectionErrors, B: ConnectionErrors, S: Service> ConnectionErrors
for CombinedConnection<A, B, S>
{
Expand All @@ -304,24 +289,19 @@ impl<A: Connection<S::Res, S::Req>, B: Connection<S::Res, S::Req>, S: Service>
impl<A: Connection<S::Res, S::Req>, B: Connection<S::Res, S::Req>, S: Service>
Connection<S::Res, S::Req> for CombinedConnection<A, B, S>
{
fn open_bi(&self) -> OpenBiFuture<A, B, S::Res, S::Req> {
async fn open_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
let this = self.clone();
async {
// try a first, then b
if let Some(a) = this.a {
let (send, recv) = a.open_bi().await.map_err(OpenBiError::A)?;
Ok((SendSink::A(send), RecvStream::A(recv)))
} else if let Some(b) = this.b {
let (send, recv) = b.open_bi().await.map_err(OpenBiError::B)?;
Ok((SendSink::B(send), RecvStream::B(recv)))
} else {
std::future::ready(Err(OpenBiError::NoChannel)).await
}
// try a first, then b
if let Some(a) = this.a {
let (send, recv) = a.open_bi().await.map_err(OpenBiError::A)?;
Ok((SendSink::A(send), RecvStream::A(recv)))
} else if let Some(b) = this.b {
let (send, recv) = b.open_bi().await.map_err(OpenBiError::B)?;
Ok((SendSink::B(send), RecvStream::B(recv)))
} else {
Err(OpenBiError::NoChannel)
}
.boxed()
}

type OpenBiFut = OpenBiFuture<A, B, S::Res, S::Req>;
}

impl<A: ConnectionErrors, B: ConnectionErrors, S: Service> ConnectionErrors
Expand All @@ -342,44 +322,32 @@ impl<A: ServerEndpoint<S::Req, S::Res>, B: ServerEndpoint<S::Req, S::Res>, S: Se
impl<A: ServerEndpoint<S::Req, S::Res>, B: ServerEndpoint<S::Req, S::Res>, S: Service>
ServerEndpoint<S::Req, S::Res> for CombinedServerEndpoint<A, B, S>
{
fn accept_bi(&self) -> AcceptBiFuture<A, B, S::Req, S::Res> {
let a_fut = if let Some(a) = &self.a {
a.accept_bi()
.map_ok(|(send, recv)| {
(
SendSink::<A, B, S::Req, S::Res>::A(send),
RecvStream::<A, B, S::Req, S::Res>::A(recv),
)
})
.map_err(AcceptBiError::A)
.left_future()
} else {
std::future::pending().right_future()
async fn accept_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
let a_fut = async {
if let Some(a) = &self.a {
let (send, recv) = a.accept_bi().await.map_err(AcceptBiError::A)?;
Ok((SendSink::A(send), RecvStream::A(recv)))
} else {
std::future::pending().await
}
};
let b_fut = if let Some(b) = &self.b {
b.accept_bi()
.map_ok(|(send, recv)| {
(
SendSink::<A, B, S::Req, S::Res>::B(send),
RecvStream::<A, B, S::Req, S::Res>::B(recv),
)
})
.map_err(AcceptBiError::B)
.left_future()
} else {
std::future::pending().right_future()
let b_fut = async {
if let Some(b) = &self.b {
let (send, recv) = b.accept_bi().await.map_err(AcceptBiError::B)?;
Ok((SendSink::B(send), RecvStream::B(recv)))
} else {
std::future::pending().await
}
};
async move {
tokio::select! {
res = a_fut => res,
res = b_fut => res,
}
}
.boxed()
.await
}

type AcceptBiFut = AcceptBiFuture<A, B, S::Req, S::Res>;

fn local_addr(&self) -> &[LocalAddr] {
&self.local_addr
}
Expand Down
92 changes: 13 additions & 79 deletions src/transport/flume.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
//! Memory transport implementation using [flume]
//!
//! [flume]: https://docs.rs/flume/
use futures_lite::{Future, Stream};
use futures_lite::Stream;
use futures_sink::Sink;

use crate::{
transport::{Connection, ConnectionErrors, LocalAddr, ServerEndpoint},
RpcMessage, Service,
};
use core::fmt;
use std::{error, fmt::Display, marker::PhantomData, pin::Pin, result, task::Poll};
use std::{error, fmt::Display, pin::Pin, result, task::Poll};

use super::ConnectionCommon;

Expand Down Expand Up @@ -129,85 +129,17 @@ impl<S: Service> ConnectionErrors for FlumeServerEndpoint<S> {
type OpenError = self::AcceptBiError;
}

type Socket<In, Out> = (self::SendSink<Out>, self::RecvStream<In>);

/// Future returned by [FlumeConnection::open_bi]
pub struct OpenBiFuture<In: RpcMessage, Out: RpcMessage> {
inner: flume::r#async::SendFut<'static, Socket<Out, In>>,
res: Option<Socket<In, Out>>,
}

impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for OpenBiFuture<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpenBiFuture").finish()
}
}

impl<In: RpcMessage, Out: RpcMessage> OpenBiFuture<In, Out> {
fn new(inner: flume::r#async::SendFut<'static, Socket<Out, In>>, res: Socket<In, Out>) -> Self {
Self {
inner,
res: Some(res),
}
}
}

impl<In: RpcMessage, Out: RpcMessage> Future for OpenBiFuture<In, Out> {
type Output = result::Result<Socket<In, Out>, self::OpenBiError>;

fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match Pin::new(&mut self.inner).poll(cx) {
Poll::Ready(Ok(())) => self
.res
.take()
.map(|x| Poll::Ready(Ok(x)))
.unwrap_or(Poll::Pending),
Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenBiError::RemoteDropped)),
Poll::Pending => Poll::Pending,
}
}
}

/// Future returned by [FlumeServerEndpoint::accept_bi]
pub struct AcceptBiFuture<In: RpcMessage, Out: RpcMessage> {
wrapped: flume::r#async::RecvFut<'static, (SendSink<Out>, RecvStream<In>)>,
_p: PhantomData<(In, Out)>,
}

impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for AcceptBiFuture<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AcceptBiFuture").finish()
}
}

impl<In: RpcMessage, Out: RpcMessage> Future for AcceptBiFuture<In, Out> {
type Output = result::Result<(SendSink<Out>, RecvStream<In>), AcceptBiError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.wrapped).poll(cx) {
Poll::Ready(Ok((send, recv))) => Poll::Ready(Ok((send, recv))),
Poll::Ready(Err(_)) => Poll::Ready(Err(AcceptBiError::RemoteDropped)),
Poll::Pending => Poll::Pending,
}
}
}

impl<S: Service> ConnectionCommon<S::Req, S::Res> for FlumeServerEndpoint<S> {
type SendSink = SendSink<S::Res>;
type RecvStream = RecvStream<S::Req>;
}

impl<S: Service> ServerEndpoint<S::Req, S::Res> for FlumeServerEndpoint<S> {
type AcceptBiFut = AcceptBiFuture<S::Req, S::Res>;

fn accept_bi(&self) -> Self::AcceptBiFut {
AcceptBiFuture {
wrapped: self.stream.clone().into_recv_async(),
_p: PhantomData,
}
async fn accept_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> {
self.stream
.recv_async()
.await
.map_err(|_| AcceptBiError::RemoteDropped)
}

fn local_addr(&self) -> &[LocalAddr] {
Expand All @@ -229,9 +161,7 @@ impl<S: Service> ConnectionCommon<S::Res, S::Req> for FlumeConnection<S> {
}

impl<S: Service> Connection<S::Res, S::Req> for FlumeConnection<S> {
type OpenBiFut = OpenBiFuture<S::Res, S::Req>;

fn open_bi(&self) -> Self::OpenBiFut {
async fn open_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
let (local_send, remote_recv) = flume::bounded::<S::Req>(128);
let (remote_send, local_recv) = flume::bounded::<S::Res>(128);
let remote_chan = (
Expand All @@ -242,7 +172,11 @@ impl<S: Service> Connection<S::Res, S::Req> for FlumeConnection<S> {
SendSink(local_send.into_sink()),
RecvStream(local_recv.into_stream()),
);
OpenBiFuture::new(self.sink.clone().into_send_async(remote_chan), local_chan)
self.sink
.send_async(remote_chan)
.await
.map_err(|_| OpenBiError::RemoteDropped)?;
Ok(local_chan)
}
}

Expand Down
Loading

0 comments on commit 1059370

Please sign in to comment.