From 6c9cabf985236e3775fc07b3f54d639553fd1424 Mon Sep 17 00:00:00 2001 From: Noa Date: Sat, 1 Feb 2025 09:17:00 -0600 Subject: [PATCH] Properly respond with sec-websocket-protocol under http/2 (#3141) --- axum/src/extract/ws.rs | 50 +++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 70b132cf..08dbe204 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -338,7 +338,7 @@ impl WebSocketUpgrade { callback(socket).await; }); - if let Some(sec_websocket_key) = &self.sec_websocket_key { + let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key { // If `sec_websocket_key` was `Some`, we are using HTTP/1.1. #[allow(clippy::declare_interior_mutable_const)] @@ -346,26 +346,30 @@ impl WebSocketUpgrade { #[allow(clippy::declare_interior_mutable_const)] const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); - let mut builder = Response::builder() + Response::builder() .status(StatusCode::SWITCHING_PROTOCOLS) .header(header::CONNECTION, UPGRADE) .header(header::UPGRADE, WEBSOCKET) .header( header::SEC_WEBSOCKET_ACCEPT, sign(sec_websocket_key.as_bytes()), - ); - - if let Some(protocol) = self.protocol { - builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); - } - - builder.body(Body::empty()).unwrap() + ) + .body(Body::empty()) + .unwrap() } else { // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond // with a 2XX with an empty body: // . Response::new(Body::empty()) + }; + + if let Some(protocol) = self.protocol { + response + .headers_mut() + .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol); } + + response } } @@ -1092,10 +1096,11 @@ mod tests { #[crate::test] async fn integration_test() { let addr = spawn_service(echo_app()); - let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo")) - .await - .unwrap(); - test_echo_app(socket).await; + let uri = format!("ws://{addr}/echo").try_into().unwrap(); + let req = tungstenite::client::ClientRequestBuilder::new(uri) + .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO); + let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap(); + test_echo_app(socket, response.headers()).await; } #[crate::test] @@ -1123,21 +1128,22 @@ mod tests { .extension(hyper::ext::Protocol::from_static("websocket")) .uri("/echo") .header("sec-websocket-version", "13") + .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO) .header("Host", "server.example.com") .body(Body::empty()) .unwrap(); - let response = send_request.send_request(req).await.unwrap(); + let mut response = send_request.send_request(req).await.unwrap(); let status = response.status(); if status != 200 { let body = response.into_body().collect().await.unwrap().to_bytes(); let body = std::str::from_utf8(&body).unwrap(); panic!("response status was {status}: {body}"); } - let upgraded = hyper::upgrade::on(response).await.unwrap(); + let upgraded = hyper::upgrade::on(&mut response).await.unwrap(); let upgraded = TokioIo::new(upgraded); let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await; - test_echo_app(socket).await; + test_echo_app(socket, response.headers()).await; } fn echo_app() -> Router { @@ -1158,11 +1164,19 @@ mod tests { Router::new().route( "/echo", - any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))), + any(|ws: WebSocketUpgrade| { + ready(ws.protocols(["echo2", "echo"]).on_upgrade(handle_socket)) + }), ) } - async fn test_echo_app(mut socket: WebSocketStream) { + const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo"; + async fn test_echo_app( + mut socket: WebSocketStream, + headers: &http::HeaderMap, + ) { + assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo"); + let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar")); socket.send(input.clone()).await.unwrap(); let output = socket.next().await.unwrap().unwrap();