Fix IntoResponse for tuples overriding error response codes

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
Co-authored-by: Yann Simon <yann.simon@commercetools.com>
This commit is contained in:
Jonas Platte 2025-12-27 20:45:53 +01:00
parent e3b32f48b5
commit a997c76b56
No known key found for this signature in database
GPG Key ID: 7D261D771D915378
10 changed files with 479 additions and 38 deletions

1
Cargo.lock generated
View File

@ -171,6 +171,7 @@ dependencies = [
"hyper",
"mime",
"pin-project-lite",
"serde",
"sync_wrapper",
"tokio",
"tower-http",

View File

@ -55,6 +55,7 @@ axum = { path = "../axum", features = ["__private"] }
axum-extra = { path = "../axum-extra", features = ["typed-header"] }
axum-macros = { path = "../axum-macros", features = ["__private"] }
hyper = "1.0.0"
serde = { version = "1.0.200", features = ["derive"] }
tokio = { version = "1.25.0", features = ["macros"] }
tower-http = { version = "0.6.0", features = ["limit"] }

View File

@ -1,4 +1,4 @@
use super::{IntoResponseParts, Response, ResponseParts};
use super::{ForceStatusCode, IntoResponseFailed, IntoResponseParts, Response, ResponseParts};
use crate::{body::Body, BoxError};
use bytes::{buf::Chain, Buf, Bytes, BytesMut};
use http::{
@ -329,7 +329,9 @@ where
{
fn into_response(self) -> Response {
let mut res = self.1.into_response();
*res.status_mut() = self.0;
if res.extensions().get::<IntoResponseFailed>().is_none() {
*res.status_mut() = self.0;
}
res
}
}
@ -405,18 +407,16 @@ macro_rules! impl_into_response {
let ($($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
$(
let parts = match $ty.into_response_parts(parts) {
if res.extensions().get::<IntoResponseFailed>().is_none() {
let parts = ResponseParts { res };
let parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
Err(err) => return err.into_response(),
};
)*
parts.res
parts.res
} else {
res
}
}
}
@ -430,16 +430,40 @@ macro_rules! impl_into_response {
let (status, $($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
$(
let parts = match $ty.into_response_parts(parts) {
if res.extensions().get::<IntoResponseFailed>().is_none() {
let parts = ResponseParts { res };
let mut parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
Err(err) => return err.into_response(),
};
)*
// Don't call `(status, parts.res).into_response()` since that checks for
// `IntoResponseFailed` and skips setting the status. We've already done that
// check here so overriding the status is required if returning
// `(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR)`
*parts.res.status_mut() = status;
parts.res
} else {
res
}
}
}
#[allow(non_snake_case)]
impl<R, $($ty,)*> IntoResponse for (ForceStatusCode, $($ty),*, R)
where
$( $ty: IntoResponseParts, )*
R: IntoResponse,
{
fn into_response(self) -> Response {
let (status, $($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
let parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => return err.into_response(),
};
(status, parts.res).into_response()
}
@ -455,17 +479,22 @@ macro_rules! impl_into_response {
let (outer_parts, $($ty),*, res) = self;
let res = res.into_response();
let parts = ResponseParts { res };
$(
let parts = match $ty.into_response_parts(parts) {
if res.extensions().get::<IntoResponseFailed>().is_none() {
let parts = ResponseParts { res };
let mut parts = match ($($ty,)*).into_response_parts(parts) {
Ok(parts) => parts,
Err(err) => {
return err.into_response();
}
Err(err) => return err.into_response(),
};
)*
(outer_parts, parts.res).into_response()
// Don't call `(outer_parts, parts.res).into_response()` for the same reason we
// don't call `(status, parts.res).into_response()` in the above impl.
*parts.res.status_mut() = outer_parts.status;
parts.res.headers_mut().extend(outer_parts.headers);
parts.res.extensions_mut().extend(outer_parts.extensions);
parts.res
} else {
res
}
}
}

View File

@ -241,7 +241,9 @@ macro_rules! impl_into_response_parts {
let res = match $ty.into_response_parts(res) {
Ok(res) => res,
Err(err) => {
return Err(err.into_response());
let mut err_res = err.into_response();
err_res.extensions_mut().insert(super::IntoResponseFailed);
return Err(err_res);
}
};
)*
@ -270,3 +272,19 @@ impl IntoResponseParts for () {
Ok(res)
}
}
#[cfg(test)]
mod tests {
use http::StatusCode;
use crate::response::IntoResponse;
#[test]
fn failed_into_response_parts() {
let response = (StatusCode::CREATED, [("\n", "\n")]).into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let response = (StatusCode::CREATED, [("\n", "\n")], ()).into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
}

View File

@ -4,6 +4,10 @@
//!
//! [`axum::response`]: https://docs.rs/axum/0.8/axum/response/index.html
use std::convert::Infallible;
use http::StatusCode;
use crate::body::Body;
mod append_headers;
@ -128,3 +132,87 @@ where
Self(value.into_response())
}
}
/// Response part that stops status code overrides.
///
/// This type should be used by types implementing [`IntoResponseParts`] or
/// [`IntoResponse`] when they fail to produce the response usually expected of
/// them and return some sort of error response instead.
///
/// It is checked used by the tuple impls of [`IntoResponse`] that have a
/// [`StatusCode`] as their first element to ignore that status code.
/// Consider the following example:
///
/// ```no_run
/// # use axum::Json;
/// # use http::StatusCode;
/// # #[derive(serde::Serialize)]
/// # struct CreatedResponse { }
/// fn my_handler(/* ... */) -> (StatusCode, Json<CreatedResponse>) {
/// // This response type's serialization may fail
/// let response = CreatedResponse { /* ... */ };
/// (StatusCode::CREATED, Json(response))
/// }
/// ```
///
/// When `response` serialization succeeds, the server responds with a status
/// code of 201 Created (overwriting `Json`s default status code of 200 OK),
/// and the expected JSON payload.
///
/// When `response` serialization fails hoewever, `impl IntoResponse for Json`
/// return a response with status code 500 Internal Server Error, and
/// `IntoResponseFailed` as a response extension, and the 201 Created override
/// is ignored.
///
/// This is a behavior introduced with axum 0.9.\
/// To force a status code override even when an inner [`IntoResponseParts`] /
/// [`IntoResponse`] failed, use [`ForceStatusCode`].
#[derive(Copy, Clone, Debug)]
pub struct IntoResponseFailed;
impl IntoResponseParts for IntoResponseFailed {
type Error = Infallible;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.extensions_mut().insert(self);
Ok(res)
}
}
/// Not sure it makes sense to return `IntoResponseFailed` as the whole response. You should
/// probably at least combine it with a status code.
///
/// ```compile_fail
/// fn foo()
/// where
/// axum_core::response::IntoResponseFailed: axum_core::response::IntoResponse,
/// {}
/// ```
#[allow(dead_code)]
fn into_response_failed_doesnt_impl_into_response() {}
/// Set the status code regardless of whether [`IntoResponseFailed`] is used or not.
///
/// See the docs for [`IntoResponseFailed`] for more details.
#[derive(Debug, Copy, Clone, Default)]
pub struct ForceStatusCode(pub StatusCode);
impl IntoResponse for ForceStatusCode {
fn into_response(self) -> Response {
let mut res = ().into_response();
*res.status_mut() = self.0;
res
}
}
impl<R> IntoResponse for (ForceStatusCode, R)
where
R: IntoResponse,
{
fn into_response(self) -> Response {
let (ForceStatusCode(status), res) = self;
let mut res = res.into_response();
*res.status_mut() = status;
res
}
}

View File

@ -4,7 +4,7 @@ use axum_core::__composite_rejection as composite_rejection;
use axum_core::__define_rejection as define_rejection;
use axum_core::{
extract::{rejection::BytesRejection, FromRequest, Request},
response::{IntoResponse, Response},
response::{IntoResponse, IntoResponseFailed, Response},
RequestExt,
};
use bytes::BytesMut;
@ -131,7 +131,12 @@ where
let mut buf = BytesMut::with_capacity(self.0.encoded_len());
match &self.0.encode(&mut buf) {
Ok(()) => buf.into_response(),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
IntoResponseFailed,
err.to_string(),
)
.into_response(),
}
}
}

View File

@ -1,6 +1,6 @@
use std::sync::Arc;
use axum_core::response::{IntoResponse, Response};
use axum_core::response::{IntoResponse, IntoResponseFailed, Response};
use bytes::{BufMut, Bytes, BytesMut};
use http::{header, HeaderValue, StatusCode};
use serde_core::Serialize;
@ -78,7 +78,12 @@ impl IntoResponse for ErasedJson {
bytes,
)
.into_response(),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
IntoResponseFailed,
err.to_string(),
)
.into_response(),
}
}
}

View File

@ -1,6 +1,6 @@
use crate::extract::Request;
use crate::extract::{rejection::*, FromRequest, RawForm};
use axum_core::response::{IntoResponse, Response};
use axum_core::response::{IntoResponse, IntoResponseFailed, Response};
use axum_core::RequestExt;
use http::header::CONTENT_TYPE;
use http::StatusCode;
@ -117,7 +117,12 @@ where
body,
)
.into_response(),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
IntoResponseFailed,
err.to_string(),
)
.into_response(),
}
}

View File

@ -1,7 +1,7 @@
use crate::extract::Request;
use crate::extract::{rejection::*, FromRequest};
use axum_core::extract::OptionalFromRequest;
use axum_core::response::{IntoResponse, Response};
use axum_core::response::{IntoResponse, IntoResponseFailed, Response};
use bytes::{BufMut, Bytes, BytesMut};
use http::{
header::{self, HeaderMap, HeaderValue},
@ -216,6 +216,7 @@ where
header::CONTENT_TYPE,
HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()),
)],
IntoResponseFailed,
err.to_string(),
)
.into_response(),

View File

@ -19,7 +19,8 @@ pub use crate::Extension;
#[doc(inline)]
pub use axum_core::response::{
AppendHeaders, ErrorResponse, IntoResponse, IntoResponseParts, Response, ResponseParts, Result,
AppendHeaders, ErrorResponse, IntoResponse, IntoResponseFailed, IntoResponseParts, Response,
ResponseParts, Result,
};
#[doc(inline)]
@ -85,10 +86,16 @@ impl IntoResponse for NoContent {
#[cfg(test)]
mod tests {
use crate::extract::Extension;
use crate::test_helpers::*;
use crate::Json;
use crate::{routing::get, Router};
use axum_core::response::IntoResponse;
use axum_core::response::ForceStatusCode;
use axum_core::response::{
IntoResponse, IntoResponseFailed, IntoResponseParts, Response, ResponseParts,
};
use http::HeaderMap;
use http::{StatusCode, Uri};
use std::collections::HashMap;
// just needs to compile
#[allow(dead_code)]
@ -247,6 +254,287 @@ mod tests {
.route("/", get(header_array_extension_mixed_body));
}
#[test]
fn status_code_tuple_doesnt_override_error() {
// sanity check where there is just one status code
assert_eq!(
StatusCode::INTERNAL_SERVER_ERROR.into_response().status(),
StatusCode::INTERNAL_SERVER_ERROR
);
assert_eq!(
(StatusCode::INTERNAL_SERVER_ERROR,)
.into_response()
.status(),
StatusCode::INTERNAL_SERVER_ERROR
);
// non-5xx status should be changed
assert_eq!(
(StatusCode::SEE_OTHER, StatusCode::NO_CONTENT)
.into_response()
.status(),
StatusCode::SEE_OTHER
);
let res = (
StatusCode::SEE_OTHER,
[("location", "foo")],
StatusCode::NO_CONTENT,
)
.into_response();
assert_eq!(res.status(), StatusCode::SEE_OTHER);
assert_eq!(res.headers()["location"], "foo");
// 5xx status codes are also changed
assert_eq!(
(StatusCode::SEE_OTHER, StatusCode::INTERNAL_SERVER_ERROR)
.into_response()
.status(),
StatusCode::SEE_OTHER
);
let res = (
StatusCode::SEE_OTHER,
[("location", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response();
assert_eq!(res.status(), StatusCode::SEE_OTHER);
assert_eq!(res.headers()["location"], "foo");
// the status is not changed if `IntoResponseFailed` is used
assert_eq!(
(
StatusCode::SEE_OTHER,
(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR)
)
.into_response()
.status(),
StatusCode::INTERNAL_SERVER_ERROR
);
let res = (
StatusCode::SEE_OTHER,
[("location", "foo")],
(IntoResponseFailed, StatusCode::INTERNAL_SERVER_ERROR),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
// response parts from the inner response do run
let res = (
// with status override
StatusCode::SEE_OTHER,
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
let res = (
// without status override
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
// (Parts, ...)
let res = (
Response::new(()).into_parts().0,
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
// (Response<()>, ...)
let res = (
Response::new(()),
[("location", "foo")],
(
[("x-bar", "bar")],
IntoResponseFailed,
[("x-foo", "foo")],
StatusCode::INTERNAL_SERVER_ERROR,
),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(res.headers().get("location").is_none());
assert_eq!(res.headers()["x-foo"], "foo");
assert_eq!(res.headers()["x-bar"], "bar");
}
#[test]
fn into_response_parts_failing_sets_extension() {
struct Fail;
impl IntoResponseParts for Fail {
type Error = ();
fn into_response_parts(
self,
_res: ResponseParts,
) -> Result<ResponseParts, Self::Error> {
Err(())
}
}
impl IntoResponse for Fail {
fn into_response(self) -> Response {
(self, ()).into_response()
}
}
assert!(Fail
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
assert!((StatusCode::INTERNAL_SERVER_ERROR, Fail, ())
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
assert!((Response::new(()).into_parts().0, Fail, ())
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
assert!((Response::new(()), Fail, ())
.into_response()
.extensions()
.get::<IntoResponseFailed>()
.is_some());
}
#[test]
fn doenst_override_status_code_when_using_into_response_failed_at_same_level() {
assert_eq!(
(StatusCode::INTERNAL_SERVER_ERROR, IntoResponseFailed, ())
.into_response()
.status(),
StatusCode::INTERNAL_SERVER_ERROR,
);
#[derive(Clone)]
struct Thing;
let res = (
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("x-foo", "foo")
.extension(Thing)
.body(())
.unwrap()
.into_parts()
.0,
IntoResponseFailed,
(),
)
.into_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(res.headers()["x-foo"], "foo");
assert!(res.extensions().get::<Thing>().is_some());
// just a sanity check
assert_eq!(
(IntoResponseFailed, ()).into_response().status(),
StatusCode::OK,
);
}
#[test]
fn force_overriding_status_code() {
assert_eq!(
ForceStatusCode(StatusCode::IM_A_TEAPOT)
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
assert_eq!(
(ForceStatusCode(StatusCode::IM_A_TEAPOT),)
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
assert_eq!(
(ForceStatusCode(StatusCode::IM_A_TEAPOT), ())
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
assert_eq!(
(
ForceStatusCode(StatusCode::IM_A_TEAPOT),
IntoResponseFailed,
StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response()
.status(),
StatusCode::IM_A_TEAPOT
);
}
#[crate::test]
async fn status_code_tuple_doesnt_override_error_json() {
let app = Router::new()
.route(
"/",
get(|| async {
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
(StatusCode::IM_A_TEAPOT, Json(not_json_compatible))
}),
)
.route(
"/two",
get(|| async {
let not_json_compatible = HashMap::from([(Vec::from([1, 2, 3]), 123)]);
(
ForceStatusCode(StatusCode::IM_A_TEAPOT),
Json(not_json_compatible),
)
}),
);
let client = TestClient::new(app);
let res = client.get("/").await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
let res = client.get("/two").await;
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
}
#[test]
fn no_content() {
assert_eq!(