diff --git a/axum/src/middleware/mod.rs b/axum/src/middleware/mod.rs index 22dab143..41b7671c 100644 --- a/axum/src/middleware/mod.rs +++ b/axum/src/middleware/mod.rs @@ -6,6 +6,7 @@ mod from_extractor; mod from_fn; mod map_request; mod map_response; +mod response_axum_body; pub use self::from_extractor::{ from_extractor, from_extractor_with_state, FromExtractor, FromExtractorLayer, @@ -17,6 +18,9 @@ pub use self::map_request::{ pub use self::map_response::{ map_response, map_response_with_state, MapResponse, MapResponseLayer, }; +pub use self::response_axum_body::{ + ResponseAxumBody, ResponseAxumBodyFuture, ResponseAxumBodyLayer, +}; pub use crate::extension::AddExtension; pub mod future { diff --git a/axum/src/middleware/response_axum_body.rs b/axum/src/middleware/response_axum_body.rs new file mode 100644 index 00000000..786cb4d9 --- /dev/null +++ b/axum/src/middleware/response_axum_body.rs @@ -0,0 +1,76 @@ +use std::{ + error::Error, + future::Future, + pin::Pin, + task::{ready, Context, Poll}, +}; + +use axum_core::{body::Body, response::Response}; +use bytes::Bytes; +use http_body::Body as HttpBody; +use pin_project_lite::pin_project; +use tower::{Layer, Service}; + +/// Layer that transforms the Response body to [`crate::body::Body`]. +/// +/// This is useful when another layer maps the body to some other type to convert it back. +#[derive(Debug, Clone)] +pub struct ResponseAxumBodyLayer; + +impl Layer for ResponseAxumBodyLayer { + type Service = ResponseAxumBody; + + fn layer(&self, inner: S) -> Self::Service { + ResponseAxumBody::(inner) + } +} + +/// Service generated by [`ResponseAxumBodyLayer`]. +#[derive(Debug, Clone)] +pub struct ResponseAxumBody(S); + +impl Service for ResponseAxumBody +where + S: Service>, + ResBody: HttpBody + Send + 'static, + ::Error: Error + Send + Sync, +{ + type Response = Response; + + type Error = S::Error; + + type Future = ResponseAxumBodyFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + ResponseAxumBodyFuture { + inner: self.0.call(req), + } + } +} + +pin_project! { + /// Response future for [`ResponseAxumBody`]. + pub struct ResponseAxumBodyFuture { + #[pin] + inner: Fut, + } +} + +impl Future for ResponseAxumBodyFuture +where + Fut: Future, E>>, + ResBody: HttpBody + Send + 'static, + ::Error: Error + Send + Sync, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let res = ready!(this.inner.poll(cx)?); + Poll::Ready(Ok(res.map(Body::new))) + } +}