From 8ef96f219958d07e92f1147029de7eb7d3be8cfd Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 9 Aug 2021 15:29:31 +0200 Subject: [PATCH] Add test for routing matching multiple methods I don't believe we had a test for this --- examples/hello_world.rs | 30 ++++++++++++++---------------- src/tests/mod.rs | 23 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/examples/hello_world.rs b/examples/hello_world.rs index bdb0a877..91a28f68 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -1,26 +1,24 @@ -//! Run with -//! -//! ```not_rust -//! cargo run --example hello_world -//! ``` - use axum::prelude::*; -use std::net::SocketAddr; +use std::{convert::Infallible, net::SocketAddr, time::Duration}; +use tower::{limit::RateLimitLayer, BoxError, ServiceBuilder}; #[tokio::main] async fn main() { - // Set the RUST_LOG, if it hasn't been explicitly defined - if std::env::var("RUST_LOG").is_err() { - std::env::set_var("RUST_LOG", "hello_world=debug") - } - tracing_subscriber::fmt::init(); + let handler_layer = ServiceBuilder::new() + .buffer(1024) + .layer(RateLimitLayer::new(10, Duration::from_secs(10))) + .into_inner(); - // build our application with a route - let app = route("/", get(handler)); + let app = route( + "/", + get(handler + .layer(handler_layer) + .handle_error(|error: BoxError| { + Ok::<_, Infallible>(format!("Unhandled error: {}", error)) + })), + ); - // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 9de386c5..bfaea8c1 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -604,6 +604,29 @@ async fn wrong_method_service() { assert_eq!(res.status(), StatusCode::NOT_FOUND); } +#[tokio::test] +async fn multiple_methods_for_one_handler() { + async fn root(_: Request) -> &'static str { + "Hello, World!" + } + + let app = route("/", on(MethodFilter::GET | MethodFilter::POST, root)); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client.get(format!("http://{}", addr)).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + /// Run a `tower::Service` in the background and get a URI for it. pub(crate) async fn run_in_background(svc: S) -> SocketAddr where