diff --git a/crates/corro-agent/src/agent/util.rs b/crates/corro-agent/src/agent/util.rs index 0c1238f1..826f239e 100644 --- a/crates/corro-agent/src/agent/util.rs +++ b/crates/corro-agent/src/agent/util.rs @@ -988,12 +988,11 @@ pub async fn process_multiple_changes( if let Some(ts) = last_cleared { let mut booked_writer = agent - .booked() - .blocking_write("process_multiple_changes(update_cleared_ts)"); + .booked() + .blocking_write("process_multiple_changes(update_cleared_ts)"); booked_writer.update_cleared_ts(ts); } - for (_, changeset, _, _) in changesets.iter() { if let Some(ts) = changeset.ts() { let dur = (agent.clock().new_timestamp().get_time() - ts.0).to_duration(); diff --git a/crates/corro-agent/src/api/peer.rs b/crates/corro-agent/src/api/peer.rs index 4522503b..85c60520 100644 --- a/crates/corro-agent/src/api/peer.rs +++ b/crates/corro-agent/src/api/peer.rs @@ -942,16 +942,6 @@ async fn process_sync( Ok(()) } -fn chunk_range + std::cmp::Ord + Copy>( - range: RangeInclusive, - chunk_size: usize, -) -> impl Iterator> { - range.clone().step_by(chunk_size).map(move |block_start| { - let block_end = (block_start + chunk_size as u64).min(*range.end()); - block_start..=block_end - }) -} - fn encode_sync_msg( codec: &mut LengthDelimitedCodec, encode_buf: &mut BytesMut, @@ -1127,22 +1117,7 @@ pub async fn parallel_sync( counter!("corro.sync.client.member", "id" => actor_id.to_string(), "addr" => addr.to_string()).increment(1); - let mut needs = our_sync_state.compute_available_needs(&their_sync_state); - - trace!(%actor_id, self_actor_id = %agent.actor_id(), "computed needs"); - - let cleared_ts = their_sync_state.last_cleared_ts; - - if let Some(ts) = cleared_ts { - if let Some(last_seen) = our_empty_ts.get(&actor_id) { - if last_seen.is_none() || last_seen.unwrap() < ts { - debug!(%actor_id, "got last cleared ts {cleared_ts:?} - out last_seen {last_seen:?}"); - needs.entry(actor_id).or_default().push( SyncNeedV1::Empty { ts: *last_seen }); - } - } - } - - Ok::<_, SyncError>((needs, tx, read)) + Ok::<_, SyncError>((their_sync_state, tx, read)) }.await ) }.instrument(info_span!("sync_client_handshake", %actor_id, %addr)) @@ -1150,15 +1125,15 @@ pub async fn parallel_sync( .collect::)>>() .await; - debug!("collected member needs and such!"); + debug!("collected member state and such!"); #[allow(clippy::manual_try_fold)] let syncers = results .into_iter() .fold(Ok(vec![]), |agg, (actor_id, addr, res)| match res { - Ok((needs, tx, read)) => { + Ok((state, tx, read)) => { let mut v = agg.unwrap_or_default(); - v.push((actor_id, addr, needs, tx, read)); + v.push((actor_id, addr, state, tx, read)); Ok(v) } Err(e) => { @@ -1177,12 +1152,26 @@ pub async fn parallel_sync( })?; let len = syncers.len(); + let actor_state: Vec<_> = syncers.iter().map(|x| (x.0, x.2.clone())).collect(); + let actor_needs = distribute_available_needs(our_sync_state, actor_state); let (readers, mut servers) = { let mut rng = rand::thread_rng(); syncers.into_iter().fold( (Vec::with_capacity(len), Vec::with_capacity(len)), - |(mut readers, mut servers), (actor_id, addr, needs, tx, read)| { + |(mut readers, mut servers), (actor_id, addr, state, tx, read)| { + let mut needs = actor_needs.get(&actor_id).cloned().unwrap_or_default(); + + let cleared_ts = state.last_cleared_ts; + if let Some(ts) = cleared_ts { + if let Some(last_seen) = our_empty_ts.get(&actor_id) { + if last_seen.is_none() || last_seen.unwrap() < ts { + debug!(%actor_id, "got last cleared ts {cleared_ts:?} - out last_seen {last_seen:?}"); + needs.push((actor_id, SyncNeedV1::Empty { ts: *last_seen })); + } + } + } + if needs.is_empty() { trace!(%actor_id, "no needs!"); return (readers, servers); @@ -1191,36 +1180,16 @@ pub async fn parallel_sync( trace!(%actor_id, "needs: {needs:?}"); - debug!(%actor_id, %addr, "needs len: {}", needs.values().map(|needs| needs.iter().map(|need| match need { + debug!(%actor_id, %addr, "needs len: {}", needs.iter().map(|(_, need)| match need { SyncNeedV1::Full {versions} => (versions.end().0 - versions.start().0) as usize + 1, SyncNeedV1::Partial {..} => 0, SyncNeedV1::Empty {..} => 0, - }).sum::()).sum::()); + }).sum::()); - let actor_needs = needs - .into_iter() - .flat_map(|(actor_id, needs)| { - let mut needs: Vec<_> = needs - .into_iter() - .flat_map(|need| match need { - // chunk the versions, sometimes it's 0..=1000000 and that's far too big for a chunk! - SyncNeedV1::Full { versions } => chunk_range(versions, 10) - .map(|versions| SyncNeedV1::Full { versions }) - .collect(), - - need => vec![need], - }) - .collect(); + needs.shuffle(&mut rng); - // NOTE: IMPORTANT! shuffle the vec so we don't keep looping over the same later on - needs.shuffle(&mut rng); - needs - .into_iter() - .map(|need| (actor_id, need)) - .collect::>() - }) - .collect::>(); + let actor_needs = needs.into_iter().collect::>(); servers.push(( actor_id, @@ -1466,6 +1435,40 @@ pub async fn parallel_sync( .sum::()) } +pub fn distribute_available_needs( + mut our_state: SyncStateV1, + mut states: Vec<(ActorId, SyncStateV1)>, +) -> HashMap> { + let mut final_needs: HashMap> = HashMap::new(); + + while !states.is_empty() { + let mut remove_keys = vec![]; + for (actor_id, state) in &states { + let actor_needs = our_state.get_n_needs(&state, 10); + // we can get no more needs from this actor + if actor_needs.is_empty() { + remove_keys.push(*actor_id); + } else { + let needs: Vec<_> = actor_needs + .clone() + .into_iter() + .flat_map(|(actor_id, needs)| { + needs.into_iter().map(move |need| (actor_id, need)) + }) + .collect(); + final_needs + .entry(*actor_id) + .or_default() + .extend_from_slice(&needs); + our_state.merge_needs(&actor_needs); + } + } + states.retain(|(actor, _)| !remove_keys.contains(actor)); + } + + final_needs +} + #[tracing::instrument(skip(agent, bookie, their_actor_id, read, write), fields(actor_id = %their_actor_id), err)] pub async fn serve_sync( agent: &Agent, @@ -1732,6 +1735,7 @@ mod tests { use rand::{Rng, RngCore}; use tempfile::TempDir; use tripwire::Tripwire; + use uuid::Uuid; use crate::{ agent::{process_multiple_changes, setup}, @@ -1739,6 +1743,51 @@ mod tests { }; use super::*; + #[test] + fn test_get_needs() -> eyre::Result<()> { + let original_state: SyncStateV1 = SyncStateV1::default(); + + let actor1 = ActorId(Uuid::new_v4()); + let actor2 = ActorId(Uuid::new_v4()); + let mut actor1_state = SyncStateV1::default(); + actor1_state.heads.insert(actor1, Version(60)); + actor1_state.heads.insert(actor2, Version(20)); + + let mut actor2_state = SyncStateV1::default(); + actor2_state.heads.insert(actor1, Version(60)); + actor2_state.heads.insert(actor2, Version(20)); + + let needs_map = distribute_available_needs( + original_state.clone(), + vec![ + (actor1, actor1_state.clone()), + (actor2, actor2_state.clone()), + ], + ); + println!("{:#?}", needs_map); + + let actor3 = ActorId(Uuid::new_v4()); + let mut actor3_state = SyncStateV1::default(); + actor3_state.heads.insert(actor3, Version(40)); + // actor 2 has seen only till Version(20) + actor2_state.heads.insert(actor3, Version(10)); + // actor 1 has seen up to 40 but has some needs + actor1_state.heads.insert(actor3, Version(40)); + actor1_state + .need + .insert(actor3, vec![(Version(3)..=Version(20))]); + + let needs_map = distribute_available_needs( + original_state, + vec![ + (actor1, actor1_state), + (actor2, actor2_state), + (actor3, actor3_state), + ], + ); + println!("{:#?}", needs_map); + Ok(()) + } #[tokio::test(flavor = "multi_thread")] async fn test_handle_need() -> eyre::Result<()> { diff --git a/crates/corro-types/src/agent.rs b/crates/corro-types/src/agent.rs index e199e53a..571a03aa 100644 --- a/crates/corro-types/src/agent.rs +++ b/crates/corro-types/src/agent.rs @@ -1213,7 +1213,6 @@ impl VersionsSnapshot { conn.prepare_cached("INSERT OR REPLACE INTO __corro_sync_state VALUES (?, ?)")? .execute((self.actor_id, ts))?; } - Ok(()) } diff --git a/crates/corro-types/src/broadcast.rs b/crates/corro-types/src/broadcast.rs index d2ea9b61..eb1b90cd 100644 --- a/crates/corro-types/src/broadcast.rs +++ b/crates/corro-types/src/broadcast.rs @@ -1,4 +1,9 @@ -use std::{cmp, fmt, io, num::NonZeroU32, ops::{Deref, RangeInclusive}, time::Duration}; +use std::{ + cmp, fmt, io, + num::NonZeroU32, + ops::{Deref, RangeInclusive}, + time::Duration, +}; use bytes::{Bytes, BytesMut}; use corro_api_types::{row_to_change, Change}; @@ -165,9 +170,14 @@ impl Changeset { // determine the estimated resource cost of processing a change pub fn processing_cost(&self) -> usize { match self { - Changeset::Empty { versions, .. } => cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20), - Changeset::EmptySet { versions, .. } => versions.iter().map(|versions| cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20)).sum::(), - Changeset::Full { changes, ..} => changes.len(), + Changeset::Empty { versions, .. } => { + cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20) + } + Changeset::EmptySet { versions, .. } => versions + .iter() + .map(|versions| cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20)) + .sum::(), + Changeset::Full { changes, .. } => changes.len(), } } diff --git a/crates/corro-types/src/sync.rs b/crates/corro-types/src/sync.rs index 303dec61..72dda7ee 100644 --- a/crates/corro-types/src/sync.rs +++ b/crates/corro-types/src/sync.rs @@ -85,8 +85,72 @@ pub struct SyncStateV1 { #[speedy(default_on_eof)] pub last_cleared_ts: Option, } - impl SyncStateV1 { + pub fn contains(&self, actor_id: &ActorId, versions: &RangeInclusive) -> bool { + !(self.heads.get(actor_id).cloned().unwrap_or_default() < *versions.end() + || self + .need + .get(actor_id) + .cloned() + .unwrap_or_default() + .iter() + .any(|x| x.start() < versions.end() && versions.start() < x.end())) + } + + pub fn merge_needs(&mut self, needs: &HashMap>) { + for (actor, need) in needs { + for n in need { + match n { + SyncNeedV1::Full { versions } => { + self.merge_full_version(*actor, versions); + } + SyncNeedV1::Partial { version, seqs } => { + let mut delete = false; + self.partial_need.entry(*actor).and_modify(|e| { + if let Some(need_seqs) = e.get_mut(version) { + let mut seqs_set = + RangeInclusiveSet::from_iter(need_seqs.clone().into_iter()); + for seq in seqs { + seqs_set.remove(seq.clone()) + } + *need_seqs = Vec::from_iter(seqs_set.clone().into_iter()); + if need_seqs.is_empty() { + delete = true + } + }; + }); + + // we have gotten all sequence numbers, delete need + if delete { + if let Some(partials) = self.partial_need.get_mut(actor) { + partials.remove(version); + }; + } + } + _ => {} + } + } + } + } + + pub fn merge_full_version(&mut self, actor_id: ActorId, version: &RangeInclusive) { + let head = self.heads.entry(actor_id).or_default(); + if version.end() > head { + // check for gaps + if *head + 1 < *version.start() { + let range = *head + 1..=*version.start() - 1; + self.need.entry(actor_id).or_default().push(range); + } + *head = *version.end(); + } + + self.need.entry(actor_id).and_modify(|e| { + let mut set = RangeInclusiveSet::from_iter(e.clone().into_iter()); + set.remove(version.clone()); + *e = Vec::from_iter(set.into_iter()); + }); + } + pub fn need_len(&self) -> u64 { self.need .values() @@ -124,6 +188,147 @@ impl SyncStateV1 { .unwrap_or(0) } + pub fn get_n_needs(&self, other: &SyncStateV1, n: u64) -> HashMap> { + let mut needs: HashMap> = HashMap::new(); + let mut total = 0; + + for (actor_id, head) in other.heads.iter() { + if *actor_id == self.actor_id { + continue; + } + if *head == Version(0) { + warn!(actor_id = %other.actor_id, "sent a 0 head version for actor id {}", actor_id); + continue; + } + let other_haves = { + let mut haves = RangeInclusiveSet::from_iter([(Version(1)..=*head)].into_iter()); + + // remove needs + if let Some(other_need) = other.need.get(actor_id) { + for need in other_need.iter() { + // create gaps + haves.remove(need.clone()); + } + } + + // remove partials + if let Some(other_partials) = other.partial_need.get(actor_id) { + for (v, _) in other_partials.iter() { + haves.remove(*v..=*v); + } + } + + // we are left with all the versions they fully have! + haves + }; + + if let Some(our_need) = self.need.get(actor_id) { + for range in our_need.iter() { + for overlap in other_haves.overlapping(range) { + let start = cmp::max(range.start(), overlap.start()); + let end = cmp::min(range.end(), overlap.end()); + let left = n - total; + let new_end = cmp::min(*end, *start + left); + needs.entry(*actor_id).or_default().push(SyncNeedV1::Full { + versions: *start..=new_end, + }); + total += end.0 - start.0; + if total >= n { + return needs; + } + } + } + } + + if let Some(our_partials) = self.partial_need.get(actor_id) { + for (v, seqs) in our_partials.iter() { + if other_haves.contains(v) { + needs + .entry(*actor_id) + .or_default() + .push(SyncNeedV1::Partial { + version: *v, + seqs: seqs.clone(), + }); + } else if let Some(other_seqs) = other + .partial_need + .get(actor_id) + .and_then(|versions| versions.get(v)) + { + let max_other_seq = other_seqs.iter().map(|range| *range.end()).max(); + let max_our_seq = seqs.iter().map(|range| *range.end()).max(); + + let end_seq = cmp::max(max_other_seq, max_our_seq); + + if let Some(end) = end_seq { + let mut other_seqs_haves = + RangeInclusiveSet::from_iter([CrsqlSeq(0)..=end]); + + for seqs in other_seqs.iter() { + other_seqs_haves.remove(seqs.clone()); + } + + let seqs = seqs + .iter() + .flat_map(|range| { + other_seqs_haves + .overlapping(range) + .map(|overlap| { + let start = cmp::max(range.start(), overlap.start()); + let end = cmp::min(range.end(), overlap.end()); + *start..=*end + }) + .collect::>>() + }) + .collect::>>(); + + if !seqs.is_empty() { + needs + .entry(*actor_id) + .or_default() + .push(SyncNeedV1::Partial { version: *v, seqs }); + } + total += 1; + if total >= n { + return needs; + } + } + } + } + } + + let left = n - total; + let missing = match self.heads.get(actor_id) { + Some(our_head) => { + if head > our_head { + let new_head = cmp::min(*our_head + left, *head); + Some((*our_head + 1)..=new_head) + } else { + None + } + } + None => { + let new_head = Version(cmp::min(head.0, left)); + Some(Version(1)..=new_head) + } + }; + + if let Some(missing) = missing { + total += missing.end().0 + missing.start().0; + needs + .entry(*actor_id) + .or_default() + .push(SyncNeedV1::Full { versions: missing }); + } + + if total >= n { + return needs; + } + } + + needs + } + pub fn compute_available_needs( &self, other: &SyncStateV1, @@ -157,10 +362,10 @@ impl SyncStateV1 { } // we are left with all the versions they fully have! - haves }; + println!("actor - {actor_id}, haves - {other_haves:?}"); if let Some(our_need) = self.need.get(actor_id) { for range in our_need.iter() { for overlap in other_haves.overlapping(range) { @@ -238,10 +443,21 @@ impl SyncStateV1 { }; if let Some(missing) = missing { - needs - .entry(*actor_id) - .or_default() - .push(SyncNeedV1::Full { versions: missing }); + let mut missing = RangeInclusiveSet::from_iter([missing].into_iter()); + if let Some(other_needs) = other.need.get(actor_id) { + // remove needs + for need in other_needs.iter() { + // create gaps + missing.remove(need.clone()); + } + } + + missing.into_iter().for_each(|v| { + needs + .entry(*actor_id) + .or_default() + .push(SyncNeedV1::Full { versions: v }); + }); } } @@ -300,24 +516,31 @@ pub async fn generate_sync(bookie: &Bookie, self_actor_id: ActorId) -> SyncState let mut last_ts = None; for (actor_id, booked) in actors { - let bookedr = booked - .read(format!("generate_sync:{}", actor_id.as_simple())) - .await; + let (last_version, needs, partials, last_cleared_ts) = { + let bookedr = booked + .read(format!("generate_sync:{}", actor_id.as_simple())) + .await; + ( + bookedr.last(), + bookedr.needed().clone(), + bookedr.partials.clone(), + bookedr.last_cleared_ts(), + ) + }; - let last_version = match { bookedr.last() } { + let last_version = match last_version { None => continue, Some(v) => v, }; - let need: Vec<_> = bookedr.needed().iter().cloned().collect(); + let need: Vec<_> = needs.iter().cloned().collect(); if !need.is_empty() { state.need.insert(actor_id, need); } { - for (v, partial) in bookedr - .partials + for (v, partial) in partials .iter() // don't set partial if it is effectively complete .filter(|(_, partial)| !partial.is_complete()) @@ -333,7 +556,7 @@ pub async fn generate_sync(bookie: &Bookie, self_actor_id: ActorId) -> SyncState } if actor_id == self_actor_id { - last_ts = bookedr.last_cleared_ts(); + last_ts = last_cleared_ts; } state.heads.insert(actor_id, last_version); @@ -391,10 +614,106 @@ impl SyncMessage { #[cfg(test)] mod tests { + use itertools::Itertools; use uuid::Uuid; use super::*; + #[test] + fn test_merge_need() { + let actor1 = ActorId(Uuid::new_v4()); + + let mut state = SyncStateV1::default(); + + let mut needs = HashMap::new(); + needs.insert( + actor1, + vec![SyncNeedV1::Full { + versions: Version(1)..=Version(50), + }], + ); + state.merge_needs(&needs); + assert_eq!(state.heads.get(&actor1).unwrap(), &Version(50)); + assert!(state.need.get(&actor1).is_none()); + assert!(state.partial_need.get(&actor1).is_none()); + + needs.get_mut(&actor1).unwrap().push(SyncNeedV1::Full { + versions: Version(70)..=Version(90), + }); + state.merge_needs(&needs); + assert_eq!(state.heads.get(&actor1).unwrap(), &Version(90)); + assert!(state + .need + .get(&actor1) + .unwrap() + .iter() + .contains(&(Version(51)..=Version(69)))); + assert!(state.partial_need.get(&actor1).is_none()); + + needs.get_mut(&actor1).unwrap().push(SyncNeedV1::Full { + versions: Version(60)..=Version(65), + }); + state.merge_needs(&needs); + assert_eq!(state.heads.get(&actor1).unwrap(), &Version(90)); + assert!(state + .need + .get(&actor1) + .unwrap() + .iter() + .contains(&(Version(51)..=Version(59)))); + assert!(state + .need + .get(&actor1) + .unwrap() + .iter() + .contains(&(Version(66)..=Version(69)))); + assert!(state.partial_need.get(&actor1).is_none()); + + needs.get_mut(&actor1).unwrap().push(SyncNeedV1::Partial { + version: Version(40), + seqs: vec![CrsqlSeq(22)..=CrsqlSeq(25)], + }); + state + .partial_need + .entry(actor1) + .or_default() + .entry(Version(40)) + .or_default() + .extend_from_slice(&vec![ + CrsqlSeq(1)..=CrsqlSeq(10), + CrsqlSeq(20)..=CrsqlSeq(25), + ]); + state.merge_needs(&needs); + assert!(state + .partial_need + .get(&actor1) + .unwrap() + .get(&Version(40)) + .unwrap() + .contains(&(CrsqlSeq(1)..=CrsqlSeq(10)))); + assert!(state + .partial_need + .get(&actor1) + .unwrap() + .get(&Version(40)) + .unwrap() + .contains(&(CrsqlSeq(20)..=CrsqlSeq(21)))); + + let mut needs= HashMap::new(); + needs.insert(actor1, vec![SyncNeedV1::Partial { + version: Version(40), + seqs: vec![CrsqlSeq(1)..=CrsqlSeq(10)], + }, SyncNeedV1::Partial { + version: Version(40), + seqs: vec![CrsqlSeq(20)..=CrsqlSeq(21)], + }]); + state.merge_needs(&needs); + assert!(state + .partial_need + .get(&actor1) + .unwrap().is_empty()); + } + #[test] fn test_compute_available_needs() { let actor1 = ActorId(Uuid::new_v4());