Skip to content

Commit

Permalink
Include fixes to SRTP (#619)
Browse files Browse the repository at this point in the history
* srtp: Fix roll over count calculation

This brings in fixes for ROC over the last couple of years from PION

* srtp: Fix packet length validation

Ported from PION
  • Loading branch information
haaspors authored Oct 19, 2024
1 parent ea8fb77 commit 225b7f5
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 101 deletions.
3 changes: 3 additions & 0 deletions srtp/src/cipher/cipher_aes_cm_hmac_sha1/ctrcipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ impl Cipher for CipherAesCmHmacSha1 {
}

let tail_offset = encrypted_len - (self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE);
if tail_offset < 8 {
return Err(Error::ErrTooShortRtcp);
}

let mut writer = Vec::with_capacity(tail_offset);

Expand Down
3 changes: 3 additions & 0 deletions srtp/src/cipher/cipher_aes_cm_hmac_sha1/opensslcipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ impl Cipher for CipherAesCmHmacSha1 {
}

let tail_offset = encrypted_len - (self.rtcp_auth_tag_len() + SRTCP_INDEX_SIZE);
if tail_offset < 8 {
return Err(Error::ErrTooShortRtcp);
}

let mut writer = Vec::with_capacity(tail_offset);

Expand Down
146 changes: 114 additions & 32 deletions srtp/src/context/context_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ fn test_valid_packet_counter() -> Result<()> {
0xcf, 0x90, 0x1e, 0xa5, 0xda, 0xd3, 0x2c, 0x15, 0x00, 0xa2, 0x24, 0xae, 0xae, 0xaf, 0x00,
0x00,
];
let counter = generate_counter(32846, s.rollover_counter, s.ssrc, &srtp_session_salt);
let counter = generate_counter(32846, (s.index >> 16) as _, s.ssrc, &srtp_session_salt);
assert_eq!(
counter, expected_counter,
"Session Key {counter:?} does not match expected {expected_counter:?}",
Expand All @@ -124,15 +124,13 @@ fn test_valid_packet_counter() -> Result<()> {

#[test]
fn test_rollover_count() -> Result<()> {
let mut s = SrtpSsrcState {
ssrc: DEFAULT_SSRC,
..Default::default()
};
let mut s = SrtpSsrcState::default();

// Set initial seqnum
let roc = s.next_rollover_count(65530);
let (roc, diff, ovf) = s.next_rollover_count(65530);
assert_eq!(roc, 0, "Initial rolloverCounter must be 0");
s.update_rollover_count(65530);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(65530, diff);

// Invalid packets never update ROC
s.next_rollover_count(0);
Expand All @@ -142,64 +140,148 @@ fn test_rollover_count() -> Result<()> {
s.next_rollover_count(0);

// We rolled over to 0
let roc = s.next_rollover_count(0);
let (roc, diff, ovf) = s.next_rollover_count(0);
assert_eq!(roc, 1, "rolloverCounter was not updated after it crossed 0");
s.update_rollover_count(0);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(0, diff);

let roc = s.next_rollover_count(65530);
let (roc, diff, ovf) = s.next_rollover_count(65530);
assert_eq!(
roc, 0,
"rolloverCounter was not updated when it rolled back, failed to handle out of order"
);
s.update_rollover_count(65530);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(65530, diff);

let roc = s.next_rollover_count(5);
let (roc, diff, ovf) = s.next_rollover_count(5);
assert_eq!(
roc, 1,
"rolloverCounter was not updated when it rolled over initial, to handle out of order"
);
s.update_rollover_count(5);

s.next_rollover_count(6);
s.update_rollover_count(6);

s.next_rollover_count(7);
s.update_rollover_count(7);

let roc = s.next_rollover_count(8);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(5, diff);

let (_, diff, _) = s.next_rollover_count(6);
s.update_rollover_count(6, diff);
let (_, diff, _) = s.next_rollover_count(7);
s.update_rollover_count(7, diff);
let (roc, diff, _) = s.next_rollover_count(8);
assert_eq!(
roc, 1,
"rolloverCounter was improperly updated for non-significant packets"
);
s.update_rollover_count(8);
s.update_rollover_count(8, diff);

// valid packets never update ROC
let roc = s.next_rollover_count(0x4000);
let (roc, diff, ovf) = s.next_rollover_count(0x4000);
assert_eq!(
roc, 1,
"rolloverCounter was improperly updated for non-significant packets"
);
s.update_rollover_count(0x4000);

let roc = s.next_rollover_count(0x8000);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(0x4000, diff);
let (roc, diff, ovf) = s.next_rollover_count(0x8000);
assert_eq!(
roc, 1,
"rolloverCounter was improperly updated for non-significant packets"
);
s.update_rollover_count(0x8000);

let roc = s.next_rollover_count(0xFFFF);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(0x8000, diff);
let (roc, diff, ovf) = s.next_rollover_count(0xFFFF);
assert_eq!(
roc, 1,
"rolloverCounter was improperly updated for non-significant packets"
);
s.update_rollover_count(0xFFFF);

let roc = s.next_rollover_count(0);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(0xFFFF, diff);
let (roc, _, ovf) = s.next_rollover_count(0);
assert_eq!(
roc, 2,
"rolloverCounter must be incremented after wrapping, got {roc}"
);
assert!(!ovf, "Should not overflow");

Ok(())
}

#[test]
fn test_rollover_count_overflow() -> Result<()> {
let mut s = SrtpSsrcState {
index: (MAX_ROC as u64) << 16,
..Default::default()
};
s.update_rollover_count(0xFFFF, 0);
let (_, _, ovf) = s.next_rollover_count(0);
assert!(ovf, "Should overflow");

Ok(())
}

#[test]
fn test_rollover_count_2() -> Result<()> {
let mut s = SrtpSsrcState::default();

let (roc, diff, ovf) = s.next_rollover_count(30123);
assert_eq!(roc, 0, "Initial rolloverCounter must be 0");
assert!(!ovf, "Should not overflow");
s.update_rollover_count(30123, diff);

// 62892 = 30123 + (1 << 15) + 1
let (roc, diff, ovf) = s.next_rollover_count(62892);
assert_eq!(roc, 0, "Initial rolloverCounter must be 0");
assert!(!ovf, "Should not overflow");
s.update_rollover_count(62892, diff);
let (roc, diff, ovf) = s.next_rollover_count(204);
assert_eq!(roc, 1, "rolloverCounter was not updated after it crossed 0");
assert!(!ovf, "Should not overflow");
s.update_rollover_count(62892, diff);
let (roc, diff, ovf) = s.next_rollover_count(64535);
assert_eq!(
roc, 0,
"rolloverCounter was not updated when it rolled back, failed to handle out of order"
);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(64535, diff);
let (roc, diff, ovf) = s.next_rollover_count(205);
assert_eq!(
roc, 1,
"rolloverCounter was improperly updated for non-significant packets"
);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(205, diff);
let (roc, diff, ovf) = s.next_rollover_count(1);
assert_eq!(
roc, 1,
"rolloverCounter was improperly updated for non-significant packets"
);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(1, diff);

let (roc, diff, ovf) = s.next_rollover_count(64532);
assert_eq!(
roc, 0,
"rolloverCounter was improperly updated for non-significant packets"
);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(64532, diff);
let (roc, diff, ovf) = s.next_rollover_count(64534);
assert_eq!(
roc, 0,
"index was improperly updated for non-significant packets"
);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(64534, diff);
let (roc, diff, ovf) = s.next_rollover_count(64532);
assert_eq!(
roc, 0,
"index was improperly updated for non-significant packets"
);
assert!(!ovf, "Should not overflow");
s.update_rollover_count(64532, diff);
let (roc, diff, ovf) = s.next_rollover_count(205);
assert_eq!(roc, 1, "index was not updated after it crossed 0");
assert!(!ovf, "Should not overflow");
s.update_rollover_count(205, diff);

Ok(())
}
Expand Down
99 changes: 46 additions & 53 deletions srtp/src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ use crate::protection_profile::*;
pub mod srtcp;
pub mod srtp;

const MAX_ROC_DISORDER: u16 = 100;
const MAX_ROC: u32 = u32::MAX;
const SEQ_NUM_MEDIAN: u16 = 1 << 15;
const SEQ_NUM_MAX: u16 = u16::MAX;

/// Encrypt/Decrypt state for a single SRTP SSRC
#[derive(Default)]
pub(crate) struct SrtpSsrcState {
ssrc: u32,
rollover_counter: u32,
index: u64,
rollover_has_processed: bool,
last_sequence_number: u16,
replay_detector: Option<Box<dyn ReplayDetector + Send + 'static>>,
}

Expand All @@ -40,61 +41,49 @@ pub(crate) struct SrtcpSsrcState {
}

impl SrtpSsrcState {
pub fn next_rollover_count(&self, sequence_number: u16) -> u32 {
let mut roc = self.rollover_counter;

if !self.rollover_has_processed {
} else if sequence_number == 0 {
// We exactly hit the rollover count

// Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
// otherwise we already incremented for disorder
if self.last_sequence_number > MAX_ROC_DISORDER {
roc += 1;
pub fn next_rollover_count(&self, sequence_number: u16) -> (u32, i32, bool) {
let local_roc = (self.index >> 16) as u32;
let local_seq = self.index as u16;

let mut guess_roc = local_roc;

let diff = if self.rollover_has_processed {
let seq = (sequence_number as i32).wrapping_sub(local_seq as i32);
// When local_roc is equal to 0, and entering seq-local_seq > SEQ_NUM_MEDIAN
// judgment, it will cause guess_roc calculation error
if self.index > SEQ_NUM_MEDIAN as _ {
if local_seq < SEQ_NUM_MEDIAN {
if seq > SEQ_NUM_MEDIAN as i32 {
guess_roc = local_roc.wrapping_sub(1);
seq.wrapping_sub(SEQ_NUM_MAX as i32 + 1)
} else {
seq
}
} else if local_seq - SEQ_NUM_MEDIAN > sequence_number {
guess_roc = local_roc.wrapping_add(1);
seq.wrapping_add(SEQ_NUM_MAX as i32 + 1)
} else {
seq
}
} else {
// local_roc is equal to 0
seq
}
} else if self.last_sequence_number < MAX_ROC_DISORDER
&& sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
{
// Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
// So we fell behind, drop to account for jitter
roc -= 1;
} else if sequence_number < MAX_ROC_DISORDER
&& self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
{
// our current is within a MAX_ROCDISORDER of 0
// and our last sequence number was a high sequence number, increment to account for jitter
roc += 1;
}
} else {
0i32
};

roc
(guess_roc, diff, (guess_roc == 0 && local_roc == MAX_ROC))
}

/// https://tools.ietf.org/html/rfc3550#appendix-A.1
pub fn update_rollover_count(&mut self, sequence_number: u16) {
pub fn update_rollover_count(&mut self, sequence_number: u16, diff: i32) {
if !self.rollover_has_processed {
self.index |= sequence_number as u64;
self.rollover_has_processed = true;
} else if sequence_number == 0 {
// We exactly hit the rollover count

// Only update rolloverCounter if lastSequenceNumber is greater then MAX_ROCDISORDER
// otherwise we already incremented for disorder
if self.last_sequence_number > MAX_ROC_DISORDER {
self.rollover_counter += 1;
}
} else if self.last_sequence_number < MAX_ROC_DISORDER
&& sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
{
// Our last sequence number incremented because we crossed 0, but then our current number was within MAX_ROCDISORDER of the max
// So we fell behind, drop to account for jitter
self.rollover_counter -= 1;
} else if sequence_number < MAX_ROC_DISORDER
&& self.last_sequence_number > (MAX_SEQUENCE_NUMBER - MAX_ROC_DISORDER)
{
// our current is within a MAX_ROCDISORDER of 0
// and our last sequence number was a high sequence number, increment to account for jitter
self.rollover_counter += 1;
} else {
self.index = self.index.wrapping_add(diff as _);
}
self.last_sequence_number = sequence_number;
}
}

Expand Down Expand Up @@ -181,12 +170,16 @@ impl Context {

/// roc returns SRTP rollover counter value of specified SSRC.
fn get_roc(&self, ssrc: u32) -> Option<u32> {
self.srtp_ssrc_states.get(&ssrc).map(|s| s.rollover_counter)
self.srtp_ssrc_states
.get(&ssrc)
.map(|s| (s.index >> 16) as _)
}

/// set_roc sets SRTP rollover counter value of specified SSRC.
fn set_roc(&mut self, ssrc: u32, roc: u32) {
self.get_srtp_ssrc_state(ssrc).rollover_counter = roc;
let state = self.get_srtp_ssrc_state(ssrc);
state.index = (roc as u64) << 16;
state.rollover_has_processed = false;
}

/// index returns SRTCP index value of specified SSRC.
Expand All @@ -196,6 +189,6 @@ impl Context {

/// set_index sets SRTCP index value of specified SSRC.
fn set_index(&mut self, ssrc: u32, index: usize) {
self.get_srtcp_ssrc_state(ssrc).srtcp_index = index;
self.get_srtcp_ssrc_state(ssrc).srtcp_index = index % (MAX_SRTCP_INDEX + 1);
}
}
4 changes: 4 additions & 0 deletions srtp/src/context/srtcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ impl Context {
/// EncryptRTCP marshals and encrypts an RTCP packet, writing to the dst buffer provided.
/// If the dst buffer does not have the capacity to hold `len(plaintext) + 14` bytes, a new one will be allocated and returned.
pub fn encrypt_rtcp(&mut self, decrypted: &[u8]) -> Result<Bytes> {
if decrypted.len() < 8 {
return Err(Error::ErrTooShortRtcp);
}

let mut buf = decrypted;
rtcp::header::Header::unmarshal(&mut buf)?;

Expand Down
Loading

0 comments on commit 225b7f5

Please sign in to comment.