From 03e956b4e4c656fce51d427fa0f282784a28f00a Mon Sep 17 00:00:00 2001 From: Max Countryman Date: Sun, 13 Oct 2024 10:42:06 -0700 Subject: [PATCH] provide graceful shutdown interface (#8) This introduces a mechanism for politely asking Underway to shutdown. To do so, a new function, `graceful_shutdown` is provided. Calling this function will send a notification to a Postgres channel. Workers listen on this channel and when a message is received will stop processing new tasks. If they're already processing a task, then they wait until that task is done or the task timeout has elapsed, whichever is first. In order to cleanly stop the queue, this function should be used. If stopping in-progress tasks is safe for your use case, then this can be ignored and the queue can be stopped without any delay. Closes #5 --- ...09c0e325cfacbf1c37589300fb48e8d9eac49.json | 35 +++ ...1e103d8e690ea0cb5189411834b9d8b246fc4.json | 23 ++ Cargo.toml | 1 + src/queue.rs | 22 ++ src/worker.rs | 243 +++++++++++++++++- 5 files changed, 310 insertions(+), 14 deletions(-) create mode 100644 .sqlx/query-45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49.json create mode 100644 .sqlx/query-54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4.json diff --git a/.sqlx/query-45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49.json b/.sqlx/query-45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49.json new file mode 100644 index 0000000..3f3fae0 --- /dev/null +++ b/.sqlx/query-45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49.json @@ -0,0 +1,35 @@ +{ + "db_name": "PostgreSQL", + "query": "\n select count(*)\n from underway.task\n where state = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "count", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [ + { + "Custom": { + "name": "underway.task_state", + "kind": { + "Enum": [ + "pending", + "in_progress", + "succeeded", + "cancelled", + "failed" + ] + } + } + } + ] + }, + "nullable": [ + null + ] + }, + "hash": "45b1b27f9669db53892c4cfad7d09c0e325cfacbf1c37589300fb48e8d9eac49" +} diff --git a/.sqlx/query-54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4.json b/.sqlx/query-54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4.json new file mode 100644 index 0000000..cd6c018 --- /dev/null +++ b/.sqlx/query-54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "select pg_notify($1, $2)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "pg_notify", + "type_info": "Void" + } + ], + "parameters": { + "Left": [ + "Text", + "Text" + ] + }, + "nullable": [ + null + ] + }, + "hash": "54d124a54b2bb28f85b3ee9882f1e103d8e690ea0cb5189411834b9d8b246fc4" +} diff --git a/Cargo.toml b/Cargo.toml index 80f3ee5..c9b1a15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ tokio = { version = "1.40.0", features = [ tracing = { version = "0.1.40", features = ["log"] } ulid = { version = "1.1.3", features = ["uuid"] } uuid = { version = "1.10.0", features = ["v4"] } +num_cpus = "1.16.0" [dev-dependencies] futures = "0.3.30" diff --git a/src/queue.rs b/src/queue.rs index 21ab344..0865212 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -959,6 +959,28 @@ impl QueueBuilder { } } +pub(crate) const SHUTDOWN_CHANNEL: &str = "underway_shutdown"; + +/// Initiates a graceful shutdown by sending a `NOTIFY` to the +/// `underway_shutdown` channel via the `pg_notify` function. +/// +/// Workers listen on this channel and when a message is received will stop +/// processing further tasks and wait for in-progress tasks to finish or +/// timeout. +/// +/// This can be useful when combined with [`tokio::signal`] to ensure queues are +/// stopped cleanly when stopping your application. +pub async fn graceful_shutdown<'a, E>(executor: E) -> Result +where + E: PgExecutor<'a>, +{ + sqlx::query!("select pg_notify($1, $2)", SHUTDOWN_CHANNEL, "") + .execute(executor) + .await?; + + Ok(()) +} + #[cfg(test)] mod tests { use std::collections::HashSet; diff --git a/src/worker.rs b/src/worker.rs index 71e1b05..dfb7d2d 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -101,24 +101,63 @@ //! # }); //! # } //! ``` +//! +//! # Stopping workers safely +//! +//! In order to ensure that workers are interrupted while handling in-progress +//! tasks, the [`graceful_shutdown`](crate::queue::graceful_shutdown) function +//! is provided. +//! +//! This function allows you to politely ask all workers to stop processing new +//! tasks. At the same time, workers are also aware of any in-progress tasks +//! they're working on and will wait for these to be done or timeout. +//! +//! For cases where it's unimportant to wait for tasks to complete, this routine +//! can be ignored. + +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use jiff::{Span, ToSpan}; use serde::{de::DeserializeOwned, Serialize}; -use sqlx::{postgres::types::PgInterval, PgConnection}; +use sqlx::{ + postgres::{types::PgInterval, PgListener}, + PgConnection, +}; +use tokio::{sync::Semaphore, task::JoinSet}; use tracing::instrument; use crate::{ job::Job, - queue::{Error as QueueError, Queue}, + queue::{Error as QueueError, Queue, SHUTDOWN_CHANNEL}, task::{DequeuedTask, Error as TaskError, Id as TaskId, RetryCount, RetryPolicy, Task}, }; pub(crate) type Result = std::result::Result<(), Error>; /// A worker that's generic over the task it processes. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Worker { queue: Queue, - task: T, + task: Arc, + + // Limits the number of concurrent `Task::execute` invocations this worker will be allowed. + concurrency_limit: usize, + + // Indicates the underlying queue has received a shutdown signal. + queue_shutdown: Arc, +} + +impl Clone for Worker { + fn clone(&self) -> Self { + Self { + queue: self.queue.clone(), + task: Arc::clone(&self.task), + concurrency_limit: self.concurrency_limit, + queue_shutdown: self.queue_shutdown.clone(), + } + } } /// Worker errors. @@ -153,7 +192,9 @@ where fn from(job: Job) -> Self { Self { queue: job.queue.clone(), - task: job, + task: Arc::new(job), + concurrency_limit: num_cpus::get(), + queue_shutdown: Arc::new(false.into()), } } } @@ -165,15 +206,28 @@ where fn from(job: &Job) -> Self { Self { queue: job.queue.clone(), - task: job.clone(), + task: Arc::new(job.to_owned()), + concurrency_limit: num_cpus::get(), + queue_shutdown: Arc::new(false.into()), } } } -impl Worker { +impl Worker { /// Creates a new worker with the given queue and task. - pub const fn new(queue: Queue, task: T) -> Self { - Self { queue, task } + pub fn new(queue: Queue, task: T) -> Self { + Self { + queue, + task: Arc::new(task), + concurrency_limit: num_cpus::get(), + queue_shutdown: Arc::new(false.into()), + } + } + + /// Sets the concurrency limit for this worker. + pub fn concurrency_limit(mut self, concurrency_limit: usize) -> Self { + self.concurrency_limit = concurrency_limit; + self } /// Runs the worker, processing tasks as they become available. @@ -188,11 +242,80 @@ impl Worker { /// polls. pub async fn run_every(&self, span: Span) -> Result { let mut interval = tokio::time::interval(span.try_into()?); - interval.tick().await; + + // Set up a listener for shutdown notifications + let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?; + shutdown_listener.listen(SHUTDOWN_CHANNEL).await?; + + let concurrency_limit = Arc::new(Semaphore::new(self.concurrency_limit)); + let mut processing_tasks = JoinSet::new(); + loop { - self.process_next_task().await?; - interval.tick().await; + tokio::select! { + shutdown_notif = shutdown_listener.recv() => { + if let Err(err) = shutdown_notif { + tracing::error!(%err, "NOTIFY resulted in an error"); + continue; + } + + self.queue_shutdown.store(true, Ordering::SeqCst); + + let task_timeout = self.task.timeout(); + + tracing::info!( + task.timeout = ?task_timeout, + "Waiting for all processing tasks or timeout" + ); + + // Try to join all the processing tasks before the task timeout. + let shutdown_result = tokio::time::timeout( + task_timeout.try_into()?, + async { + while let Some(res) = processing_tasks.join_next().await { + if let Err(err) = res { + tracing::error!(%err, "A processing task failed during shutdown"); + } + } + } + ).await; + + match shutdown_result { + Ok(_) => { + tracing::debug!("All processing tasks completed gracefully"); + }, + Err(_) => { + let remaining_tasks = processing_tasks.len(); + tracing::warn!(remaining_tasks, "Reached task timeout before all tasks completed"); + }, + } + + break; + }, + + _ = interval.tick() => { + if self.queue_shutdown.load(Ordering::SeqCst) { + tracing::info!("Queue is shutdown so no new tasks will be processed"); + break; + } + + let permit = concurrency_limit.clone().acquire_owned().await.expect("Concurrency limit semaphore should be open"); + processing_tasks.spawn({ + // TODO: Rather than clone the worker, we could have a separate type that + // owns task processing. + let worker = self.clone(); + + async move { + if let Err(err) = worker.process_next_task().await { + tracing::error!(%err, "Error processing next task"); + } + drop(permit); + } + }); + } + } } + + Ok(()) } /// Processes the next available task in the queue. @@ -234,6 +357,7 @@ impl Worker { Ok(_) => { self.queue.mark_task_succeeded(&mut *tx, task_id).await?; } + Err(err) => { self.handle_task_error(err, &mut tx, task_id, task_row) .await?; @@ -357,13 +481,16 @@ pub(crate) fn pg_interval_to_span( #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{sync::Arc, time::Duration as StdDuration}; use sqlx::PgPool; use tokio::sync::Mutex; use super::*; - use crate::task::{Result as TaskResult, State as TaskState}; + use crate::{ + queue::graceful_shutdown, + task::{Result as TaskResult, State as TaskState}, + }; struct TestTask; @@ -461,4 +588,92 @@ mod tests { Ok(()) } + + #[sqlx::test] + async fn test_graceful_shutdown(pool: PgPool) -> sqlx::Result<(), Error> { + let queue = Queue::builder() + .name("test_queue") + .pool(pool.clone()) + .build() + .await?; + + #[derive(Debug, Clone)] + struct LongRunningTask; + + impl Task for LongRunningTask { + type Input = (); + + async fn execute(&self, _: Self::Input) -> TaskResult { + tokio::time::sleep(StdDuration::from_secs(1)).await; + Ok(()) + } + } + + // Enqueue some tasks + for _ in 0..5 { + queue.enqueue(&pool, &LongRunningTask, ()).await?; + } + + // Start workers + let worker = Worker::new(queue.clone(), LongRunningTask); + for _ in 0..2 { + let worker = worker.clone(); + tokio::spawn(async move { worker.run().await }); + } + + let pending = sqlx::query_scalar!( + r#" + select count(*) + from underway.task + where state = $1 + "#, + TaskState::Pending as _ + ) + .fetch_one(&pool) + .await?; + assert_eq!(pending, Some(5)); + + // Wait briefly to ensure workers are listening + tokio::time::sleep(StdDuration::from_secs(2)).await; + + // Initiate graceful shutdown + graceful_shutdown(&pool).await?; + + // Wait for tasks to be done + tokio::time::sleep(StdDuration::from_secs(5)).await; + + let succeeded = sqlx::query_scalar!( + r#" + select count(*) + from underway.task + where state = $1 + "#, + TaskState::Succeeded as _ + ) + .fetch_one(&pool) + .await?; + assert_eq!(succeeded, Some(5)); + + // New tasks shouldn't be processed + queue.enqueue(&pool, &LongRunningTask, ()).await?; + + // Wait to ensure a worker would have seen the new task if one were processing + tokio::time::sleep(StdDuration::from_secs(5)).await; + + let succeeded = sqlx::query_scalar!( + r#" + select count(*) + from underway.task + where state = $1 + "#, + TaskState::Succeeded as _ + ) + .fetch_one(&pool) + .await?; + + // Succeeded count remains the same since workers have been shutdown + assert_eq!(succeeded, Some(5)); + + Ok(()) + } }