diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 9f28fd64..55314de5 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **fixed:** Nested routers will now inherit fallbacks from outer routers ([#1521]) - **added:** Add `accept_unmasked_frames` setting in WebSocketUpgrade ([#1529]) +- **added:** Add `WebSocketUpgrade::on_failed_upgrade` to customize what to do + when upgrading a connection fails ([#1539]) + +[#1539]: https://github.com/tokio-rs/axum/pull/1539 [#1521]: https://github.com/tokio-rs/axum/pull/1521 diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 63e938c6..2f55a6ba 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -134,18 +134,29 @@ use tokio_tungstenite::{ /// rejected. /// /// See the [module docs](self) for an example. -#[derive(Debug)] #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] -pub struct WebSocketUpgrade { +pub struct WebSocketUpgrade { config: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. protocol: Option, sec_websocket_key: HeaderValue, on_upgrade: OnUpgrade, + on_failed_upgrade: F, sec_websocket_protocol: Option, } -impl WebSocketUpgrade { +impl std::fmt::Debug for WebSocketUpgrade { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WebSocketUpgrade") + .field("config", &self.config) + .field("protocol", &self.protocol) + .field("sec_websocket_key", &self.sec_websocket_key) + .field("sec_websocket_protocol", &self.sec_websocket_protocol) + .finish_non_exhaustive() + } +} + +impl WebSocketUpgrade { /// Set the size of the internal message send queue. pub fn max_send_queue(mut self, max: usize) -> Self { self.config.max_send_queue = Some(max); @@ -231,24 +242,71 @@ impl WebSocketUpgrade { self } + /// Provide a callback to call if upgrading the connection fails. + /// + /// The connection upgrade is performed in a background task. If that fails this callback + /// will be called. + /// + /// By default any errors will be silently ignored. + /// + /// # Example + /// + /// ``` + /// use axum::{ + /// extract::{WebSocketUpgrade}, + /// response::Response, + /// }; + /// + /// async fn handler(ws: WebSocketUpgrade) -> Response { + /// ws.on_failed_upgrade(|error| { + /// report_error(error); + /// }) + /// .on_upgrade(|socket| async { /* ... */ }) + /// } + /// # + /// # fn report_error(_: axum::Error) {} + /// ``` + pub fn on_failed_upgrade(self, callback: C) -> WebSocketUpgrade + where + C: OnFailedUpdgrade, + { + WebSocketUpgrade { + config: self.config, + protocol: self.protocol, + sec_websocket_key: self.sec_websocket_key, + on_upgrade: self.on_upgrade, + on_failed_upgrade: callback, + sec_websocket_protocol: self.sec_websocket_protocol, + } + } + /// Finalize upgrading the connection and call the provided callback with /// the stream. /// /// When using `WebSocketUpgrade`, the response produced by this method /// should be returned from the handler. See the [module docs](self) for an /// example. - pub fn on_upgrade(self, callback: F) -> Response + pub fn on_upgrade(self, callback: C) -> Response where - F: FnOnce(WebSocket) -> Fut + Send + 'static, + C: FnOnce(WebSocket) -> Fut + Send + 'static, Fut: Future + Send + 'static, + F: OnFailedUpdgrade, { let on_upgrade = self.on_upgrade; let config = self.config; + let on_failed_upgrade = self.on_failed_upgrade; let protocol = self.protocol.clone(); tokio::spawn(async move { - let upgraded = on_upgrade.await.expect("connection upgrade failed"); + let upgraded = match on_upgrade.await { + Ok(upgraded) => upgraded, + Err(err) => { + on_failed_upgrade.call(Error::new(err)); + return; + } + }; + let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) .await; @@ -281,8 +339,37 @@ impl WebSocketUpgrade { } } +/// What to do when a connection upgrade fails. +/// +/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. +pub trait OnFailedUpdgrade: Send + 'static { + /// Call the callback. + fn call(self, error: Error); +} + +impl OnFailedUpdgrade for F +where + F: FnOnce(Error) + Send + 'static, +{ + fn call(self, error: Error) { + self(error) + } +} + +/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`. +/// +/// It simply ignores the error. +#[non_exhaustive] +#[derive(Debug)] +pub struct DefaultOnFailedUpdgrade; + +impl OnFailedUpdgrade for DefaultOnFailedUpdgrade { + #[inline] + fn call(self, _error: Error) {} +} + #[async_trait] -impl FromRequestParts for WebSocketUpgrade +impl FromRequestParts for WebSocketUpgrade where S: Send + Sync, { @@ -323,6 +410,7 @@ where sec_websocket_key, on_upgrade, sec_websocket_protocol, + on_failed_upgrade: DefaultOnFailedUpdgrade, }) } } @@ -722,7 +810,7 @@ pub mod close_code { #[cfg(test)] mod tests { use super::*; - use crate::{body::Body, routing::get}; + use crate::{body::Body, routing::get, Router}; use http::{Request, Version}; use tower::ServiceExt; @@ -751,4 +839,21 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + + #[allow(dead_code)] + fn default_on_failed_upgrade() { + async fn handler(ws: WebSocketUpgrade) -> Response { + ws.on_upgrade(|_| async {}) + } + let _: Router = Router::new().route("/", get(handler)); + } + + #[allow(dead_code)] + fn on_failed_upgrade() { + async fn handler(ws: WebSocketUpgrade) -> Response { + ws.on_failed_upgrade(|_error: Error| println!("oops!")) + .on_upgrade(|_| async {}) + } + let _: Router = Router::new().route("/", get(handler)); + } }