Skip to content

Commit

Permalink
feat: +/- num of workers per queue dynamically
Browse files Browse the repository at this point in the history
  • Loading branch information
taimingl committed Apr 2, 2024
1 parent d086a91 commit 153939c
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 72 deletions.
148 changes: 93 additions & 55 deletions src/common/infra/ingest_buffer/task_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,88 +20,126 @@ use once_cell::sync::Lazy;

use super::{entry::IngestEntry, workers::Workers};

pub static TASKQUEUE: Lazy<TaskQueue> = Lazy::new(TaskQueue::new);
// TODO: change those static to env
static MIN_WORKER_CNT: usize = 3;
static DEFAULT_CHANNEL_CAP: usize = 10;

pub type RwMap<K, V> = tokio::sync::RwLock<hashbrown::HashMap<K, V>>;
pub static TQMANAGER: Lazy<TaskQueueManager> = Lazy::new(TaskQueueManager::new);

pub struct TaskQueue {
pub sender: Arc<Sender<IngestEntry>>,
pub receiver: Arc<Receiver<IngestEntry>>,
pub queues: RwMap<Arc<str>, Workers>, // key: stream, val: workers
}
pub type RwMap<K, V> = tokio::sync::RwLock<hashbrown::HashMap<K, V>>;

pub async fn init() -> Result<(), anyhow::Error> {
_ = TASKQUEUE.queues.read().await.len();
_ = TQMANAGER.task_queues.read().await.len();
Ok(())
}

pub async fn send_task(stream_name: &str, task: IngestEntry) -> Result<(), anyhow::Error> {
TASKQUEUE.send_task(stream_name, task).await
// pub async fn send_task(stream_name: &str, task: IngestEntry) -> Result<(), anyhow::Error> {
// TASKQUEUE.send_task(stream_name, task).await
// }

// pub async fn remove_worker_for(stream_name: &str) {
// TASKQUEUE.remove_worker_for(stream_name);
// }

pub struct TaskQueueManager {
pub task_queues: RwMap<Arc<str>, TaskQueue>, // key: stream, val: TaskQueue
}

impl TaskQueue {
impl TaskQueueManager {
pub fn new() -> Self {
// TODO: set default channel cap
let queue_cap = 10;
let (sender, receiver) = bounded::<IngestEntry>(queue_cap);
let queues = RwMap::default();
Self {
task_queues: RwMap::default(),
}
}

pub async fn send_task(&self, stream_name: &str, task: IngestEntry) {
if !self.task_queue_avail(stream_name).await {
self.add_task_queue_for(stream_name).await;
}
let r = self.task_queues.read().await;
let tq = r.get(stream_name).unwrap();
let _ = tq.send_task(task).await;
}

// HELP: how do i invoke this function is a looping thread on the global instance
pub async fn remove_stopped_task_queues(&self) {
let interval = tokio::time::Duration::from_secs(600);
let mut interval = tokio::time::interval(interval);
interval.tick().await; // the first tick is immediate
loop {
let mut keys_to_remove = vec![];
{
let r = self.task_queues.read().await;
for (k, v) in r.iter() {
if v.workers.running_worker_count().await == 0 {
keys_to_remove.push(k.clone());
}
}
}
{
let mut w = self.task_queues.write().await;
for key in keys_to_remove {
w.remove(&key);
}
}
interval.tick().await;
}
}

async fn task_queue_avail(&self, stream_name: &str) -> bool {
let r = self.task_queues.read().await;
if let Some(tq) = r.get(stream_name) {
if tq.workers.running_worker_count().await == 0 {
tq.workers.add_workers_by(MIN_WORKER_CNT).await;
}
return true;
}
false
}

async fn add_task_queue_for(&self, stream_name: &str) {
let tq = TaskQueue::new(DEFAULT_CHANNEL_CAP);
let mut w = self.task_queues.write().await;
w.insert(Arc::from(stream_name), tq);
}
}

pub struct TaskQueue {
pub sender: Arc<Sender<IngestEntry>>,
pub receiver: Arc<Receiver<IngestEntry>>,
pub workers: Arc<Workers>,
}

impl TaskQueue {
// TODO: decide default initial workers count for a queue
pub fn new(channel_cap: usize) -> Self {
let (sender, receiver) = bounded::<IngestEntry>(channel_cap);
// let queues = RwMap::default();
let workers = Arc::new(Workers::new(MIN_WORKER_CNT, Arc::new(receiver.clone())));
Self {
sender: Arc::new(sender),
receiver: Arc::new(receiver),
queues,
workers,
}
}

// TODO
// 1. add logic to increase # of workers for a queue or increase the channel capacity
// 2. send status back. (??: public endpoint needs to respond IngestionResponse)
pub async fn send_task(
&self,
stream_name: &str,
task: IngestEntry,
) -> Result<(), anyhow::Error> {
// 1. add min worker count to increase the number of workers
pub async fn send_task(&self, task: IngestEntry) -> Result<(), anyhow::Error> {
if self.receiver.is_closed() {
return Err(anyhow::anyhow!("Channel is closed. BUG"));
}
if !self.queue_exists_for(stream_name).await {
let mut r = self.queues.write().await;
let workers = Workers::new(Arc::clone(&self.receiver));
r.insert(Arc::from(stream_name), workers);
}

let mut delay_secs = 1;
while let Err(e) = self.sender.try_send(task.clone()) {
println!("channel is full {:?}. delay for now, TODO to add", e);
tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
delay_secs *= 2;
if delay_secs > 10 {
// waiting too long, not enough workers to take tasks out of the channel
// increase worker count for this stream
self.increase_workers_for(stream_name).await;
println!("channel is full {:?}.", e);
if self.workers.running_worker_count().await < MIN_WORKER_CNT {
self.workers.add_workers_by(MIN_WORKER_CNT).await;
}
}
Ok(())
}

// TODO: change the static incremental number
pub async fn increase_workers_for(&self, stream_name: &str) {
let mut rw = self.queues.write().await;
rw.entry(Arc::from(stream_name)).and_modify(|workers| {
workers.add_workers_by(5);
});
}

pub async fn shut_down(&self) {
let mut r = self.queues.write().await;
let _: Vec<_> = r
.values_mut()
.map(|w| async { w.shut_down().await })
.collect();
self.sender.close(); // all cloned receivers will shut down in next iteration
}

async fn queue_exists_for(&self, stream_name: &str) -> bool {
let r = self.queues.read().await;
r.contains_key(stream_name)
self.workers.shut_down().await;
}
}
49 changes: 32 additions & 17 deletions src/common/infra/ingest_buffer/workers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,65 @@ use tokio::task::JoinHandle;

use super::entry::IngestEntry;

type RwVec<T> = tokio::sync::RwLock<Vec<T>>;

pub struct Workers {
pub workers_cnt: usize,
pub receiver: Arc<Receiver<IngestEntry>>,
pub workers: Vec<JoinHandle<()>>,
pub handles: RwVec<JoinHandle<()>>,
}

impl Workers {
// TODO: add default worker count
pub fn new(receiver: Arc<Receiver<IngestEntry>>) -> Self {
let workers_cnt = 1;
let mut workers = Vec::with_capacity(workers_cnt);
for i in 0..workers_cnt {
pub fn new(count: usize, receiver: Arc<Receiver<IngestEntry>>) -> Self {
let mut handles = Vec::with_capacity(count);
for i in 0..count {
let r = receiver.clone();
let worker = tokio::spawn(async move {
let _ = process_task(i, r).await;
});
workers.push(worker);
handles.push(worker);
}
Self {
workers_cnt,
receiver,
workers,
handles: RwVec::from(handles),
}
}

pub fn add_workers_by(&mut self, count: usize) {
for i in self.workers_cnt..self.workers_cnt + count {
pub async fn add_workers_by(&self, count: usize) {
let mut rw = self.handles.write().await;
let curr_cnt = rw.len();
for i in curr_cnt..curr_cnt + count {
let r = Arc::clone(&self.receiver);
let worker = tokio::spawn(async move {
let handle = tokio::spawn(async move {
let _ = process_task(i, r).await;
});
self.workers.push(worker);
rw.push(handle);
}
self.workers_cnt += count;
}

pub async fn running_worker_count(&self) -> usize {
let mut rw = self.handles.write().await;
rw.retain(|handle| !handle.is_finished());
rw.len()
}

// TODO: handle join errors
pub async fn shut_down(&mut self) {
let _join_res = join_all(self.workers.drain(..)).await;
pub async fn shut_down(&self) {
let mut rw = self.handles.write().await;
let _join_res = join_all(rw.drain(..)).await;
}
}

// TODO: add default delay between each pull
// TODO: define max idle time before shutting down a worker
// TODO: handle errors
async fn process_task(
id: usize,
receiver: Arc<Receiver<IngestEntry>>,
) -> Result<(), anyhow::Error> {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(2));
let mut time = std::time::Instant::now();
let max_idle_time = 10.0;
println!("Worker {id} starting");
loop {
if receiver.is_closed() {
break;
Expand All @@ -83,9 +93,14 @@ async fn process_task(
"receiver {id} received and ingesting {} request.",
receiver.len()
);
// reset idle time
time = std::time::Instant::now();
for req in received {
req.ingest().await;
}
} else if time.elapsed().as_secs_f64() > max_idle_time {
println!("worker {id} idle too long, shutting down");
break;
}
interval.tick().await;
}
Expand Down

0 comments on commit 153939c

Please sign in to comment.