mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 07:20:38 +00:00
Properly respond with sec-websocket-protocol under http/2 (#3141)
This commit is contained in:
parent
0e6e96fb8c
commit
6c9cabf985
@ -338,7 +338,7 @@ impl<F> WebSocketUpgrade<F> {
|
|||||||
callback(socket).await;
|
callback(socket).await;
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(sec_websocket_key) = &self.sec_websocket_key {
|
let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
|
||||||
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
|
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.
|
||||||
|
|
||||||
#[allow(clippy::declare_interior_mutable_const)]
|
#[allow(clippy::declare_interior_mutable_const)]
|
||||||
@ -346,26 +346,30 @@ impl<F> WebSocketUpgrade<F> {
|
|||||||
#[allow(clippy::declare_interior_mutable_const)]
|
#[allow(clippy::declare_interior_mutable_const)]
|
||||||
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
|
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
|
||||||
|
|
||||||
let mut builder = Response::builder()
|
Response::builder()
|
||||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||||
.header(header::CONNECTION, UPGRADE)
|
.header(header::CONNECTION, UPGRADE)
|
||||||
.header(header::UPGRADE, WEBSOCKET)
|
.header(header::UPGRADE, WEBSOCKET)
|
||||||
.header(
|
.header(
|
||||||
header::SEC_WEBSOCKET_ACCEPT,
|
header::SEC_WEBSOCKET_ACCEPT,
|
||||||
sign(sec_websocket_key.as_bytes()),
|
sign(sec_websocket_key.as_bytes()),
|
||||||
);
|
)
|
||||||
|
.body(Body::empty())
|
||||||
if let Some(protocol) = self.protocol {
|
.unwrap()
|
||||||
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.body(Body::empty()).unwrap()
|
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
|
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
|
||||||
// with a 2XX with an empty body:
|
// with a 2XX with an empty body:
|
||||||
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
|
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
|
||||||
Response::new(Body::empty())
|
Response::new(Body::empty())
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(protocol) = self.protocol {
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1092,10 +1096,11 @@ mod tests {
|
|||||||
#[crate::test]
|
#[crate::test]
|
||||||
async fn integration_test() {
|
async fn integration_test() {
|
||||||
let addr = spawn_service(echo_app());
|
let addr = spawn_service(echo_app());
|
||||||
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
|
let uri = format!("ws://{addr}/echo").try_into().unwrap();
|
||||||
.await
|
let req = tungstenite::client::ClientRequestBuilder::new(uri)
|
||||||
.unwrap();
|
.with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
|
||||||
test_echo_app(socket).await;
|
let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
|
||||||
|
test_echo_app(socket, response.headers()).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[crate::test]
|
#[crate::test]
|
||||||
@ -1123,21 +1128,22 @@ mod tests {
|
|||||||
.extension(hyper::ext::Protocol::from_static("websocket"))
|
.extension(hyper::ext::Protocol::from_static("websocket"))
|
||||||
.uri("/echo")
|
.uri("/echo")
|
||||||
.header("sec-websocket-version", "13")
|
.header("sec-websocket-version", "13")
|
||||||
|
.header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
|
||||||
.header("Host", "server.example.com")
|
.header("Host", "server.example.com")
|
||||||
.body(Body::empty())
|
.body(Body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let response = send_request.send_request(req).await.unwrap();
|
let mut response = send_request.send_request(req).await.unwrap();
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
if status != 200 {
|
if status != 200 {
|
||||||
let body = response.into_body().collect().await.unwrap().to_bytes();
|
let body = response.into_body().collect().await.unwrap().to_bytes();
|
||||||
let body = std::str::from_utf8(&body).unwrap();
|
let body = std::str::from_utf8(&body).unwrap();
|
||||||
panic!("response status was {status}: {body}");
|
panic!("response status was {status}: {body}");
|
||||||
}
|
}
|
||||||
let upgraded = hyper::upgrade::on(response).await.unwrap();
|
let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
|
||||||
let upgraded = TokioIo::new(upgraded);
|
let upgraded = TokioIo::new(upgraded);
|
||||||
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
|
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
|
||||||
test_echo_app(socket).await;
|
test_echo_app(socket, response.headers()).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn echo_app() -> Router {
|
fn echo_app() -> Router {
|
||||||
@ -1158,11 +1164,19 @@ mod tests {
|
|||||||
|
|
||||||
Router::new().route(
|
Router::new().route(
|
||||||
"/echo",
|
"/echo",
|
||||||
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
|
any(|ws: WebSocketUpgrade| {
|
||||||
|
ready(ws.protocols(["echo2", "echo"]).on_upgrade(handle_socket))
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
|
const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
|
||||||
|
async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
|
||||||
|
mut socket: WebSocketStream<S>,
|
||||||
|
headers: &http::HeaderMap,
|
||||||
|
) {
|
||||||
|
assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
|
||||||
|
|
||||||
let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
|
let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
|
||||||
socket.send(input.clone()).await.unwrap();
|
socket.send(input.clone()).await.unwrap();
|
||||||
let output = socket.next().await.unwrap().unwrap();
|
let output = socket.next().await.unwrap().unwrap();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user