Use tuples instead of ServiceBuilder internally (#2229)

This commit is contained in:
David Pedersen 2023-09-17 10:56:47 +02:00 committed by GitHub
parent 20f48af914
commit 9eb502c768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 41 additions and 75 deletions

View File

@ -1254,7 +1254,7 @@ mod tests {
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap}; use http::{header::ALLOW, HeaderMap};
use std::time::Duration; use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt}; use tower::{timeout::TimeoutLayer, Service, ServiceExt};
use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer}; use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer};
#[crate::test] #[crate::test]
@ -1354,13 +1354,10 @@ mod tests {
.merge(delete_service(ServeDir::new("."))) .merge(delete_service(ServeDir::new(".")))
.fallback(|| async { StatusCode::NOT_FOUND }) .fallback(|| async { StatusCode::NOT_FOUND })
.put(ok) .put(ok)
.layer( .layer((
ServiceBuilder::new() HandleErrorLayer::new(|_| async { StatusCode::REQUEST_TIMEOUT }),
.layer(HandleErrorLayer::new(|_| async { TimeoutLayer::new(Duration::from_secs(10)),
StatusCode::REQUEST_TIMEOUT )),
}))
.layer(TimeoutLayer::new(Duration::from_secs(10))),
),
); );
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();

View File

@ -17,8 +17,8 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use tower::{ use tower::{
util::{BoxCloneService, MapResponseLayer, Oneshot}, util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot},
ServiceBuilder, ServiceExt, ServiceExt,
}; };
use tower_layer::Layer; use tower_layer::Layer;
use tower_service::Service; use tower_service::Service;
@ -57,12 +57,12 @@ impl<E> Route<E> {
<L::Service as Service<Request>>::Future: Send + 'static, <L::Service as Service<Request>>::Future: Send + 'static,
NewError: 'static, NewError: 'static,
{ {
let layer = ServiceBuilder::new() let layer = (
.map_request(|req: Request<_>| req.map(Body::new)) MapRequestLayer::new(|req: Request<_>| req.map(Body::new)),
.map_err(Into::into) MapErrLayer::new(Into::into),
.layer(MapResponseLayer::new(IntoResponse::into_response)) MapResponseLayer::new(IntoResponse::into_response),
.layer(layer) layer,
.into_inner(); );
Route::new(layer.layer(self)) Route::new(layer.layer(self))
} }

View File

@ -1,6 +1,6 @@
use super::*; use super::*;
use std::future::{pending, ready}; use std::future::{pending, ready};
use tower::{timeout::TimeoutLayer, ServiceBuilder}; use tower::timeout::TimeoutLayer;
async fn unit() {} async fn unit() {}
@ -33,13 +33,10 @@ impl<R> Service<R> for Svc {
async fn handler() { async fn handler() {
let app = Router::new().route( let app = Router::new().route(
"/", "/",
get(forever.layer( get(forever.layer((
ServiceBuilder::new() HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT }),
.layer(HandleErrorLayer::new(|_: BoxError| async { timeout(),
StatusCode::REQUEST_TIMEOUT ))),
}))
.layer(timeout()),
)),
); );
let client = TestClient::new(app); let client = TestClient::new(app);
@ -52,13 +49,10 @@ async fn handler() {
async fn handler_multiple_methods_first() { async fn handler_multiple_methods_first() {
let app = Router::new().route( let app = Router::new().route(
"/", "/",
get(forever.layer( get(forever.layer((
ServiceBuilder::new() HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT }),
.layer(HandleErrorLayer::new(|_: BoxError| async { timeout(),
StatusCode::REQUEST_TIMEOUT )))
}))
.layer(timeout()),
))
.post(unit), .post(unit),
); );
@ -73,15 +67,10 @@ async fn handler_multiple_methods_middle() {
let app = Router::new().route( let app = Router::new().route(
"/", "/",
delete(unit) delete(unit)
.get( .get(forever.layer((
forever.layer( HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT }),
ServiceBuilder::new() timeout(),
.layer(HandleErrorLayer::new(|_: BoxError| async { )))
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
),
)
.post(unit), .post(unit),
); );
@ -95,15 +84,10 @@ async fn handler_multiple_methods_middle() {
async fn handler_multiple_methods_last() { async fn handler_multiple_methods_last() {
let app = Router::new().route( let app = Router::new().route(
"/", "/",
delete(unit).get( delete(unit).get(forever.layer((
forever.layer( HandleErrorLayer::new(|_: BoxError| async { StatusCode::REQUEST_TIMEOUT }),
ServiceBuilder::new() timeout(),
.layer(HandleErrorLayer::new(|_: BoxError| async { ))),
StatusCode::REQUEST_TIMEOUT
}))
.layer(timeout()),
),
),
); );
let client = TestClient::new(app); let client = TestClient::new(app);

View File

@ -127,13 +127,10 @@ async fn layer_and_handle_error() {
let one = Router::new().route("/foo", get(|| async {})); let one = Router::new().route("/foo", get(|| async {}));
let two = Router::new() let two = Router::new()
.route("/timeout", get(std::future::pending::<()>)) .route("/timeout", get(std::future::pending::<()>))
.layer( .layer((
ServiceBuilder::new() HandleErrorLayer::new(|_| async { StatusCode::REQUEST_TIMEOUT }),
.layer(HandleErrorLayer::new(|_| async { TimeoutLayer::new(Duration::from_millis(10)),
StatusCode::REQUEST_TIMEOUT ));
}))
.layer(TimeoutLayer::new(Duration::from_millis(10))),
);
let app = one.merge(two); let app = one.merge(two);
let client = TestClient::new(app); let client = TestClient::new(app);

View File

@ -30,9 +30,7 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
time::Duration, time::Duration,
}; };
use tower::{ use tower::{service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceExt};
service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
};
use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer}; use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
use tower_service::Service; use tower_service::Service;
@ -179,7 +177,6 @@ async fn routing_between_services() {
#[crate::test] #[crate::test]
async fn middleware_on_single_route() { async fn middleware_on_single_route() {
use tower::ServiceBuilder;
use tower_http::{compression::CompressionLayer, trace::TraceLayer}; use tower_http::{compression::CompressionLayer, trace::TraceLayer};
async fn handle(_: Request) -> &'static str { async fn handle(_: Request) -> &'static str {
@ -188,12 +185,7 @@ async fn middleware_on_single_route() {
let app = Router::new().route( let app = Router::new().route(
"/", "/",
get(handle.layer( get(handle.layer((TraceLayer::new_for_http(), CompressionLayer::new()))),
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.into_inner(),
)),
); );
let client = TestClient::new(app); let client = TestClient::new(app);
@ -309,13 +301,10 @@ async fn wildcard_sees_whole_url() {
async fn middleware_applies_to_routes_above() { async fn middleware_applies_to_routes_above() {
let app = Router::new() let app = Router::new()
.route("/one", get(std::future::pending::<()>)) .route("/one", get(std::future::pending::<()>))
.layer( .layer((
ServiceBuilder::new() HandleErrorLayer::new(|_: BoxError| async move { StatusCode::REQUEST_TIMEOUT }),
.layer(HandleErrorLayer::new(|_: BoxError| async move { TimeoutLayer::new(Duration::new(0, 0)),
StatusCode::REQUEST_TIMEOUT ))
}))
.layer(TimeoutLayer::new(Duration::new(0, 0))),
)
.route("/two", get(|| async {})); .route("/two", get(|| async {}));
let client = TestClient::new(app); let client = TestClient::new(app);

View File

@ -14,7 +14,6 @@ use axum::{
routing::post, routing::post,
Router, Router,
}; };
use tower::ServiceBuilder;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main] #[tokio::main]
@ -29,7 +28,7 @@ async fn main() {
let app = Router::new() let app = Router::new()
.route("/", post(handler)) .route("/", post(handler))
.layer(ServiceBuilder::new().layer(middleware::from_fn(print_request_body))); .layer(middleware::from_fn(print_request_body));
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await .await