diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ca21426..6ba45462 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- Implement `Stream` for `WebSocket`. -- Implement `Sink` for `WebSocket`. -- Implement `Deref` most extractors. +- Implement `Stream` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52)) +- Implement `Sink` for `WebSocket` ([#52](https://github.com/tokio-rs/axum/pull/52)) +- Implement `Deref` most extractors ([#56](https://github.com/tokio-rs/axum/pull/56)) +- Return `405 Method Not Allowed` for unsupported method for route ([#63](https://github.com/tokio-rs/axum/pull/63)) ## Breaking changes diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 12a4a7e0..aaa9792a 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -168,7 +168,7 @@ where OnMethod { method, svc: handler.into_service(), - fallback: EmptyRouter::new(), + fallback: EmptyRouter::method_not_allowed(), } } diff --git a/src/lib.rs b/src/lib.rs index c2084469..90dee214 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -676,7 +676,7 @@ where { use routing::RoutingDsl; - routing::EmptyRouter::new().route(description, service) + routing::EmptyRouter::not_found().route(description, service) } mod sealed { diff --git a/src/routing.rs b/src/routing.rs index 8871fb95..75c20f3b 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -84,7 +84,6 @@ pub struct Route { } /// Trait for building routers. -// TODO(david): this name isn't great #[async_trait] pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// Add another route to the router. @@ -364,21 +363,38 @@ fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { } } -/// A [`Service`] that responds with `404 Not Found` to all requests. +/// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed` +/// to all requests. /// /// This is used as the bottom service in a router stack. You shouldn't have to /// use to manually. -pub struct EmptyRouter(PhantomData E>); +pub struct EmptyRouter { + status: StatusCode, + _marker: PhantomData E>, +} impl EmptyRouter { - pub(crate) fn new() -> Self { - Self(PhantomData) + pub(crate) fn not_found() -> Self { + Self { + status: StatusCode::NOT_FOUND, + _marker: PhantomData, + } + } + + pub(crate) fn method_not_allowed() -> Self { + Self { + status: StatusCode::METHOD_NOT_ALLOWED, + _marker: PhantomData, + } } } impl Clone for EmptyRouter { fn clone(&self) -> Self { - Self(PhantomData) + Self { + status: self.status, + _marker: PhantomData, + } } } @@ -405,7 +421,7 @@ impl Service> for EmptyRouter { fn call(&mut self, _req: Request) -> Self::Future { let mut res = Response::new(crate::body::empty()); - *res.status_mut() = StatusCode::NOT_FOUND; + *res.status_mut() = self.status; EmptyRouterFuture(future::ok(res)) } } @@ -806,7 +822,7 @@ where Nested { pattern: PathPattern::new(description), svc, - fallback: EmptyRouter::new(), + fallback: EmptyRouter::not_found(), } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 39c4a18d..46a7728b 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -256,7 +256,7 @@ where inner: svc, _request_body: PhantomData, }, - fallback: EmptyRouter::new(), + fallback: EmptyRouter::method_not_allowed(), } } diff --git a/src/tests.rs b/src/tests.rs index 9d8fb2b5..d476109d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -3,6 +3,7 @@ use crate::{ service, }; use bytes::Bytes; +use futures_util::future::Ready; use http::{header::AUTHORIZATION, Request, Response, StatusCode}; use hyper::{Body, Server}; use serde::Deserialize; @@ -10,6 +11,7 @@ use serde_json::json; use std::{ convert::Infallible, net::{SocketAddr, TcpListener}, + task::{Context, Poll}, time::Duration, }; use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder}; @@ -677,6 +679,124 @@ async fn test_extractor_middleware() { assert_eq!(res.status(), StatusCode::OK); } +#[tokio::test] +async fn wrong_method_handler() { + let app = route("/", get(|| async {}).post(|| async {})).route("/foo", patch(|| async {})); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .patch(format!("http://{}", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + let res = client + .patch(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + let res = client + .get(format!("http://{}/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn wrong_method_nest() { + let nested_app = route("/", get(|| async {})); + let app = crate::routing::nest("/", nested_app); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client.get(format!("http://{}", addr)).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + let res = client + .patch(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn wrong_method_service() { + #[derive(Clone)] + struct Svc; + + impl Service for Svc { + type Response = Response>; + type Error = Infallible; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: R) -> Self::Future { + futures_util::future::ok(Response::new(http_body::Empty::new())) + } + } + + let app = route("/", service::get(Svc).post(Svc)).route("/foo", service::patch(Svc)); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .patch(format!("http://{}", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + let res = client + .patch(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + let res = client + .get(format!("http://{}/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + /// Run a `tower::Service` in the background and get a URI for it. async fn run_in_background(svc: S) -> SocketAddr where