serve: Use JoinSet for termination handling

This commit is contained in:
Jonas Platte 2025-04-26 23:49:03 +02:00
parent b5cd093928
commit d959a6fe68
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
3 changed files with 34 additions and 30 deletions

View File

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- **breaking:** `serve` now requires the listener IO type to be `Sync`
- **changed:** Cancelling `serve` now cancels all tokio tasks it spawned for handling connections
# Unreleased

View File

@ -75,6 +75,14 @@ macro_rules! trace {
}
}
#[cfg(feature = "tracing")]
#[allow(unused_macros)]
macro_rules! info {
($($tt:tt)*) => {
tracing::info!($($tt)*)
}
}
#[cfg(feature = "tracing")]
#[allow(unused_macros)]
macro_rules! error {
@ -89,6 +97,12 @@ macro_rules! trace {
($($tt:tt)*) => {};
}
#[cfg(not(feature = "tracing"))]
#[allow(unused_macros)]
macro_rules! info {
($($tt:tt)*) => {};
}
#[cfg(not(feature = "tracing"))]
#[allow(unused_macros)]
macro_rules! error {

View File

@ -18,6 +18,7 @@ use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::watch,
task::JoinSet,
};
use tower::ServiceExt as _;
use tower_service::Service;
@ -187,19 +188,16 @@ where
} = self;
let (signal_tx, _signal_rx) = watch::channel(());
let (_close_tx, close_rx) = watch::channel(());
// Use a JoinSet to propagate cancellation
let mut join_set = JoinSet::new();
loop {
let (io, remote_addr) = listener.accept().await;
let io = TokioIo::new(io);
let conn_service = prep_serve_connection(&mut make_service, remote_addr, &io).await;
tokio::spawn(serve_connection(
io,
conn_service,
signal_tx.clone(),
close_rx.clone(),
));
join_set.spawn(serve_connection(io, conn_service, signal_tx.clone()));
}
}
}
@ -287,15 +285,18 @@ where
let (signal_tx, signal_rx) = watch::channel(());
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
info!("received graceful shutdown signal");
drop(signal_rx);
});
let (close_tx, close_rx) = watch::channel(());
let mut join_set = JoinSet::new();
loop {
let (io, remote_addr) = tokio::select! {
conn = listener.accept() => conn,
// Eagerly drop tasks from the JoinSet again, to be able to
// report the number of tasks that are still running after the
// shutdown signal is received.
Some(_) = join_set.join_next() => continue,
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
@ -304,22 +305,16 @@ where
let io = TokioIo::new(io);
let conn_service = prep_serve_connection(&mut make_service, remote_addr, &io).await;
tokio::spawn(serve_connection(
io,
conn_service,
signal_tx.clone(),
close_rx.clone(),
));
join_set.spawn(serve_connection(io, conn_service, signal_tx.clone()));
}
drop(close_rx);
drop(listener);
trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;
if !join_set.is_empty() {
info!(num_tasks = join_set.len(), "waiting for tasks to finish");
}
while join_set.join_next().await.is_some() {}
}
}
@ -391,12 +386,8 @@ where
.unwrap_or_else(|err| match err {})
}
async fn serve_connection<I, S>(
io: TokioIo<I>,
conn_service: S,
signal_tx: watch::Sender<()>,
close_rx: watch::Receiver<()>,
) where
async fn serve_connection<I, S>(io: TokioIo<I>, conn_service: S, signal_tx: watch::Sender<()>)
where
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
@ -428,8 +419,6 @@ async fn serve_connection<I, S>(
}
}
}
drop(close_rx);
}
/// An incoming stream.