axum: add ListenerExt::limit_connections (#3489)

This commit is contained in:
Søren Løvborg 2025-09-23 23:25:36 +02:00 committed by GitHub
parent 50f0082970
commit 9ed1ad69d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 2 deletions

View File

@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
rejection type instead of `axum::response::Response` ([#3261])
- **breaking:** `axum::serve` now applies hyper's default `header_read_timeout` ([#3478])
- **added:** Implement `OptionalFromRequest` for `Multipart` ([#3220])
- **added:** New `ListenerExt::limit_connections` allows limiting concurrent `axum::serve` connections ([#3489])
- **changed:** `serve` has an additional generic argument and can now work with any response body
type, not just `axum::body::Body` ([#3205])
- **change:** Update minimum rust version to 1.78 ([#3412])
@ -22,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#3220]: https://github.com/tokio-rs/axum/pull/3220
[#3412]: https://github.com/tokio-rs/axum/pull/3412
[#3478]: https://github.com/tokio-rs/axum/pull/3478
[#3489]: https://github.com/tokio-rs/axum/pull/3489
# 0.8.4

View File

@ -1,8 +1,17 @@
use std::{fmt, future::Future, time::Duration};
use std::{
fmt,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use pin_project_lite::pin_project;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
sync::{OwnedSemaphorePermit, Semaphore},
};
/// Types that can listen for connections.
@ -64,6 +73,24 @@ impl Listener for tokio::net::UnixListener {
/// Extensions to [`Listener`].
pub trait ListenerExt: Listener + Sized {
/// Limit the number of concurrent connections. Once the limit has
/// been reached, no additional connections will be accepted until
/// an existing connection is closed. Listener implementations will
/// typically continue to queue incoming connections, up to an OS
/// and implementation-specific listener backlog limit.
///
/// Compare [`tower::limit::concurrency`], which provides ways to
/// limit concurrent in-flight requests, but does not limit connections
/// that are idle or in the process of sending request headers.
///
/// [`tower::limit::concurrency`]: https://docs.rs/tower/latest/tower/limit/concurrency/
fn limit_connections(self, limit: usize) -> ConnLimiter<Self> {
ConnLimiter {
listener: self,
sem: Arc::new(Semaphore::new(limit)),
}
}
/// Run a mutable closure on every accepted `Io`.
///
/// # Example
@ -99,6 +126,84 @@ pub trait ListenerExt: Listener + Sized {
impl<L: Listener> ListenerExt for L {}
/// Return type of [`ListenerExt::limit_connections`].
///
/// See that method for details.
#[derive(Debug)]
pub struct ConnLimiter<T> {
listener: T,
sem: Arc<Semaphore>,
}
impl<T: Listener> Listener for ConnLimiter<T> {
type Io = ConnLimiterIo<T::Io>;
type Addr = T::Addr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let permit = self.sem.clone().acquire_owned().await.unwrap();
let (io, addr) = self.listener.accept().await;
(ConnLimiterIo { io, permit }, addr)
}
fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
self.listener.local_addr()
}
}
pin_project! {
/// A connection counted by [`ConnLimiter`].
///
/// See [`ListenerExt::limit_connections`] for details.
#[derive(Debug)]
pub struct ConnLimiterIo<T> {
#[pin]
io: T,
permit: OwnedSemaphorePermit,
}
}
// Simply forward implementation to `io` field.
impl<T: AsyncRead> AsyncRead for ConnLimiterIo<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().io.poll_read(cx, buf)
}
}
// Simply forward implementation to `io` field.
impl<T: AsyncWrite> AsyncWrite for ConnLimiterIo<T> {
fn is_write_vectored(&self) -> bool {
self.io.is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().io.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().io.poll_shutdown(cx)
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().io.poll_write(cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.project().io.poll_write_vectored(cx, bufs)
}
}
/// Return type of [`ListenerExt::tap_io`].
///
/// See that method for details.
@ -165,3 +270,54 @@ fn is_connection_error(e: &io::Error) -> bool {
| io::ErrorKind::ConnectionReset
)
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::{io, time};
use super::{Listener, ListenerExt};
#[tokio::test(start_paused = true)]
async fn limit_connections() {
static COUNT: AtomicUsize = AtomicUsize::new(0);
struct MyListener;
impl Listener for MyListener {
type Io = io::DuplexStream;
type Addr = ();
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
COUNT.fetch_add(1, Ordering::SeqCst);
(io::duplex(0).0, ()) // dummy connection
}
fn local_addr(&self) -> io::Result<Self::Addr> {
Ok(())
}
}
let mut listener = MyListener.limit_connections(1);
assert_eq!(COUNT.load(Ordering::SeqCst), 0);
// First 'accept' succeeds immediately.
let conn1 = listener.accept().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
time::timeout(time::Duration::from_secs(1), listener.accept())
.await
.expect_err("Second 'accept' should time out.");
// It never reaches MyListener::accept to be counted.
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
// Close the first connection.
drop(conn1);
// Now 'accept' again succeeds immediately.
let _conn2 = listener.accept().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 2);
}
}

View File

@ -23,7 +23,7 @@ use tower_service::Service;
mod listener;
pub use self::listener::{Listener, ListenerExt, TapIo};
pub use self::listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapIo};
/// Serve the service with the supplied listener.
///