diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 7d653677..0d3fcd7f 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **breaking:** `serve` now requires the listener IO type to be `Sync` +- **changed:** Cancelling `serve` now cancels all tokio tasks it spawned for handling connections # Unreleased diff --git a/axum/src/macros.rs b/axum/src/macros.rs index 37b8fc3b..68e15365 100644 --- a/axum/src/macros.rs +++ b/axum/src/macros.rs @@ -75,6 +75,14 @@ macro_rules! trace { } } +#[cfg(feature = "tracing")] +#[allow(unused_macros)] +macro_rules! info { + ($($tt:tt)*) => { + tracing::info!($($tt)*) + } +} + #[cfg(feature = "tracing")] #[allow(unused_macros)] macro_rules! error { @@ -89,6 +97,12 @@ macro_rules! trace { ($($tt:tt)*) => {}; } +#[cfg(not(feature = "tracing"))] +#[allow(unused_macros)] +macro_rules! info { + ($($tt:tt)*) => {}; +} + #[cfg(not(feature = "tracing"))] #[allow(unused_macros)] macro_rules! error { diff --git a/axum/src/serve/mod.rs b/axum/src/serve/mod.rs index 089053f7..f73824c2 100644 --- a/axum/src/serve/mod.rs +++ b/axum/src/serve/mod.rs @@ -18,6 +18,7 @@ use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::watch, + task::JoinSet, }; use tower::ServiceExt as _; use tower_service::Service; @@ -187,19 +188,16 @@ where } = self; let (signal_tx, _signal_rx) = watch::channel(()); - let (_close_tx, close_rx) = watch::channel(()); + + // Use a JoinSet to propagate cancellation + let mut join_set = JoinSet::new(); loop { let (io, remote_addr) = listener.accept().await; let io = TokioIo::new(io); let conn_service = prep_serve_connection(&mut make_service, remote_addr, &io).await; - tokio::spawn(serve_connection( - io, - conn_service, - signal_tx.clone(), - close_rx.clone(), - )); + join_set.spawn(serve_connection(io, conn_service, signal_tx.clone())); } } } @@ -287,15 +285,18 @@ where let (signal_tx, signal_rx) = watch::channel(()); tokio::spawn(async move { signal.await; - trace!("received graceful shutdown signal. Telling tasks to shutdown"); + info!("received graceful shutdown signal"); drop(signal_rx); }); - let (close_tx, close_rx) = watch::channel(()); - + let mut join_set = JoinSet::new(); loop { let (io, remote_addr) = tokio::select! { conn = listener.accept() => conn, + // Eagerly drop tasks from the JoinSet again, to be able to + // report the number of tasks that are still running after the + // shutdown signal is received. + Some(_) = join_set.join_next() => continue, _ = signal_tx.closed() => { trace!("signal received, not accepting new connections"); break; @@ -304,22 +305,16 @@ where let io = TokioIo::new(io); let conn_service = prep_serve_connection(&mut make_service, remote_addr, &io).await; - tokio::spawn(serve_connection( - io, - conn_service, - signal_tx.clone(), - close_rx.clone(), - )); + join_set.spawn(serve_connection(io, conn_service, signal_tx.clone())); } - drop(close_rx); drop(listener); - trace!( - "waiting for {} task(s) to finish", - close_tx.receiver_count() - ); - close_tx.closed().await; + if !join_set.is_empty() { + info!(num_tasks = join_set.len(), "waiting for tasks to finish"); + } + + while join_set.join_next().await.is_some() {} } } @@ -391,12 +386,8 @@ where .unwrap_or_else(|err| match err {}) } -async fn serve_connection( - io: TokioIo, - conn_service: S, - signal_tx: watch::Sender<()>, - close_rx: watch::Receiver<()>, -) where +async fn serve_connection(io: TokioIo, conn_service: S, signal_tx: watch::Sender<()>) +where I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Service + Clone + Send + 'static, S::Future: Send, @@ -428,8 +419,6 @@ async fn serve_connection( } } } - - drop(close_rx); } /// An incoming stream.