Skip to content

Commit

Permalink
use BTreeMap and enable test
Browse files Browse the repository at this point in the history
  • Loading branch information
somtochiama committed Aug 27, 2024
1 parent 92e27b1 commit 9d6ac2a
Showing 1 changed file with 111 additions and 101 deletions.
212 changes: 111 additions & 101 deletions crates/corro-types/src/sync.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::{cmp, collections::HashMap, io, ops::RangeInclusive};

use bytes::BytesMut;
Expand Down Expand Up @@ -79,9 +80,9 @@ pub enum SyncRejectionV1 {
#[derive(Debug, Default, Clone, PartialEq, Readable, Writable, Serialize, Deserialize)]
pub struct SyncStateV1 {
pub actor_id: ActorId,
pub heads: HashMap<ActorId, Version>,
pub need: HashMap<ActorId, Vec<RangeInclusive<Version>>>,
pub partial_need: HashMap<ActorId, HashMap<Version, Vec<RangeInclusive<CrsqlSeq>>>>,
pub heads: BTreeMap<ActorId, Version>,
pub need: BTreeMap<ActorId, Vec<RangeInclusive<Version>>>,
pub partial_need: BTreeMap<ActorId, HashMap<Version, Vec<RangeInclusive<CrsqlSeq>>>>,
#[speedy(default_on_eof)]
pub last_cleared_ts: Option<Timestamp>,
}
Expand Down Expand Up @@ -227,7 +228,7 @@ impl SyncStateV1 {
for overlap in other_haves.overlapping(range) {
let start = cmp::max(range.start(), overlap.start());
let end = cmp::min(range.end(), overlap.end());
let left = n - total;
let left = n - total - 1;
let new_end = cmp::min(*end, *start + left);
needs.entry(*actor_id).or_default().push(SyncNeedV1::Full {
versions: *start..=new_end,
Expand Down Expand Up @@ -301,20 +302,15 @@ impl SyncStateV1 {
}
}

let left = n - total;
let missing = match self.heads.get(actor_id) {
Some(our_head) => {
if head > our_head {
let new_head = cmp::min(*our_head + left, *head);
Some((*our_head + 1)..=new_head)
Some((*our_head + 1)..=*head)
} else {
None
}
}
None => {
let new_head = Version(cmp::min(head.0, left));
Some(Version(1)..=new_head)
}
None => Some(Version(1)..=*head),
};

if let Some(missing) = missing {
Expand All @@ -333,17 +329,18 @@ impl SyncStateV1 {
}
}

missing.into_iter().for_each(|v| {
total += v.end().0 - v.start().0 + 1;
needs
.entry(*actor_id)
.or_default()
.push(SyncNeedV1::Full { versions: v });
});
}
for v in missing {
let left = n - total - 1;
let new_end = cmp::min(*v.start() + left, *v.end());
needs.entry(*actor_id).or_default().push(SyncNeedV1::Full {
versions: *v.start()..=new_end,
});

if total >= n {
return needs;
total += new_end.0 - v.start().0 + 1;
if total >= n {
return needs;
}
}
}
}

Expand Down Expand Up @@ -646,86 +643,99 @@ mod tests {

// TODO: this test occasionally fails because get_n_needs loops over a HashMap which is
// unordered. This could probably be fixed by using a BTreeMap instead
// #[test]
// fn test_get_n_needs() -> eyre::Result<()> {
// let mut original_state: SyncStateV1 = SyncStateV1::default();
//
// // static strings so the order is predictible
// let actor1 = ActorId(Uuid::parse_str("adea794c-8bb8-4ca6-b04b-87ec22348326").unwrap());
// let actor2 = ActorId(Uuid::parse_str("0dea794c-8bb8-4ca6-b04b-87ec22348326").unwrap());
//
// let mut actor1_state = SyncStateV1::default();
// actor1_state.heads.insert(actor1, Version(20));
// let expected = HashMap::from([(
// actor1,
// vec![SyncNeedV1::Full {
// versions: Version(1)..=Version(10),
// }],
// )]);
// assert_eq!(original_state.get_n_needs(&actor1_state, 10), expected);
//
// let mut actor1_state = SyncStateV1::default();
// actor1_state.heads.insert(actor1, Version(8));
// actor1_state.heads.insert(actor2, Version(20));
// let got = original_state.get_n_needs(&actor1_state, 10);
// let expected = HashMap::from([
// (
// actor1,
// vec![SyncNeedV1::Full {
// versions: Version(1)..=Version(8),
// }],
// ),
// (
// actor2,
// vec![SyncNeedV1::Full {
// versions: Version(1)..=Version(2),
// }],
// ),
// ]);
// assert_eq!(got, expected);
//
// let mut actor1_state = SyncStateV1::default();
// actor1_state.heads.insert(actor1, Version(30));
// actor1_state.partial_need.insert(
// actor1,
// HashMap::from([(Version(13), vec![CrsqlSeq(20)..=CrsqlSeq(25)])]),
// );
// actor1_state
// .need
// .insert(actor1, vec![Version(21)..=Version(24)]);
//
// original_state.heads.insert(actor1, Version(10));
// original_state
// .need
// .insert(actor1, vec![Version(4)..=Version(5)]);
// original_state.partial_need.insert(
// actor1,
// HashMap::from([(Version(9), vec![CrsqlSeq(1)..=CrsqlSeq(10)])]),
// );
//
// let got = original_state.get_n_needs(&actor1_state, 10);
// let expected = HashMap::from([(
// actor1,
// vec![
// SyncNeedV1::Full {
// versions: Version(4)..=Version(5),
// },
// SyncNeedV1::Partial {
// version: Version(9),
// seqs: vec![CrsqlSeq(1)..=CrsqlSeq(10)],
// },
// SyncNeedV1::Full {
// versions: Version(11)..=Version(12),
// },
// SyncNeedV1::Full {
// versions: Version(14)..=Version(17),
// },
// ],
// )]);
// assert_eq!(got, expected);
//
// Ok(())
// }
#[test]
fn test_get_n_needs() {
let mut original_state: SyncStateV1 = SyncStateV1::default();

// static strings so the order is predictible
let actor1 = ActorId(Uuid::parse_str("0dea794c-8bb8-4ca6-b04b-87ec22348326").unwrap());
let actor2 = ActorId(Uuid::parse_str("adea794c-8bb8-4ca6-b04b-87ec22348326").unwrap());

let mut actor1_state = SyncStateV1::default();
actor1_state.heads.insert(actor1, Version(20));
let expected = HashMap::from([(
actor1,
vec![SyncNeedV1::Full {
versions: Version(1)..=Version(10),
}],
)]);
assert_eq!(original_state.get_n_needs(&actor1_state, 10), expected);

let mut actor1_state = SyncStateV1::default();
actor1_state.heads.insert(actor1, Version(8));
actor1_state.heads.insert(actor2, Version(20));
let got = original_state.get_n_needs(&actor1_state, 10);
let expected = HashMap::from([
(
actor1,
vec![SyncNeedV1::Full {
versions: Version(1)..=Version(8),
}],
),
(
actor2,
vec![SyncNeedV1::Full {
versions: Version(1)..=Version(2),
}],
),
]);
assert_eq!(got, expected);

let mut actor1_state = SyncStateV1::default();
actor1_state.heads.insert(actor1, Version(30));
actor1_state.partial_need.insert(
actor1,
HashMap::from([(Version(13), vec![CrsqlSeq(20)..=CrsqlSeq(25)])]),
);
actor1_state
.need
.insert(actor1, vec![Version(21)..=Version(24)]);

original_state.heads.insert(actor1, Version(10));
original_state
.need
.insert(actor1, vec![Version(4)..=Version(5)]);
original_state.partial_need.insert(
actor1,
HashMap::from([(Version(9), vec![CrsqlSeq(1)..=CrsqlSeq(10)])]),
);

let got = original_state.get_n_needs(&actor1_state, 10);
let expected = HashMap::from([(
actor1,
vec![
SyncNeedV1::Full {
versions: Version(4)..=Version(5),
},
SyncNeedV1::Partial {
version: Version(9),
seqs: vec![CrsqlSeq(1)..=CrsqlSeq(10)],
},
SyncNeedV1::Full {
versions: Version(11)..=Version(12),
},
SyncNeedV1::Full {
versions: Version(14)..=Version(18),
},
],
)]);
assert_eq!(got, expected);

let mut actor1_state = SyncStateV1::default();
actor1_state.heads.insert(actor1, Version(30));
original_state.heads.insert(actor1, Version(30));
original_state
.need
.insert(actor1, vec![Version(4)..=Version(20)]);
let got = original_state.get_n_needs(&actor1_state, 10);
let expected = HashMap::from([(
actor1,
vec![SyncNeedV1::Full {
versions: Version(4)..=Version(13),
}],
)]);
assert_eq!(got, expected);
}

#[test]
fn test_merge_need() {
Expand Down

0 comments on commit 9d6ac2a

Please sign in to comment.