mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 07:20:38 +00:00
Add MockConnectInfo
(#1767)
This commit is contained in:
parent
cd86f7ec7a
commit
143c415955
@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
- **added:** Add `FormRejection::FailedToDeserializeFormBody` which is returned
|
- **added:** Add `FormRejection::FailedToDeserializeFormBody` which is returned
|
||||||
if the request body couldn't be deserialized into the target type, as opposed
|
if the request body couldn't be deserialized into the target type, as opposed
|
||||||
to `FailedToDeserializeForm` which is only for query parameters ([#1683])
|
to `FailedToDeserializeForm` which is only for query parameters ([#1683])
|
||||||
|
- **added:** Add `MockConnectInfo` for setting `ConnectInfo` during tests
|
||||||
|
|
||||||
[#1683]: https://github.com/tokio-rs/axum/pull/1683
|
[#1683]: https://github.com/tokio-rs/axum/pull/1683
|
||||||
[#1690]: https://github.com/tokio-rs/axum/pull/1690
|
[#1690]: https://github.com/tokio-rs/axum/pull/1690
|
||||||
|
@ -137,15 +137,81 @@ where
|
|||||||
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
|
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
|
||||||
|
|
||||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
let Extension(connect_info) = Extension::<Self>::from_request_parts(parts, state).await?;
|
match Extension::<Self>::from_request_parts(parts, state).await {
|
||||||
Ok(connect_info)
|
Ok(Extension(connect_info)) => Ok(connect_info),
|
||||||
|
Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
|
||||||
|
Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
|
||||||
|
None => Err(err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware used to mock [`ConnectInfo`] during tests.
|
||||||
|
///
|
||||||
|
/// If you're accidentally using [`MockConnectInfo`] and
|
||||||
|
/// [`Router::into_make_service_with_connect_info`] at the same time then
|
||||||
|
/// [`Router::into_make_service_with_connect_info`] takes precedence.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use axum::{
|
||||||
|
/// Router,
|
||||||
|
/// extract::connect_info::{MockConnectInfo, ConnectInfo},
|
||||||
|
/// body::Body,
|
||||||
|
/// routing::get,
|
||||||
|
/// http::{Request, StatusCode},
|
||||||
|
/// };
|
||||||
|
/// use std::net::SocketAddr;
|
||||||
|
/// use tower::ServiceExt;
|
||||||
|
///
|
||||||
|
/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {}
|
||||||
|
///
|
||||||
|
/// // this router you can run with `app.into_make_service_with_connect_info::<SocketAddr>()`
|
||||||
|
/// fn app() -> Router {
|
||||||
|
/// Router::new().route("/", get(handler))
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // use this router for tests
|
||||||
|
/// fn test_app() -> Router {
|
||||||
|
/// app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))))
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // #[tokio::test]
|
||||||
|
/// async fn some_test() {
|
||||||
|
/// let app = test_app();
|
||||||
|
///
|
||||||
|
/// let request = Request::new(Body::empty());
|
||||||
|
/// let response = app.oneshot(request).await.unwrap();
|
||||||
|
/// assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
/// }
|
||||||
|
/// #
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() {
|
||||||
|
/// # some_test().await;
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct MockConnectInfo<T>(pub T);
|
||||||
|
|
||||||
|
impl<S, T> Layer<S> for MockConnectInfo<T>
|
||||||
|
where
|
||||||
|
T: Clone + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
type Service = <Extension<Self> as Layer<S>>::Service;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
Extension(self.clone()).layer(inner)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{routing::get, Router, Server};
|
use crate::{routing::get, test_helpers::TestClient, Router, Server};
|
||||||
use std::net::{SocketAddr, TcpListener};
|
use std::net::{SocketAddr, TcpListener};
|
||||||
|
|
||||||
#[crate::test]
|
#[crate::test]
|
||||||
@ -214,4 +280,48 @@ mod tests {
|
|||||||
let body = res.text().await.unwrap();
|
let body = res.text().await.unwrap();
|
||||||
assert_eq!(body, "it worked!");
|
assert_eq!(body, "it worked!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[crate::test]
|
||||||
|
async fn mock_connect_info() {
|
||||||
|
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
|
||||||
|
format!("{addr}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/", get(handler))
|
||||||
|
.layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
|
||||||
|
|
||||||
|
let client = TestClient::new(app);
|
||||||
|
|
||||||
|
let res = client.get("/").send().await;
|
||||||
|
let body = res.text().await;
|
||||||
|
assert!(body.starts_with("0.0.0.0:1337"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[crate::test]
|
||||||
|
async fn both_mock_and_real_connect_info() {
|
||||||
|
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
|
||||||
|
format!("{addr}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/", get(handler))
|
||||||
|
.layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
|
||||||
|
|
||||||
|
let server = Server::from_tcp(listener)
|
||||||
|
.unwrap()
|
||||||
|
.serve(app.into_make_service_with_connect_info::<SocketAddr>());
|
||||||
|
server.await.expect("server error");
|
||||||
|
});
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client.get(format!("http://{addr}")).send().await.unwrap();
|
||||||
|
let body = res.text().await.unwrap();
|
||||||
|
assert!(body.starts_with("127.0.0.1:"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,10 @@
|
|||||||
//! cargo test -p example-testing
|
//! cargo test -p example-testing
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
|
extract::ConnectInfo,
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
@ -43,6 +46,10 @@ fn app() -> Router {
|
|||||||
Json(serde_json::json!({ "data": payload.0 }))
|
Json(serde_json::json!({ "data": payload.0 }))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
.route(
|
||||||
|
"/requires-connect-into",
|
||||||
|
get(|ConnectInfo(addr): ConnectInfo<SocketAddr>| async move { format!("Hi {addr}") }),
|
||||||
|
)
|
||||||
// We can still add middleware
|
// We can still add middleware
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
}
|
}
|
||||||
@ -52,6 +59,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
|
extract::connect_info::MockConnectInfo,
|
||||||
http::{self, Request, StatusCode},
|
http::{self, Request, StatusCode},
|
||||||
};
|
};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
@ -164,4 +172,21 @@ mod tests {
|
|||||||
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
||||||
assert_eq!(response.status(), StatusCode::OK);
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Here we're calling `/requires-connect-into` which requires `ConnectInfo`
|
||||||
|
//
|
||||||
|
// That is normally set with `Router::into_make_service_with_connect_info` but we can't easily
|
||||||
|
// use that during tests. The solution is instead to set the `MockConnectInfo` layer during
|
||||||
|
// tests.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn with_into_make_service_with_connect_info() {
|
||||||
|
let mut app = app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 3000))));
|
||||||
|
|
||||||
|
let request = Request::builder()
|
||||||
|
.uri("/requires-connect-into")
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user