mirror of
https://github.com/tokio-rs/axum.git
synced 2025-09-26 20:40:29 +00:00
axum: add ListenerExt::limit_connections (#3489)
This commit is contained in:
parent
50f0082970
commit
9ed1ad69d2
@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
rejection type instead of `axum::response::Response` ([#3261])
|
||||
- **breaking:** `axum::serve` now applies hyper's default `header_read_timeout` ([#3478])
|
||||
- **added:** Implement `OptionalFromRequest` for `Multipart` ([#3220])
|
||||
- **added:** New `ListenerExt::limit_connections` allows limiting concurrent `axum::serve` connections ([#3489])
|
||||
- **changed:** `serve` has an additional generic argument and can now work with any response body
|
||||
type, not just `axum::body::Body` ([#3205])
|
||||
- **change:** Update minimum rust version to 1.78 ([#3412])
|
||||
@ -22,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
[#3220]: https://github.com/tokio-rs/axum/pull/3220
|
||||
[#3412]: https://github.com/tokio-rs/axum/pull/3412
|
||||
[#3478]: https://github.com/tokio-rs/axum/pull/3478
|
||||
[#3489]: https://github.com/tokio-rs/axum/pull/3489
|
||||
|
||||
# 0.8.4
|
||||
|
||||
|
@ -1,8 +1,17 @@
|
||||
use std::{fmt, future::Future, time::Duration};
|
||||
use std::{
|
||||
fmt,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncWrite},
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::{OwnedSemaphorePermit, Semaphore},
|
||||
};
|
||||
|
||||
/// Types that can listen for connections.
|
||||
@ -64,6 +73,24 @@ impl Listener for tokio::net::UnixListener {
|
||||
|
||||
/// Extensions to [`Listener`].
|
||||
pub trait ListenerExt: Listener + Sized {
|
||||
/// Limit the number of concurrent connections. Once the limit has
|
||||
/// been reached, no additional connections will be accepted until
|
||||
/// an existing connection is closed. Listener implementations will
|
||||
/// typically continue to queue incoming connections, up to an OS
|
||||
/// and implementation-specific listener backlog limit.
|
||||
///
|
||||
/// Compare [`tower::limit::concurrency`], which provides ways to
|
||||
/// limit concurrent in-flight requests, but does not limit connections
|
||||
/// that are idle or in the process of sending request headers.
|
||||
///
|
||||
/// [`tower::limit::concurrency`]: https://docs.rs/tower/latest/tower/limit/concurrency/
|
||||
fn limit_connections(self, limit: usize) -> ConnLimiter<Self> {
|
||||
ConnLimiter {
|
||||
listener: self,
|
||||
sem: Arc::new(Semaphore::new(limit)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a mutable closure on every accepted `Io`.
|
||||
///
|
||||
/// # Example
|
||||
@ -99,6 +126,84 @@ pub trait ListenerExt: Listener + Sized {
|
||||
|
||||
impl<L: Listener> ListenerExt for L {}
|
||||
|
||||
/// Return type of [`ListenerExt::limit_connections`].
|
||||
///
|
||||
/// See that method for details.
|
||||
#[derive(Debug)]
|
||||
pub struct ConnLimiter<T> {
|
||||
listener: T,
|
||||
sem: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl<T: Listener> Listener for ConnLimiter<T> {
|
||||
type Io = ConnLimiterIo<T::Io>;
|
||||
type Addr = T::Addr;
|
||||
|
||||
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||
let permit = self.sem.clone().acquire_owned().await.unwrap();
|
||||
let (io, addr) = self.listener.accept().await;
|
||||
(ConnLimiterIo { io, permit }, addr)
|
||||
}
|
||||
|
||||
fn local_addr(&self) -> tokio::io::Result<Self::Addr> {
|
||||
self.listener.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// A connection counted by [`ConnLimiter`].
|
||||
///
|
||||
/// See [`ListenerExt::limit_connections`] for details.
|
||||
#[derive(Debug)]
|
||||
pub struct ConnLimiterIo<T> {
|
||||
#[pin]
|
||||
io: T,
|
||||
permit: OwnedSemaphorePermit,
|
||||
}
|
||||
}
|
||||
|
||||
// Simply forward implementation to `io` field.
|
||||
impl<T: AsyncRead> AsyncRead for ConnLimiterIo<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut io::ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.project().io.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
// Simply forward implementation to `io` field.
|
||||
impl<T: AsyncWrite> AsyncWrite for ConnLimiterIo<T> {
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.io.is_write_vectored()
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().io.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().io.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.project().io.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.project().io.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return type of [`ListenerExt::tap_io`].
|
||||
///
|
||||
/// See that method for details.
|
||||
@ -165,3 +270,54 @@ fn is_connection_error(e: &io::Error) -> bool {
|
||||
| io::ErrorKind::ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use tokio::{io, time};
|
||||
|
||||
use super::{Listener, ListenerExt};
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn limit_connections() {
|
||||
static COUNT: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
struct MyListener;
|
||||
|
||||
impl Listener for MyListener {
|
||||
type Io = io::DuplexStream;
|
||||
type Addr = ();
|
||||
|
||||
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
|
||||
COUNT.fetch_add(1, Ordering::SeqCst);
|
||||
(io::duplex(0).0, ()) // dummy connection
|
||||
}
|
||||
|
||||
fn local_addr(&self) -> io::Result<Self::Addr> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let mut listener = MyListener.limit_connections(1);
|
||||
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 0);
|
||||
|
||||
// First 'accept' succeeds immediately.
|
||||
let conn1 = listener.accept().await;
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
|
||||
|
||||
time::timeout(time::Duration::from_secs(1), listener.accept())
|
||||
.await
|
||||
.expect_err("Second 'accept' should time out.");
|
||||
// It never reaches MyListener::accept to be counted.
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
|
||||
|
||||
// Close the first connection.
|
||||
drop(conn1);
|
||||
|
||||
// Now 'accept' again succeeds immediately.
|
||||
let _conn2 = listener.accept().await;
|
||||
assert_eq!(COUNT.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ use tower_service::Service;
|
||||
|
||||
mod listener;
|
||||
|
||||
pub use self::listener::{Listener, ListenerExt, TapIo};
|
||||
pub use self::listener::{ConnLimiter, ConnLimiterIo, Listener, ListenerExt, TapIo};
|
||||
|
||||
/// Serve the service with the supplied listener.
|
||||
///
|
||||
|
Loading…
x
Reference in New Issue
Block a user