Skip to content

Commit

Permalink
refactored send task loop, reintegrated packet processing into new ar…
Browse files Browse the repository at this point in the history
…chitecture
  • Loading branch information
ilumary committed Jun 2, 2024
1 parent c84ca3e commit a392b33
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 105 deletions.
209 changes: 107 additions & 102 deletions project/src/quic/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use rustls::{
use std::{collections::HashMap, fmt, net::SocketAddr, sync::Arc};
use tokio::{
net::UdpSocket as TokioUdpSocket,
sync::{mpsc, RwLock},
sync::{mpsc, oneshot, RwLock},
task::JoinHandle,
};

Expand All @@ -36,9 +36,6 @@ pub struct Distributor {
//RFC 2104, used to generate reset tokens from connection ids
hmac_reset_key: ring::hmac::Key,

//shared channel for sending packets
sending_handle: mpsc::Sender<packet::Datagram>,

//channels for each connection
connection_send_handles: HashMap<ConnectionId, mpsc::Sender<packet::EarlyDatagram>>,

Expand All @@ -48,12 +45,9 @@ pub struct Distributor {

impl Distributor {
fn new(key: ring::hmac::Key, server_cfg: Option<Arc<rustls::ServerConfig>>) -> Self {
//placeholder for default initialization with 1 capacity
let (tx, _) = mpsc::channel::<packet::Datagram>(1);
Self {
server_config: server_cfg,
hmac_reset_key: key,
sending_handle: tx,
connection_send_handles: HashMap::new(),
cancellation_token: tokio_util::sync::CancellationToken::new(),
}
Expand All @@ -62,26 +56,30 @@ impl Distributor {

pub struct Acceptor {
//queue for initial packets
rx: mpsc::Receiver<(Vec<u8>, String, Header, TSDistributor)>,
rx: mpsc::Receiver<packet::InitialDatagram>,
}

impl Acceptor {
async fn accept(&mut self) -> Option<Connection> {
let (packet, remote, header, tsd) = self.rx.recv().await.unwrap();

//move into static connection function
println!("accepting connection from {}", &remote);
let (packet, remote, header, tsd, socket) = self.rx.recv().await.unwrap();

//TODO move cancellation token in connection, listen in connection for cancel without channel
let new_conn = Connection::early_connection((packet, remote, header), tsd)
.await
.unwrap();
let (connection, ready) =
Connection::early_connection((packet, remote, header), tsd, socket)
.await
.unwrap();

//while(conn.poll_state() != established) {
// match state
// new_conn.handle_packet() //recv().await
//}
match ready.await {
Ok(terror::QuicTransportError::NoError) => return Some(connection),
Ok(quic_error) => {
eprintln!("Connection could not be established: {}", quic_error);
}
Err(error) => {
eprintln!("Error retrieving connection state: {}", error);
}
}

// force terminate connection recv task
connection.abort();
None
}
}
Expand Down Expand Up @@ -164,11 +162,10 @@ impl ServerConfig {
}

pub struct Endpoint {
socket: Arc<TokioUdpSocket>,
address: String,

//task handles for recv and send loops
recv_loop_handle: Option<JoinHandle<Result<u64, terror::Error>>>,
send_loop_handle: Option<JoinHandle<Result<u64, terror::Error>>>,

//stores connection channel sender handles
distributor: TSDistributor,
Expand All @@ -185,21 +182,22 @@ impl Endpoint {
)));

Endpoint {
socket: Arc::new(
TokioUdpSocket::bind(addr)
.await
.expect("fatal error: socket bind failed"),
),
address: addr.to_string(),
recv_loop_handle: None,
send_loop_handle: None,
distributor,
}
}

pub async fn start_acceptor(&mut self) -> Acceptor {
//socket lives via Arc in the recv loop and in each connection
let socket = Arc::new(
TokioUdpSocket::bind(self.address.clone())
.await
.expect("fatal error: socket bind failed"),
);

//TODO add to configuration of (server) endpoint maximum number of buffered inital packets
let (tx_initial, rx_initial) = mpsc::channel::<packet::InitialDatagram>(64);
let recv_socket = self.socket.clone();
let dist = self.distributor.clone();
let cancellation_token = { dist.read().await.cancellation_token.clone() };

Expand All @@ -209,7 +207,7 @@ impl Endpoint {
.take(u16::MAX as usize)
.collect::<Vec<_>>();

let (size, src_addr) = match recv_socket.recv_from(&mut buffer).await {
let (size, src_addr) = match socket.recv_from(&mut buffer).await {
Ok((size, src_addr)) => (size, src_addr),
Err(error) => {
println!("Error while receiving datagram: {:?}", error);
Expand All @@ -229,7 +227,13 @@ impl Endpoint {
//stop accepting new connections when entering graceful shutdown
if partial_decode.is_inital() && !cancellation_token.is_cancelled() {
tx_initial
.send((buffer, src_addr.to_string(), partial_decode, dist.clone()))
.send((
buffer,
src_addr.to_string(),
partial_decode,
dist.clone(),
socket.clone(),
))
.await
.unwrap();
} else {
Expand Down Expand Up @@ -258,45 +262,16 @@ impl Endpoint {
}
}));

let send_socket = self.socket.clone();
let (tx, mut rx) = mpsc::channel::<packet::Datagram>(64);
let dist = self.distributor.clone();
let cancellation_token = { dist.read().await.cancellation_token.clone() };

{
//set sending handle
self.distributor.write().await.sending_handle = tx;
}

self.send_loop_handle = Some(tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let size = match send_socket.send_to(&msg.0, &msg.1).await {
Ok(size) => size,
Err(error) => {
return Err(terror::Error::socket_error(format!("{}", error)));
}
};

println!("sent {} bytes to {}", size, &msg.1);

if cancellation_token.is_cancelled() {
{
//return if all connections have been succesfully shut down
if dist.read().await.connection_send_handles.is_empty() {
return Ok(0);
}
}
}
}
Ok(0)
}));

Acceptor { rx: rx_initial }
}
}

pub struct Connection {
inner: Inner,

//send socket
socket: Arc<TokioUdpSocket>,
//stream acceptor handles
//bidi_stream_r: mpsc::Receiver<stream::data>,
//uni_stream_r: mpsc::Receiver<stream::data>,
}
Expand All @@ -306,7 +281,8 @@ impl Connection {
async fn early_connection(
inital_datagram: packet::EarlyDatagram,
tsd: TSDistributor,
) -> Result<Self, terror::Error> {
socket: Arc<TokioUdpSocket>,
) -> Result<(Self, oneshot::Receiver<terror::QuicTransportError>), terror::Error> {
let (mut buffer, src_addr, mut head) = inital_datagram;
let (hmac_reset_key, server_config) = {
let t = tsd.read().await;
Expand All @@ -321,8 +297,7 @@ impl Connection {
Version::V1,
Side::Server,
&head.dcid,
)
.await?;
);

let header_length = match head.decrypt(&mut buffer, ikp.remote.header.as_ref()) {
Ok(s) => s,
Expand Down Expand Up @@ -398,6 +373,8 @@ impl Connection {
.insert(initial_local_scid.clone(), transmit_q);
}

let (conn_read_tx, conn_ready_rx) = oneshot::channel::<terror::QuicTransportError>();

let inner = Inner::new(
Side::Server,
head.version,
Expand All @@ -407,48 +384,41 @@ impl Connection {
initial_local_scid,
src_addr.parse().unwrap(),
initial_space,
conn_read_tx,
);

let conn = Self { inner };
let mut conn = Self { inner, socket };

//process inital packet inside connection, all subsequent packets are sent through channel
//self.inner.accept(payload, head).await;
conn.inner.accept(&head, &mut buffer)?;

//send answer to initial packet
//conn.send((buffer, conn.remote))

Ok(conn.start(recv_q).await.unwrap())
//start recv loop to process more incoming packets
Ok((conn.start(recv_q).await.unwrap(), conn_ready_rx))
}

async fn derive_initial_keyset(
fn derive_initial_keyset(
server_cfg: Arc<rustls::ServerConfig>,
version: Version,
side: Side,
dcid: &ConnectionId,
) -> Result<Keys, terror::Error> {
println!(
"{:?} available cipher suites: {:?}",
&server_cfg.crypto_provider().cipher_suites.len(),
&server_cfg.crypto_provider().cipher_suites,
);

let cipher_suites = &server_cfg.crypto_provider().cipher_suites;

if cipher_suites.is_empty() {
return Err(terror::Error::no_cipher_suite("no supported cipher suites"));
}

//use first availible tls 1.3 crypto suite for now
if let Some(tls13_cipher_suite) = cipher_suites[0].tls13() {
return Ok(Keys::initial(
version,
tls13_cipher_suite,
tls13_cipher_suite.quic.unwrap(),
&dcid.id,
side,
));
}

Err(terror::Error::no_cipher_suite(
"no available tls 1.3 cipher suite",
))
) -> Keys {
/* for now only the rustls ring provider is supported, so we may omit numerous checks */
server_cfg
.crypto_provider()
.cipher_suites
.iter()
.find_map(|cs| match (cs.suite(), cs.tls13()) {
(rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
Some(suite.quic_suite())
}
_ => None,
})
.flatten()
.expect("default crypto provider failed to provide initial cipher suite")
.keys(&dcid.id, side, version)
}

//starts recv_q, feeds into all other queues
Expand All @@ -467,6 +437,23 @@ impl Connection {
Ok(self)
}

async fn send(&self, dgram: packet::Datagram) -> Result<(), terror::Error> {
let size = match self.socket.send_to(&dgram.0, &dgram.1).await {
Ok(size) => size,
Err(error) => {
return Err(terror::Error::socket_error(format!("{}", error)));
}
};

println!("sent {} bytes to {}", size, &dgram.1);

Ok(())
}

fn abort(self) {
self.inner.loop_handle.unwrap().abort();
}

//pub async fn accept_bidi_stream() {}

//pub async fn accept_uni_stream() {}
Expand All @@ -477,8 +464,12 @@ impl Connection {
}

struct Inner {
//packet recv loop for connection
loop_handle: Option<JoinHandle<Result<u64, terror::Error>>>,

//oneshot channel
conn_ready: oneshot::Sender<terror::QuicTransportError>,

//side
side: Side,

Expand Down Expand Up @@ -530,9 +521,11 @@ impl Inner {
initial_local_scid: ConnectionId,
remote_address: SocketAddr,
initial_space: PacketNumberSpace,
conn_ready: oneshot::Sender<terror::QuicTransportError>,
) -> Self {
Self {
loop_handle: None,
conn_ready,
side,
version,
tls_session,
Expand Down Expand Up @@ -572,8 +565,17 @@ impl Inner {
}

//accepts new connection, passes payload to process_payload
fn accept(&mut self, header: &Header, payload_raw: &mut [u8]) -> Result<(), terror::Error> {
let mut payload = octets::OctetsMut::with_slice(payload_raw);
fn accept(&mut self, header: &Header, packet_raw: &mut [u8]) -> Result<(), terror::Error> {
let mut payload = octets::OctetsMut::with_slice(packet_raw);

//skip forth to packet payload
if let Err(error) = payload.skip(header.raw_length + header.packet_num_length as usize + 1)
{
return Err(terror::Error::packet_size_error(format!(
"header is longer than packet {}",
error
)));
};

self.process_payload(header, &mut payload);

Expand Down Expand Up @@ -637,10 +639,9 @@ impl Inner {
}

println!(
"stream_id {:#x} len {:#x} data {:#x?}",
"stream_id {:#x} len {:#x}",
stream_frame.stream_id,
stream_data.len(),
stream_data
);
} //STREAM
0x10 => {
Expand Down Expand Up @@ -745,6 +746,7 @@ impl Inner {
KeyChange::Handshake { keys } => keys,
KeyChange::OneRtt { keys, next } => {
self.next_secrets = Some(next);
//connection may be established here, investigate
keys
}
};
Expand Down Expand Up @@ -828,6 +830,9 @@ impl Inner {
}
};

//THIS IS THE WAY
//buf.put_varint_with_len(packet_length_min, 3);

let header_length = buf.off();

//calculate packet number offset
Expand Down
Loading

0 comments on commit a392b33

Please sign in to comment.