diff --git a/project/src/quic/lib.rs b/project/src/quic/lib.rs index 0cfbe4b..2399961 100644 --- a/project/src/quic/lib.rs +++ b/project/src/quic/lib.rs @@ -10,7 +10,7 @@ use packet::{ }; use rand::RngCore; use rustls::{ - pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, + pki_types::{CertificateDer, PrivateKeyDer}, quic::{Connection as RustlsConnection, KeyChange, Keys, PacketKeySet, Version}, Side, }; @@ -349,29 +349,20 @@ impl Connection { let initial_local_scid = ConnectionId::generate_with_length(8); let orig_dcid = head.dcid.clone(); - let mut transport_config = transport_parameters::TransportConfig::default(); - transport_config - .original_destination_connection_id(orig_dcid.id()) - .initial_source_connection_id(initial_local_scid.id()) - .stateless_reset_token( - token::StatelessResetToken::new(&hmac_reset_key, &initial_local_scid) - .token - .to_vec(), - ); - - //Allocate byte buffer and encode transport config to create rustls connection - let mut buf = [0u8; 1024]; - let mut param_buffer = OctetsMut::with_slice(&mut buf); - transport_config.encode(&mut param_buffer).unwrap(); - let (data, _) = param_buffer.split_at(param_buffer.off()).unwrap(); + let mut tpc = transport_parameters::TransportConfig::default(); + tpc.original_destination_connection_id = + transport_parameters::OriginalDestinationConnectionId::try_from(orig_dcid.clone())?; + tpc.initial_source_connection_id = + transport_parameters::InitialSourceConnectionId::try_from(initial_local_scid.clone())?; + tpc.stateless_reset_token = transport_parameters::StatelessResetTokenTP::try_from( + token::StatelessResetToken::new(&hmac_reset_key, &initial_local_scid), + )?; + + let data = tpc.encode(Side::Server)?; let conn = RustlsConnection::Server( - rustls::quic::ServerConnection::new( - server_config, - rustls::quic::Version::V1, - data.to_vec(), - ) - .unwrap(), + rustls::quic::ServerConnection::new(server_config, rustls::quic::Version::V1, data) + .unwrap(), ); let initial_space: PacketNumberSpace = PacketNumberSpace { @@ -602,10 +593,11 @@ impl Inner { self.process_payload(header, packet_raw)?; if let Some(tpc) = self.tls_session.quic_transport_parameters() { - self.remote_tpc.update(tpc); + self.remote_tpc.update(tpc).unwrap(); } - if self.remote_tpc.get_original_scid() != self.initial_remote_scid { + if *self.remote_tpc.initial_source_connection_id.get().unwrap() != self.initial_remote_scid + { return Err(terror::Error::quic_protocol_violation( "scids from packet header and transport parameters differ", )); @@ -923,7 +915,7 @@ impl Inner { .sort_by(|a, b| b.cmp(a)); //TODO figure out delay - let ack_delay = 64 * (2 ^ self.remote_tpc.ack_delay_exponent.as_varint()); + let ack_delay = 64 * (2 ^ self.remote_tpc.ack_delay_exponent.get().unwrap().get()); //directly generate ack frame from packet number vector let ack_frame = AckFrame::from_packet_number_vec( @@ -993,7 +985,7 @@ impl PacketNumberSpace { } #[derive(Eq, Hash, PartialEq, Clone)] -struct ConnectionId { +pub struct ConnectionId { id: Vec, } diff --git a/project/src/quic/packet.rs b/project/src/quic/packet.rs index 3fd454c..0393945 100644 --- a/project/src/quic/packet.rs +++ b/project/src/quic/packet.rs @@ -1023,6 +1023,7 @@ pub fn varint_length(num: u64) -> usize { 0..=63 => 1, 64..=16383 => 2, 16384..=1073741823 => 3, + 1073741824..=4611686018427387903 => 4, _ => unreachable!("number exceeded abnormally large size"), } } diff --git a/project/src/quic/terror.rs b/project/src/quic/terror.rs index 816aecc..2fce354 100644 --- a/project/src/quic/terror.rs +++ b/project/src/quic/terror.rs @@ -32,6 +32,16 @@ impl Error { taurus_error!(crypto_error, 0x07); taurus_error!(quic_protocol_violation, 0x0a); taurus_error!(taurus_misc_error, 0xff); + + pub fn quic_transport_error(reason: T, code: QuicTransportError) -> Self + where + T: Into, + { + Self { + code: code as u64, + msg: reason.into(), + } + } } impl fmt::Display for Error { @@ -66,7 +76,6 @@ pub enum QuicTransportError { KeyUpdateError = 0x0e, AeadLimitReached = 0x0f, NoViablePath = 0x10, - CryptoError(CryptoError), } impl fmt::Display for QuicTransportError { @@ -93,7 +102,6 @@ impl fmt::Display for QuicTransportError { QuicTransportError::KeyUpdateError => write!(f, "0x0e key update error"), QuicTransportError::AeadLimitReached => write!(f, "0x0f aead limit reached"), QuicTransportError::NoViablePath => write!(f, "0x10 no viable path"), - QuicTransportError::CryptoError(c) => write!(f, "{} crypto error", c), } } } diff --git a/project/src/quic/token.rs b/project/src/quic/token.rs index ec731ae..e3d2080 100644 --- a/project/src/quic/token.rs +++ b/project/src/quic/token.rs @@ -1,5 +1,6 @@ use crate::ConnectionId; +#[derive(PartialEq, Default)] pub struct StatelessResetToken { pub token: [u8; 0x10], } diff --git a/project/src/quic/transport_parameters.rs b/project/src/quic/transport_parameters.rs index 6011092..02fc0d5 100644 --- a/project/src/quic/transport_parameters.rs +++ b/project/src/quic/transport_parameters.rs @@ -1,274 +1,671 @@ -use octets::{BufferTooShortError, Octets, OctetsMut}; -use std::{ - fmt, - net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, -}; +use core::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; -use crate::token; -use crate::ConnectionId; +use octets::{Octets, OctetsMut}; -//RFC 9000 Section 18.2 -//TODO expand to RFC 9287 & draft-ietf-quic-ack-frequency -pub struct TransportConfig { - pub original_destination_connection_id: Option, - pub max_idle_timeout: TransportParameter, - pub stateless_reset_token: Option, - pub max_udp_payload_size: TransportParameter, - pub initial_max_data: TransportParameter, - pub initial_max_stream_data_bidi_local: TransportParameter, - pub initial_max_stream_data_bidi_remote: TransportParameter, - pub initial_max_stream_data_uni: TransportParameter, - pub initial_max_streams_bidi: TransportParameter, - pub initial_max_streams_uni: TransportParameter, - pub ack_delay_exponent: TransportParameter, - pub max_ack_delay: TransportParameter, - pub disable_active_migration: Option, - pub preferred_address: Option, - pub active_connection_id_limit: TransportParameter, - pub initial_source_connection_id: Option, - pub retry_source_connection_id: Option, +use crate::{terror, token::StatelessResetToken, ConnectionId, MAX_CID_SIZE}; + +trait IOHandler { + fn encode(value: &T, buf: &mut OctetsMut) -> Result<(), octets::BufferTooShortError>; + fn decode(buf: &mut Octets) -> Result; } -impl TransportConfig { - pub fn default() -> Self { - Self { - original_destination_connection_id: None, - max_idle_timeout: TransportParameter::new(0x0001, 0, vec![]), - stateless_reset_token: None, - max_udp_payload_size: TransportParameter::new(0x0003, 2, vec![255, 247]), - initial_max_data: TransportParameter::new(0x0004, 0, vec![]), - initial_max_stream_data_bidi_local: TransportParameter::new(0x0005, 0, vec![]), - initial_max_stream_data_bidi_remote: TransportParameter::new(0x0006, 0, vec![]), - initial_max_stream_data_uni: TransportParameter::new(0x0007, 0, vec![]), - initial_max_streams_bidi: TransportParameter::new(0x0008, 0, vec![]), - initial_max_streams_uni: TransportParameter::new(0x0009, 0, vec![]), - ack_delay_exponent: TransportParameter::new(0x000a, 1, vec![3]), - max_ack_delay: TransportParameter::new(0x000b, 1, vec![25]), - disable_active_migration: Some(TransportParameter::new(0x000c, 0, vec![])), - preferred_address: None, - active_connection_id_limit: TransportParameter::new(0x000e, 1, vec![2]), - initial_source_connection_id: None, - retry_source_connection_id: None, - } +impl IOHandler for ConnectionId { + fn encode( + value: &ConnectionId, + buf: &mut OctetsMut, + ) -> Result<(), octets::BufferTooShortError> { + buf.put_varint(value.len() as u64)?; + buf.put_bytes(&value.id)?; + Ok(()) } - pub fn update(&mut self, buffer: &[u8]) { - let mut buf = octets::Octets::with_slice(buffer); - while let Ok(tp) = TransportParameter::decode(&mut buf) { - match tp.id { - 0x0001 => self.max_idle_timeout = tp, - 0x0002 => self.stateless_reset_token = Some(tp), - 0x0003 => self.max_udp_payload_size = tp, - 0x0004 => self.initial_max_data = tp, - 0x0005 => self.initial_max_stream_data_bidi_local = tp, - 0x0006 => self.initial_max_stream_data_bidi_remote = tp, - 0x0007 => self.initial_max_stream_data_uni = tp, - 0x0008 => self.initial_max_streams_bidi = tp, - 0x0009 => self.initial_max_streams_uni = tp, - 0x000a => self.ack_delay_exponent = tp, - 0x000b => self.max_ack_delay = tp, - 0x000c => self.disable_active_migration = Some(tp), - 0x000e => self.active_connection_id_limit = tp, - 0x000f => self.initial_source_connection_id = Some(tp), - _ => println!( - "unknown transport parameter with id {:x?} and value: {:x?}", - tp.id, tp.value - ), - } - } + fn decode(buf: &mut Octets) -> Result { + let length = buf.get_varint()?; + Ok(ConnectionId::from_vec( + buf.get_bytes(length.try_into().unwrap())?.to_vec(), + )) } +} + +#[derive(PartialEq, Default)] +pub struct VarInt { + value: u64, +} - pub fn get_original_scid(&self) -> ConnectionId { - self.initial_source_connection_id.as_ref().unwrap().as_cid() +impl VarInt { + pub fn get(&self) -> u64 { + self.value } +} + +impl From for VarInt { + fn from(x: u64) -> VarInt { + VarInt { value: x } + } +} - pub fn original_destination_connection_id(&mut self, orig_dcid: &Vec) -> &mut Self { - self.original_destination_connection_id = Some(TransportParameter::new( - 0x0000, - orig_dcid.len() as u64, - orig_dcid.clone(), - )); - self +impl IOHandler for VarInt { + fn encode(value: &VarInt, buf: &mut OctetsMut) -> Result<(), octets::BufferTooShortError> { + let length = crate::packet::varint_length(value.value) as u64; + println!( + "encoding {:x?} with length field {:x?}", + value.value, length + ); + buf.put_varint(length)?; + buf.put_varint(value.value)?; + Ok(()) } - pub fn stateless_reset_token(&mut self, token: Vec) -> &mut Self { - self.stateless_reset_token = Some(TransportParameter::new(0x0002, 16, token)); - self + fn decode(buf: &mut Octets) -> Result { + let _ = buf.get_varint()?; + Ok(Self { + value: buf.get_varint()?, + }) } +} - pub fn preferred_address(&mut self, address: PreferredAddress) -> &mut Self { - self.preferred_address = Some(address); - self +impl IOHandler for StatelessResetToken { + fn encode( + token: &StatelessResetToken, + buf: &mut OctetsMut, + ) -> Result<(), octets::BufferTooShortError> { + buf.put_bytes(&token.token)?; + Ok(()) } - pub fn initial_source_connection_id(&mut self, initial_scid: &Vec) -> &mut Self { - self.initial_source_connection_id = Some(TransportParameter::new( - 0x0000, - initial_scid.len() as u64, - initial_scid.clone(), - )); - self + fn decode(buf: &mut Octets) -> Result { + let length = buf.get_varint()?; + Ok(StatelessResetToken::from( + buf.get_bytes(length.try_into().unwrap())?.to_vec(), + )) } +} + +#[derive(PartialEq, Default)] +pub struct PreferredAddressData { + pub address_v4: Option, + pub address_v6: Option, + pub conn_id: ConnectionId, + pub stateless_reset_token: StatelessResetToken, +} + +impl PreferredAddressData { + pub fn len(&self) -> usize { + let mut length: usize = 0; + if self.address_v4.is_some() { + length += 4 + 2; + } + if self.address_v6.is_some() { + length += 16 + 2; + } + length += 1 + self.conn_id.len(); + length += 16; + length + } +} + +impl IOHandler for PreferredAddressData { + fn decode(_buf: &mut Octets) -> Result { + todo!() + } + + fn encode( + pa: &PreferredAddressData, + buf: &mut OctetsMut, + ) -> Result<(), octets::BufferTooShortError> { + buf.put_varint(0x000d)?; + buf.put_varint(pa.len() as u64)?; + + buf.put_bytes( + pa.address_v4 + .map_or(Ipv4Addr::UNSPECIFIED.octets(), |a| a.ip().octets()) + .as_ref(), + )?; + buf.put_u16(pa.address_v4.map_or(0, |a| a.port()))?; - pub fn retry_source_connection_id(&mut self, retry_scid: &Vec) -> &mut Self { - self.retry_source_connection_id = Some(TransportParameter::new( - 0x0000, - retry_scid.len() as u64, - retry_scid.clone(), - )); - self + buf.put_bytes( + pa.address_v6 + .map_or(Ipv6Addr::UNSPECIFIED.octets(), |a| a.ip().octets()) + .as_ref(), + )?; + buf.put_u16(pa.address_v6.map_or(0, |a| a.port()))?; + + buf.put_u8(pa.conn_id.len() as u8)?; + buf.put_bytes(pa.conn_id.as_slice())?; + + buf.put_bytes(pa.stateless_reset_token.token.as_ref())?; + + Ok(()) } +} + +trait TransportParameter: Sized { + const ID: usize; + + type ValueType; - //TODO check if client or server, currently only server - pub fn encode(&self, buf: &mut OctetsMut) -> Result<(), BufferTooShortError> { - //Server only - if let Some(tp) = &self.original_destination_connection_id { - tp.encode(buf)?; + fn get_value(&self) -> Option<&Self::ValueType>; + + fn decode(buf: &mut Octets) -> Result; + + fn encode(&self, buf: &mut OctetsMut) -> Result<(), terror::Error>; +} + +macro_rules! transport_parameter { + ($name:ident, $id:expr, $valuetype:ty) => { + transport_parameter!($name, $id, $valuetype, <$valuetype as Default>::default()); + }; + ($name:ident, $id:expr, $valuetype:ty, $default:expr) => { + pub struct $name { + value: $valuetype, } - self.max_idle_timeout.encode(buf)?; + impl $name { + //Expose get method so that trait can be private + pub fn get(&self) -> Option<&$valuetype> { + self.get_value() + } + } - //Server only - if let Some(srt) = &self.stateless_reset_token { - srt.encode(buf)?; + impl Default for $name { + fn default() -> Self { + Self { value: $default } + } } - self.max_udp_payload_size.encode(buf)?; - self.initial_max_data.encode(buf)?; - self.initial_max_stream_data_bidi_local.encode(buf)?; - self.initial_max_stream_data_bidi_remote.encode(buf)?; - self.initial_max_stream_data_uni.encode(buf)?; - self.initial_max_streams_bidi.encode(buf)?; - self.initial_max_streams_uni.encode(buf)?; - self.ack_delay_exponent.encode(buf)?; - self.max_ack_delay.encode(buf)?; - - if let Some(dam) = &self.disable_active_migration { - dam.encode(buf)?; + impl TryFrom<$valuetype> for $name { + type Error = terror::Error; + + fn try_from(value: $valuetype) -> Result { + Self { value }.validate() + } } - //Server only - if let Some(pa) = &self.preferred_address { - pa.encode(buf)?; + impl TransportParameter for $name { + const ID: usize = $id; + + type ValueType = $valuetype; + + fn get_value(&self) -> Option<&Self::ValueType> { + Some(&self.value) + } + + fn decode(buf: &mut Octets) -> Result { + Ok(Self { + value: >::decode(buf) + .map_err(|e| terror::Error::buffer_size_error(format!("{}", e)))?, + }) + } + + fn encode(&self, buf: &mut OctetsMut) -> Result<(), terror::Error> { + if self.value != $default { + buf.put_varint(Self::ID as u64) + .map_err(|e| terror::Error::buffer_size_error(format!("{}", e)))?; + + >::encode(&self.value, buf) + .map_err(|e| terror::Error::buffer_size_error(format!("{}", e)))?; + } + Ok(()) + } + } + }; +} + +macro_rules! zero_sized_transport_parameter { + ($name:ident, $id:expr) => { + pub struct $name { + enabled: bool, } - self.active_connection_id_limit.encode(buf)?; + impl $name { + pub fn get(&self) -> bool { + self.enabled + } + } - if let Some(initial_scid) = &self.initial_source_connection_id { - initial_scid.encode(buf)?; + impl From for $name { + fn from(b: bool) -> Self { + Self { enabled: b } + } } - //Server only - if let Some(retry_scid) = &self.retry_source_connection_id { - retry_scid.encode(buf)?; + impl Default for $name { + fn default() -> Self { + Self { enabled: false } + } } - Ok(()) + impl TransportParameter for $name { + const ID: usize = $id; + + type ValueType = bool; + + fn get_value(&self) -> Option<&Self::ValueType> { + Some(&self.enabled) + } + + fn decode(buf: &mut Octets) -> Result { + buf.skip(1) + .map_err(|e| terror::Error::buffer_size_error(format!("{}", e)))?; + Ok(Self { enabled: true }) + } + + fn encode(&self, buf: &mut OctetsMut) -> Result<(), terror::Error> { + if self.enabled { + buf.put_varint(Self::ID as u64) + .map_err(|e| terror::Error::buffer_size_error(format!("{}", e)))?; + buf.put_u8(0x00) + .map_err(|e| terror::Error::buffer_size_error(format!("{}", e)))?; + } + Ok(()) + } + } + }; +} + +transport_parameter!(OriginalDestinationConnectionId, 0x00, ConnectionId); + +impl OriginalDestinationConnectionId { + fn validate(self) -> Result { + if self.value.len() > MAX_CID_SIZE && self.value.len() > 0 { + return Err(terror::Error::quic_transport_error( + "malformed, badly formatted or absent original destination connection id", + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) + } +} + +transport_parameter!(MaxIdleTimeout, 0x01, VarInt); + +impl MaxIdleTimeout { + fn validate(self) -> Result { + Ok(self) } } -//RFC 9000 Section 18: Transport Parameter Encoding -//Id and length fields are defined as variable length integers, we store them in human -//readable form and only encode/decode if needed with Octets -pub struct TransportParameter { - id: u64, - length: u64, - value: Vec, +transport_parameter!(StatelessResetTokenTP, 0x02, StatelessResetToken); + +impl StatelessResetTokenTP { + fn validate(self) -> Result { + Ok(self) + } } -impl TransportParameter { - pub fn new(id: u64, length: u64, value: Vec) -> Self { - Self { id, length, value } +transport_parameter!(MaxUdpPayloadSize, 0x03, VarInt, 0xfff7.into()); + +impl MaxUdpPayloadSize { + fn validate(self) -> Result { + if (1200..=65527).contains(&self.value.value) { + return Err(terror::Error::quic_transport_error( + format!("invalid maximum udp payload size of {:?}", self.value.value), + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) } +} - pub fn as_varint(&self) -> u64 { - let mut b = Octets::with_slice(&self.value); - b.get_varint().unwrap() +transport_parameter!(InitialMaxData, 0x04, VarInt); + +impl InitialMaxData { + fn validate(self) -> Result { + Ok(self) } +} + +transport_parameter!(InitialMaxStreamDataBidiLocal, 0x05, VarInt); - pub fn as_cid(&self) -> ConnectionId { - ConnectionId::from_vec(self.value.clone()) +impl InitialMaxStreamDataBidiLocal { + fn validate(self) -> Result { + Ok(self) } +} - pub fn decode(data: &mut Octets<'_>) -> Result { - let id = data.get_varint()?; - let length = data.get_varint()?; - let value = data.get_bytes(length as usize).unwrap(); - Ok(Self { - id, - length, - value: value.to_vec(), - }) +transport_parameter!(InitialMaxStreamDataBidiRemote, 0x06, VarInt); + +impl InitialMaxStreamDataBidiRemote { + fn validate(self) -> Result { + Ok(self) } +} - pub fn encode(&self, data: &mut OctetsMut<'_>) -> Result<(), BufferTooShortError> { - data.put_varint(self.id)?; - data.put_varint(self.length)?; - data.put_bytes(&self.value)?; - Ok(()) +transport_parameter!(InitialMaxStreamDataUni, 0x07, VarInt); + +impl InitialMaxStreamDataUni { + fn validate(self) -> Result { + Ok(self) } } -impl fmt::Display for TransportParameter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "0x{:x} len:0x{:x} data:{}", - self.id, - self.length, - self.value - .iter() - .map(|val| format!("{:x}", val)) - .collect::>() - .join(" ") - ) +transport_parameter!(InitialMaxStreamsBidi, 0x08, VarInt); + +impl InitialMaxStreamsBidi { + fn validate(self) -> Result { + //if a stream id were to be greater than 2^60 it wouldnt be encodable as stream id + if self.value.value > 2u64.pow(60) { + return Err(terror::Error::quic_transport_error( + format!( + "invalid maximum bidirectional streams {:?}", + self.value.value + ), + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) } } -pub struct PreferredAddress { - pub address_v4: Option, - pub address_v6: Option, - pub conn_id: ConnectionId, - pub stateless_reset_token: token::StatelessResetToken, +transport_parameter!(InitialMaxStreamsUni, 0x09, VarInt); + +impl InitialMaxStreamsUni { + fn validate(self) -> Result { + //if a stream id were to be greater than 2^60 it wouldnt be encodable as stream id + if self.value.value > 2u64.pow(60) { + return Err(terror::Error::quic_transport_error( + format!( + "invalid maximum unidirectional streams {:?}", + self.value.value + ), + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) + } +} + +transport_parameter!(AckDelayExponent, 0x0a, VarInt, 0x3.into()); + +impl AckDelayExponent { + fn validate(self) -> Result { + if self.value.value > 20 { + return Err(terror::Error::quic_transport_error( + "ack delay exponent cant be greater than 20", + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) + } +} + +transport_parameter!(MaxAckDelay, 0x0b, VarInt, 0x19.into()); + +impl MaxAckDelay { + fn validate(self) -> Result { + if self.value.value >= 2u64.pow(14) { + return Err(terror::Error::quic_transport_error( + "max ack delay cant be equal or greater than 2^14", + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) + } } +zero_sized_transport_parameter!(DisableActiveMigration, 0x0c); +transport_parameter!(PreferredAddress, 0x0d, PreferredAddressData); + impl PreferredAddress { - pub fn len(&self) -> usize { - let mut length: usize = 0; - if self.address_v4.is_some() { - length += 4 + 2; + fn validate(self) -> Result { + if self.value.address_v4.is_some() || self.value.address_v6.is_some() { + return Err(terror::Error::quic_transport_error( + "preferred address parameter must have at least one address set", + terror::QuicTransportError::TransportParameterError, + )); } - if self.address_v6.is_some() { - length += 16 + 2; + Ok(self) + } +} + +transport_parameter!(ActiveConnectionIdLimit, 0x0e, VarInt, 0x02.into()); + +impl ActiveConnectionIdLimit { + fn validate(self) -> Result { + if self.value.value < 2 { + return Err(terror::Error::quic_transport_error( + "active connection id limit must be at least 2", + terror::QuicTransportError::TransportParameterError, + )); } - length += 1 + self.conn_id.len(); - length += 16; - length + Ok(self) } +} - pub fn encode(&self, data: &mut OctetsMut<'_>) -> Result<(), BufferTooShortError> { - data.put_varint(0x000d)?; - data.put_varint(self.len() as u64)?; +transport_parameter!(InitialSourceConnectionId, 0x0f, ConnectionId); - data.put_bytes( - self.address_v4 - .map_or(Ipv4Addr::UNSPECIFIED.octets(), |a| a.ip().octets()) - .as_ref(), - )?; - data.put_u16(self.address_v4.map_or(0, |a| a.port()))?; +impl InitialSourceConnectionId { + fn validate(self) -> Result { + if self.value.len() > MAX_CID_SIZE && self.value.len() > 0 { + return Err(terror::Error::quic_transport_error( + "malformed, badly formatted or absent initial source connection id", + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) + } +} - data.put_bytes( - self.address_v6 - .map_or(Ipv6Addr::UNSPECIFIED.octets(), |a| a.ip().octets()) - .as_ref(), - )?; - data.put_u16(self.address_v6.map_or(0, |a| a.port()))?; +transport_parameter!(RetrySourceConnectionId, 0x10, ConnectionId); - data.put_u8(self.conn_id.len() as u8)?; - data.put_bytes(self.conn_id.as_slice())?; +impl RetrySourceConnectionId { + fn validate(self) -> Result { + if self.value.len() > MAX_CID_SIZE && self.value.len() > 0 { + return Err(terror::Error::quic_transport_error( + "malformed, badly formatted or absent retry source connection id", + terror::QuicTransportError::TransportParameterError, + )); + } + Ok(self) + } +} + +//Params outside of RFC 9000 +zero_sized_transport_parameter!(Grease, 0xb6); +transport_parameter!(MaxDatagramFrameSize, 0x20, VarInt, 0x00.into()); + +impl MaxDatagramFrameSize { + fn validate(self) -> Result { + Ok(self) + } +} - data.put_bytes(self.stateless_reset_token.token.as_ref())?; +zero_sized_transport_parameter!(GreaseQuicBit, 0x2ab2); +transport_parameter!(MinAckDelay, 0xff04de1a, VarInt); + +impl MinAckDelay { + fn validate(self) -> Result { + Ok(self) + } +} + +//RFC 9000 Section 18.2 +//TODO expand to RFC 9287 & draft-ietf-quic-ack-frequency +#[derive(Default)] +pub struct TransportConfig { + pub original_destination_connection_id: OriginalDestinationConnectionId, + pub max_idle_timeout: MaxIdleTimeout, + pub stateless_reset_token: StatelessResetTokenTP, + pub max_udp_payload_size: MaxUdpPayloadSize, + pub initial_max_data: InitialMaxData, + pub initial_max_stream_data_bidi_local: InitialMaxStreamDataBidiLocal, + pub initial_max_stream_data_bidi_remote: InitialMaxStreamDataBidiRemote, + pub initial_max_stream_data_uni: InitialMaxStreamDataUni, + pub initial_max_streams_bidi: InitialMaxStreamsBidi, + pub initial_max_streams_uni: InitialMaxStreamsUni, + pub ack_delay_exponent: AckDelayExponent, + pub max_ack_delay: MaxAckDelay, + pub disable_active_migration: DisableActiveMigration, + pub preferred_address: PreferredAddress, + pub active_connection_id_limit: ActiveConnectionIdLimit, + pub initial_source_connection_id: InitialSourceConnectionId, + pub retry_source_connection_id: RetrySourceConnectionId, + + //Params outside of RFC 9000 + pub grease: Grease, + pub max_datagram_frame_size: MaxDatagramFrameSize, + pub grease_quic_bit: GreaseQuicBit, + pub min_ack_delay: MinAckDelay, +} + +impl TransportConfig { + pub fn decode(buf: &[u8]) -> Result { + let mut tpc = Self::default(); + tpc.update(buf)?; + Ok(tpc) + } + + pub fn update(&mut self, buf: &[u8]) -> Result<(), terror::Error> { + let mut b = octets::Octets::with_slice(buf); + while let Ok(id) = b.get_varint() { + match id { + 0x00 => { + self.original_destination_connection_id = + OriginalDestinationConnectionId::decode(&mut b)? + } + 0x0001 => self.max_idle_timeout = MaxIdleTimeout::decode(&mut b)?, + 0x0002 => self.stateless_reset_token = StatelessResetTokenTP::decode(&mut b)?, + 0x0003 => self.max_udp_payload_size = MaxUdpPayloadSize::decode(&mut b)?, + 0x0004 => self.initial_max_data = InitialMaxData::decode(&mut b)?, + 0x0005 => { + self.initial_max_stream_data_bidi_local = + InitialMaxStreamDataBidiLocal::decode(&mut b)? + } + 0x0006 => { + self.initial_max_stream_data_bidi_remote = + InitialMaxStreamDataBidiRemote::decode(&mut b)? + } + 0x0007 => { + self.initial_max_stream_data_uni = InitialMaxStreamDataUni::decode(&mut b)? + } + 0x0008 => self.initial_max_streams_bidi = InitialMaxStreamsBidi::decode(&mut b)?, + 0x0009 => self.initial_max_streams_uni = InitialMaxStreamsUni::decode(&mut b)?, + 0x000a => self.ack_delay_exponent = AckDelayExponent::decode(&mut b)?, + 0x000b => self.max_ack_delay = MaxAckDelay::decode(&mut b)?, + 0x000c => self.disable_active_migration = DisableActiveMigration::decode(&mut b)?, + 0x000e => { + self.active_connection_id_limit = ActiveConnectionIdLimit::decode(&mut b)? + } + 0x000f => { + self.initial_source_connection_id = InitialSourceConnectionId::decode(&mut b)? + } + 0x0010 => { + self.retry_source_connection_id = RetrySourceConnectionId::decode(&mut b)? + } + 0x00b6 => self.grease = Grease::decode(&mut b)?, + 0x0020 => self.max_datagram_frame_size = MaxDatagramFrameSize::decode(&mut b)?, + 0x2ab2 => self.grease_quic_bit = GreaseQuicBit::decode(&mut b)?, + 0xff04de1a => self.min_ack_delay = MinAckDelay::decode(&mut b)?, + _ => { + let data = b + .get_bytes_with_varint_length() + .map_err(|e| terror::Error::buffer_size_error(format!("{}", e)))?; + println!("unknown transport parameter with id {:x?}: {:x?}", id, data); + } + } + } Ok(()) } + + pub fn encode(&self, _side: rustls::Side) -> Result, terror::Error> { + let mut vec = vec![0u8; 1024]; + let written: usize; + + { + let mut buf = OctetsMut::with_slice(&mut vec); + + macro_rules! write_tp { + ($name:ident) => { + self.$name.encode(&mut buf)?; + }; + } + + write_tp!(original_destination_connection_id); + write_tp!(max_idle_timeout); + write_tp!(stateless_reset_token); + write_tp!(max_udp_payload_size); + write_tp!(initial_max_data); + write_tp!(initial_max_stream_data_bidi_local); + write_tp!(initial_max_stream_data_bidi_remote); + write_tp!(initial_max_stream_data_uni); + write_tp!(initial_max_streams_bidi); + write_tp!(initial_max_streams_uni); + write_tp!(ack_delay_exponent); + write_tp!(max_ack_delay); + write_tp!(disable_active_migration); + write_tp!(preferred_address); + write_tp!(active_connection_id_limit); + write_tp!(initial_source_connection_id); + write_tp!(retry_source_connection_id); + + //other transport params only after GREASE + write_tp!(grease); + write_tp!(max_datagram_frame_size); + write_tp!(grease_quic_bit); + write_tp!(min_ack_delay); + + written = buf.off(); + } + + vec.resize(written, 0x00); + + Ok(vec) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transport_parameter_encoding() { + let tpc = TransportConfig { + grease: Grease::from(true), + min_ack_delay: MinAckDelay::try_from(VarInt::from(1000)).unwrap(), + ack_delay_exponent: AckDelayExponent::try_from(VarInt::from(5)).unwrap(), + original_destination_connection_id: OriginalDestinationConnectionId::try_from( + ConnectionId::from_vec(vec![0xab, 0xab, 0xab, 0xab]), + ) + .unwrap(), + ..TransportConfig::default() + }; + + let result = tpc.encode(rustls::Side::Server).unwrap(); + + let expected = vec![ + 0x00, 0x04, 0xab, 0xab, 0xab, 0xab, 0x0a, 0x01, 0x05, 0x40, 0xb6, 0x00, 0xc0, 0x00, + 0x00, 0x00, 0xff, 0x04, 0xde, 0x1a, 0x02, 0x43, 0xe8, + ]; + + assert_eq!(result, expected); + } + + #[test] + #[should_panic] + fn test_transport_parameter_validation() { + AckDelayExponent::try_from(VarInt::from(21)).unwrap(); + } + + #[test] + fn test_transport_parameter_decoding() { + let raw = vec![ + 0x01, 0x02, 0x67, 0x10, 0x03, 0x02, 0x45, 0xc0, 0x04, 0x08, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0x05, 0x04, 0x80, 0x13, 0x12, 0xd0, 0x06, 0x04, 0x80, 0x13, + 0x12, 0xd0, 0x07, 0x04, 0x80, 0x13, 0x12, 0xd0, 0x08, 0x02, 0x40, 0x64, 0x09, 0x02, + 0x40, 0x64, 0x0e, 0x01, 0x05, 0x40, 0xb6, 0x00, 0x20, 0x04, 0x80, 0x00, 0xff, 0xff, + 0x0f, 0x08, 0x03, 0x25, 0x05, 0xd0, 0x49, 0x6f, 0x4c, 0x31, 0x6a, 0xb2, 0x00, 0xc0, + 0x00, 0x00, 0x00, 0xff, 0x04, 0xde, 0x1a, 0x02, 0x43, 0xe8, + ]; + + let tpc = TransportConfig::decode(&raw).unwrap(); + + assert_eq!(tpc.max_idle_timeout.get().unwrap().get(), 10000); + assert_eq!( + tpc.initial_source_connection_id.get().unwrap().id, + vec![0x03, 0x25, 0x05, 0xd0, 0x49, 0x6f, 0x4c, 0x31] + ); + assert_eq!(tpc.active_connection_id_limit.get().unwrap().get(), 5); + assert!(tpc.grease.get()); + } }