mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 15:24:54 +00:00
Make websocket handlers support extractors (#41)
This commit is contained in:
parent
d927c819d3
commit
d843f4378b
@ -5,11 +5,12 @@
|
|||||||
//! ```
|
//! ```
|
||||||
//! RUST_LOG=tower_http=debug,key_value_store=trace \
|
//! RUST_LOG=tower_http=debug,key_value_store=trace \
|
||||||
//! cargo run \
|
//! cargo run \
|
||||||
//! --features ws \
|
//! --all-features \
|
||||||
//! --example websocket
|
//! --example websocket
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
|
extract::TypedHeader,
|
||||||
prelude::*,
|
prelude::*,
|
||||||
routing::nest,
|
routing::nest,
|
||||||
service::ServiceExt,
|
service::ServiceExt,
|
||||||
@ -57,7 +58,13 @@ async fn main() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(mut socket: WebSocket) {
|
async fn handle_socket(
|
||||||
|
mut socket: WebSocket,
|
||||||
|
// websocket handlers can also use extractors
|
||||||
|
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
|
||||||
|
) {
|
||||||
|
println!("`{}` connected", user_agent.as_str());
|
||||||
|
|
||||||
if let Some(msg) = socket.recv().await {
|
if let Some(msg) = socket.recv().await {
|
||||||
let msg = msg.unwrap();
|
let msg = msg.unwrap();
|
||||||
println!("Client says: {:?}", msg);
|
println!("Client says: {:?}", msg);
|
||||||
|
@ -248,7 +248,7 @@ use crate::{response::IntoResponse, util::ByteStr};
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::{Buf, Bytes};
|
use bytes::{Buf, Bytes};
|
||||||
use futures_util::stream::Stream;
|
use futures_util::stream::Stream;
|
||||||
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
|
use http::{header, Extensions, HeaderMap, Method, Request, Response, Uri, Version};
|
||||||
use rejection::*;
|
use rejection::*;
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use std::{
|
use std::{
|
||||||
@ -475,6 +475,46 @@ impl<B> RequestParts<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<B> FromRequest<B> for ()
|
||||||
|
where
|
||||||
|
B: Send,
|
||||||
|
{
|
||||||
|
type Rejection = Infallible;
|
||||||
|
|
||||||
|
async fn from_request(_: &mut RequestParts<B>) -> Result<(), Self::Rejection> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! impl_from_request {
|
||||||
|
() => {
|
||||||
|
};
|
||||||
|
|
||||||
|
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||||
|
#[async_trait]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
impl<B, $head, $($tail,)*> FromRequest<B> for ($head, $($tail,)*)
|
||||||
|
where
|
||||||
|
$head: FromRequest<B> + Send,
|
||||||
|
$( $tail: FromRequest<B> + Send, )*
|
||||||
|
B: Send,
|
||||||
|
{
|
||||||
|
type Rejection = Response<crate::body::Body>;
|
||||||
|
|
||||||
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
let $head = $head::from_request(req).await.map_err(IntoResponse::into_response)?;
|
||||||
|
$( let $tail = $tail::from_request(req).await.map_err(IntoResponse::into_response)?; )*
|
||||||
|
Ok(($head, $($tail,)*))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_from_request!($($tail,)*);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_from_request!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T, B> FromRequest<B> for Option<T>
|
impl<T, B> FromRequest<B> for Option<T>
|
||||||
where
|
where
|
||||||
@ -1233,3 +1273,39 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extractor that extracts the raw query string, without parsing it.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust,no_run
|
||||||
|
/// use axum::prelude::*;
|
||||||
|
/// use futures::StreamExt;
|
||||||
|
///
|
||||||
|
/// async fn handler(extract::RawQuery(query): extract::RawQuery) {
|
||||||
|
/// // ...
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// let app = route("/users", get(handler));
|
||||||
|
/// # async {
|
||||||
|
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
|
/// # };
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RawQuery(pub Option<String>);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<B> FromRequest<B> for RawQuery
|
||||||
|
where
|
||||||
|
B: Send,
|
||||||
|
{
|
||||||
|
type Rejection = Infallible;
|
||||||
|
|
||||||
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
let query = req
|
||||||
|
.uri()
|
||||||
|
.and_then(|uri| uri.query())
|
||||||
|
.map(|query| query.to_string());
|
||||||
|
Ok(Self(query))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -172,7 +172,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod sealed {
|
pub(crate) mod sealed {
|
||||||
#![allow(unreachable_pub, missing_docs, missing_debug_implementations)]
|
#![allow(unreachable_pub, missing_docs, missing_debug_implementations)]
|
||||||
|
|
||||||
pub trait HiddentTrait {}
|
pub trait HiddentTrait {}
|
||||||
@ -188,8 +188,8 @@ mod sealed {
|
|||||||
/// See the [module docs](crate::handler) for more details.
|
/// See the [module docs](crate::handler) for more details.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Handler<B, In>: Sized {
|
pub trait Handler<B, In>: Sized {
|
||||||
// This seals the trait. We cannot use the regular "sealed super trait" approach
|
// This seals the trait. We cannot use the regular "sealed super trait"
|
||||||
// due to coherence.
|
// approach due to coherence.
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
type Sealed: sealed::HiddentTrait;
|
type Sealed: sealed::HiddentTrait;
|
||||||
|
|
||||||
@ -256,7 +256,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! impl_handler {
|
macro_rules! impl_handler {
|
||||||
() => {};
|
() => {
|
||||||
|
};
|
||||||
|
|
||||||
( $head:ident, $($tail:ident),* $(,)? ) => {
|
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
13
src/lib.rs
13
src/lib.rs
@ -65,7 +65,7 @@
|
|||||||
//! ["extractors"](#extractors) as arguments and returns something that
|
//! ["extractors"](#extractors) as arguments and returns something that
|
||||||
//! can be converted [into a response](#building-responses).
|
//! can be converted [into a response](#building-responses).
|
||||||
//!
|
//!
|
||||||
//! Handlers is where you custom domain logic lives and axum applications are
|
//! Handlers is where your custom domain logic lives and axum applications are
|
||||||
//! built by routing between handlers.
|
//! built by routing between handlers.
|
||||||
//!
|
//!
|
||||||
//! Some examples of handlers:
|
//! Some examples of handlers:
|
||||||
@ -78,14 +78,14 @@
|
|||||||
//! // Handler that immediately returns an empty `200 OK` response.
|
//! // Handler that immediately returns an empty `200 OK` response.
|
||||||
//! async fn unit_handler() {}
|
//! async fn unit_handler() {}
|
||||||
//!
|
//!
|
||||||
//! // Handler that immediately returns an empty `200 Ok` response with a plain
|
//! // Handler that immediately returns an empty `200 OK` response with a plain
|
||||||
//! // text body.
|
//! // text body.
|
||||||
//! async fn string_handler() -> String {
|
//! async fn string_handler() -> String {
|
||||||
//! "Hello, World!".to_string()
|
//! "Hello, World!".to_string()
|
||||||
//! }
|
//! }
|
||||||
//!
|
//!
|
||||||
//! // Handler that buffers the request body and returns it if it is valid UTF-8
|
//! // Handler that buffers the request body and returns it.
|
||||||
//! async fn buffer_body(body: Bytes) -> Result<String, StatusCode> {
|
//! async fn echo(body: Bytes) -> Result<String, StatusCode> {
|
||||||
//! if let Ok(string) = String::from_utf8(body.to_vec()) {
|
//! if let Ok(string) = String::from_utf8(body.to_vec()) {
|
||||||
//! Ok(string)
|
//! Ok(string)
|
||||||
//! } else {
|
//! } else {
|
||||||
@ -248,7 +248,7 @@
|
|||||||
//! "foo"
|
//! "foo"
|
||||||
//! }
|
//! }
|
||||||
//!
|
//!
|
||||||
//! // String works too and will get a text/plain content-type
|
//! // String works too and will get a `text/plain` content-type
|
||||||
//! async fn plain_text_string(uri: Uri) -> String {
|
//! async fn plain_text_string(uri: Uri) -> String {
|
||||||
//! format!("Hi from {}", uri.path())
|
//! format!("Hi from {}", uri.path())
|
||||||
//! }
|
//! }
|
||||||
@ -547,14 +547,13 @@
|
|||||||
//! [`Timeout`]: tower::timeout::Timeout
|
//! [`Timeout`]: tower::timeout::Timeout
|
||||||
//! [examples]: https://github.com/tokio-rs/axum/tree/main/examples
|
//! [examples]: https://github.com/tokio-rs/axum/tree/main/examples
|
||||||
|
|
||||||
#![doc(html_root_url = "https://docs.rs/tower-http/0.1.0")]
|
#![doc(html_root_url = "https://docs.rs/axum/0.1.0")]
|
||||||
#.
|
opaque_future! {
|
||||||
#[derive(Debug)]
|
/// Response future for [`WebSocketUpgrade`](super::WebSocketUpgrade).
|
||||||
pub struct ResponseFuture(Result<Option<HeaderValue>, Option<(StatusCode, &'static str)>>);
|
pub type ResponseFuture = futures_util::future::BoxFuture<'static, Result<Response<BoxBody>, Infallible>>;
|
||||||
|
|
||||||
impl ResponseFuture {
|
|
||||||
pub(super) fn ok(key: HeaderValue) -> Self {
|
|
||||||
Self(Ok(Some(key)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn err(status: StatusCode, body: &'static str) -> Self {
|
|
||||||
Self(Err(Some((status, body))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Future for ResponseFuture {
|
|
||||||
type Output = Result<Response<Full<Bytes>>, Infallible>;
|
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
let res = match self.get_mut().0.as_mut() {
|
|
||||||
Ok(key) => Response::builder()
|
|
||||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
|
||||||
.header(
|
|
||||||
http::header::CONNECTION,
|
|
||||||
HeaderValue::from_str("upgrade").unwrap(),
|
|
||||||
)
|
|
||||||
.header(
|
|
||||||
http::header::UPGRADE,
|
|
||||||
HeaderValue::from_str("websocket").unwrap(),
|
|
||||||
)
|
|
||||||
.header(
|
|
||||||
http::header::SEC_WEBSOCKET_ACCEPT,
|
|
||||||
sign(key.take().unwrap().as_bytes()),
|
|
||||||
)
|
|
||||||
.body(Full::new(Bytes::new()))
|
|
||||||
.unwrap(),
|
|
||||||
Err(err) => {
|
|
||||||
let (status, body) = err.take().unwrap();
|
|
||||||
Response::builder()
|
|
||||||
.status(status)
|
|
||||||
.body(Full::from(body))
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Poll::Ready(Ok(res))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sign(key: &[u8]) -> HeaderValue {
|
|
||||||
let mut sha1 = Sha1::default();
|
|
||||||
sha1.update(key);
|
|
||||||
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
|
||||||
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
|
||||||
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
|
||||||
}
|
}
|
||||||
|
294
src/ws/mod.rs
294
src/ws/mod.rs
@ -17,11 +17,48 @@
|
|||||||
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
//! # };
|
//! # };
|
||||||
//! ```
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Websocket handlers can also use extractors, however the first function
|
||||||
|
//! argument must be of type [`WebSocket`]:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use axum::{prelude::*, extract::{RequestParts, FromRequest}, ws::{ws, WebSocket}};
|
||||||
|
//! use http::{HeaderMap, StatusCode};
|
||||||
|
//!
|
||||||
|
//! /// An extractor that authorizes requests.
|
||||||
|
//! struct RequireAuth;
|
||||||
|
//!
|
||||||
|
//! #[async_trait::async_trait]
|
||||||
|
//! impl<B> FromRequest<B> for RequireAuth
|
||||||
|
//! where
|
||||||
|
//! B: Send,
|
||||||
|
//! {
|
||||||
|
//! type Rejection = StatusCode;
|
||||||
|
//!
|
||||||
|
//! async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
//! # unimplemented!()
|
||||||
|
//! // Put your auth logic here...
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! let app = route("/ws", ws(handle_socket));
|
||||||
|
//!
|
||||||
|
//! async fn handle_socket(
|
||||||
|
//! mut socket: WebSocket,
|
||||||
|
//! // Run `RequireAuth` for each request before upgrading.
|
||||||
|
//! _auth: RequireAuth,
|
||||||
|
//! ) {
|
||||||
|
//! // ...
|
||||||
|
//! }
|
||||||
|
//! # async {
|
||||||
|
//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
|
//! # };
|
||||||
|
//! ```
|
||||||
|
|
||||||
use crate::{
|
use crate::body::{box_body, BoxBody};
|
||||||
routing::EmptyRouter,
|
use crate::extract::{FromRequest, RequestParts};
|
||||||
service::{BoxResponseBody, OnMethod},
|
use crate::response::IntoResponse;
|
||||||
};
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use future::ResponseFuture;
|
use future::ResponseFuture;
|
||||||
use futures_util::{sink::SinkExt, stream::StreamExt};
|
use futures_util::{sink::SinkExt, stream::StreamExt};
|
||||||
@ -31,7 +68,11 @@ use http::{
|
|||||||
};
|
};
|
||||||
use http_body::Full;
|
use http_body::Full;
|
||||||
use hyper::upgrade::{OnUpgrade, Upgraded};
|
use hyper::upgrade::{OnUpgrade, Upgraded};
|
||||||
use std::{borrow::Cow, convert::Infallible, fmt, future::Future, task::Context, task::Poll};
|
use sha1::{Digest, Sha1};
|
||||||
|
use std::{
|
||||||
|
borrow::Cow, convert::Infallible, fmt, future::Future, marker::PhantomData, task::Context,
|
||||||
|
task::Poll,
|
||||||
|
};
|
||||||
use tokio_tungstenite::{
|
use tokio_tungstenite::{
|
||||||
tungstenite::protocol::{self, WebSocketConfig},
|
tungstenite::protocol::{self, WebSocketConfig},
|
||||||
WebSocketStream,
|
WebSocketStream,
|
||||||
@ -44,31 +85,103 @@ pub mod future;
|
|||||||
/// each connection.
|
/// each connection.
|
||||||
///
|
///
|
||||||
/// See the [module docs](crate::ws) for more details.
|
/// See the [module docs](crate::ws) for more details.
|
||||||
pub fn ws<F, Fut, B>(callback: F) -> OnMethod<BoxResponseBody<WebSocketUpgrade<F>, B>, EmptyRouter>
|
pub fn ws<F, B, T>(callback: F) -> WebSocketUpgrade<F, B, T>
|
||||||
where
|
where
|
||||||
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
F: WebSocketHandler<B, T>,
|
||||||
Fut: Future<Output = ()> + Send + 'static,
|
|
||||||
{
|
{
|
||||||
let svc = WebSocketUpgrade {
|
WebSocketUpgrade {
|
||||||
callback,
|
callback,
|
||||||
config: WebSocketConfig::default(),
|
config: WebSocketConfig::default(),
|
||||||
};
|
_request_body: PhantomData,
|
||||||
crate::service::get::<_, B>(svc)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Trait for async functions that can be used to handle websocket requests.
|
||||||
|
///
|
||||||
|
/// You shouldn't need to depend on this trait directly. It is automatically
|
||||||
|
/// implemented to closures of the right types.
|
||||||
|
///
|
||||||
|
/// See the [module docs](crate::ws) for more details.
|
||||||
|
#[async_trait]
|
||||||
|
pub trait WebSocketHandler<B, In>: Sized {
|
||||||
|
// This seals the trait. We cannot use the regular "sealed super trait"
|
||||||
|
// approach due to coherence.
|
||||||
|
#[doc(hidden)]
|
||||||
|
type Sealed: crate::handler::sealed::HiddentTrait;
|
||||||
|
|
||||||
|
/// Call the handler with the given websocket stream and input parsed by
|
||||||
|
/// extractors.
|
||||||
|
async fn call(self, stream: WebSocket, input: In);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<F, Fut, B> WebSocketHandler<B, ()> for F
|
||||||
|
where
|
||||||
|
F: FnOnce(WebSocket) -> Fut + Send,
|
||||||
|
Fut: Future<Output = ()> + Send,
|
||||||
|
B: Send,
|
||||||
|
{
|
||||||
|
type Sealed = crate::handler::sealed::Hidden;
|
||||||
|
|
||||||
|
async fn call(self, stream: WebSocket, _: ()) {
|
||||||
|
self(stream).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! impl_ws_handler {
|
||||||
|
() => {
|
||||||
|
};
|
||||||
|
|
||||||
|
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||||
|
#[async_trait]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
impl<F, Fut, B, $head, $($tail,)*> WebSocketHandler<B, ($head, $($tail,)*)> for F
|
||||||
|
where
|
||||||
|
B: Send,
|
||||||
|
$head: FromRequest<B> + Send + 'static,
|
||||||
|
$( $tail: FromRequest<B> + Send + 'static, )*
|
||||||
|
F: FnOnce(WebSocket, $head, $($tail,)*) -> Fut + Send,
|
||||||
|
Fut: Future<Output = ()> + Send,
|
||||||
|
{
|
||||||
|
type Sealed = crate::handler::sealed::Hidden;
|
||||||
|
|
||||||
|
async fn call(self, stream: WebSocket, ($head, $($tail,)*): ($head, $($tail,)*)) {
|
||||||
|
self(stream, $head, $($tail,)*).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_ws_handler!($($tail,)*);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl_ws_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||||
|
|
||||||
/// [`Service`] that upgrades connections to websockets and spawns a task to
|
/// [`Service`] that upgrades connections to websockets and spawns a task to
|
||||||
/// handle the stream.
|
/// handle the stream.
|
||||||
///
|
///
|
||||||
/// Created with [`ws`].
|
/// Created with [`ws`].
|
||||||
///
|
///
|
||||||
/// See the [module docs](crate::ws) for more details.
|
/// See the [module docs](crate::ws) for more details.
|
||||||
#[derive(Clone)]
|
pub struct WebSocketUpgrade<F, B, T> {
|
||||||
pub struct WebSocketUpgrade<F> {
|
|
||||||
callback: F,
|
callback: F,
|
||||||
config: WebSocketConfig,
|
config: WebSocketConfig,
|
||||||
|
_request_body: PhantomData<fn() -> (B, T)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F> fmt::Debug for WebSocketUpgrade<F> {
|
impl<F, B, T> Clone for WebSocketUpgrade<F, B, T>
|
||||||
|
where
|
||||||
|
F: Clone,
|
||||||
|
{
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
callback: self.callback.clone(),
|
||||||
|
config: self.config,
|
||||||
|
_request_body: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, B, T> fmt::Debug for WebSocketUpgrade<F, B, T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("WebSocketUpgrade")
|
f.debug_struct("WebSocketUpgrade")
|
||||||
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
|
.field("callback", &format_args!("{}", std::any::type_name::<F>()))
|
||||||
@ -77,7 +190,7 @@ impl<F> fmt::Debug for WebSocketUpgrade<F> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F> WebSocketUpgrade<F> {
|
impl<F, B, T> WebSocketUpgrade<F, B, T> {
|
||||||
/// Set the size of the internal message send queue.
|
/// Set the size of the internal message send queue.
|
||||||
pub fn max_send_queue(mut self, max: usize) -> Self {
|
pub fn max_send_queue(mut self, max: usize) -> Self {
|
||||||
self.config.max_send_queue = Some(max);
|
self.config.max_send_queue = Some(max);
|
||||||
@ -97,12 +210,13 @@ impl<F> WebSocketUpgrade<F> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<ReqBody, F, Fut> Service<Request<ReqBody>> for WebSocketUpgrade<F>
|
impl<ReqBody, F, T> Service<Request<ReqBody>> for WebSocketUpgrade<F, ReqBody, T>
|
||||||
where
|
where
|
||||||
F: FnOnce(WebSocket) -> Fut + Clone + Send + 'static,
|
F: WebSocketHandler<ReqBody, T> + Clone + Send + 'static,
|
||||||
Fut: Future<Output = ()> + Send + 'static,
|
T: FromRequest<ReqBody> + Send + 'static,
|
||||||
|
ReqBody: Send + 'static,
|
||||||
{
|
{
|
||||||
type Response = Response<Full<Bytes>>;
|
type Response = Response<BoxBody>;
|
||||||
type Error = Infallible;
|
type Error = Infallible;
|
||||||
type Future = ResponseFuture;
|
type Future = ResponseFuture;
|
||||||
|
|
||||||
@ -111,62 +225,104 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||||
if !header_eq(
|
let this = self.clone();
|
||||||
&req,
|
|
||||||
header::CONNECTION,
|
|
||||||
HeaderValue::from_static("upgrade"),
|
|
||||||
) {
|
|
||||||
return ResponseFuture::err(
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
"Connection header did not include 'upgrade'",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
|
ResponseFuture(Box::pin(async move {
|
||||||
return ResponseFuture::err(
|
if req.method() != http::Method::GET {
|
||||||
StatusCode::BAD_REQUEST,
|
return response(StatusCode::NOT_FOUND, "Request method must be `GET`");
|
||||||
"`Upgrade` header did not include 'websocket'",
|
}
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if !header_eq(
|
if !header_eq(
|
||||||
&req,
|
&req,
|
||||||
header::SEC_WEBSOCKET_VERSION,
|
header::CONNECTION,
|
||||||
HeaderValue::from_static("13"),
|
HeaderValue::from_static("upgrade"),
|
||||||
) {
|
) {
|
||||||
return ResponseFuture::err(
|
return response(
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
"`Sec-Websocket-Version` header did not include '13'",
|
"Connection header did not include 'upgrade'",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
if !header_eq(&req, header::UPGRADE, HeaderValue::from_static("websocket")) {
|
||||||
key
|
return response(
|
||||||
} else {
|
StatusCode::BAD_REQUEST,
|
||||||
return ResponseFuture::err(
|
"`Upgrade` header did not include 'websocket'",
|
||||||
StatusCode::BAD_REQUEST,
|
);
|
||||||
"`Sec-Websocket-Key` header missing",
|
}
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
if !header_eq(
|
||||||
|
&req,
|
||||||
|
header::SEC_WEBSOCKET_VERSION,
|
||||||
|
HeaderValue::from_static("13"),
|
||||||
|
) {
|
||||||
|
return response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"`Sec-Websocket-Version` header did not include '13'",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let config = self.config;
|
let key = if let Some(key) = req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
|
||||||
let callback = self.callback.clone();
|
key
|
||||||
|
} else {
|
||||||
|
return response(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"`Sec-Websocket-Key` header missing",
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
tokio::spawn(async move {
|
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>().unwrap();
|
||||||
let upgraded = on_upgrade.await.unwrap();
|
|
||||||
let socket =
|
|
||||||
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
|
|
||||||
.await;
|
|
||||||
let socket = WebSocket { inner: socket };
|
|
||||||
callback(socket).await;
|
|
||||||
});
|
|
||||||
|
|
||||||
ResponseFuture::ok(key)
|
let config = this.config;
|
||||||
|
let callback = this.callback.clone();
|
||||||
|
|
||||||
|
let mut req = RequestParts::new(req);
|
||||||
|
let input = match T::from_request(&mut req).await {
|
||||||
|
Ok(input) => input,
|
||||||
|
Err(rejection) => {
|
||||||
|
let res = rejection.into_response().map(box_body);
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let upgraded = on_upgrade.await.unwrap();
|
||||||
|
let socket = WebSocketStream::from_raw_socket(
|
||||||
|
upgraded,
|
||||||
|
protocol::Role::Server,
|
||||||
|
Some(config),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let socket = WebSocket { inner: socket };
|
||||||
|
callback.call(socket, input).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let res = Response::builder()
|
||||||
|
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||||
|
.header(
|
||||||
|
http::header::CONNECTION,
|
||||||
|
HeaderValue::from_str("upgrade").unwrap(),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
http::header::UPGRADE,
|
||||||
|
HeaderValue::from_str("websocket").unwrap(),
|
||||||
|
)
|
||||||
|
.header(http::header::SEC_WEBSOCKET_ACCEPT, sign(key.as_bytes()))
|
||||||
|
.body(box_body(Full::new(Bytes::new())))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn response<E>(status: StatusCode, body: &'static str) -> Result<Response<BoxBody>, E> {
|
||||||
|
let res = Response::builder()
|
||||||
|
.status(status)
|
||||||
|
.body(box_body(Full::from(body)))
|
||||||
|
.unwrap();
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
||||||
if let Some(header) = req.headers().get(&key) {
|
if let Some(header) = req.headers().get(&key) {
|
||||||
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
|
||||||
@ -175,6 +331,14 @@ fn header_eq<B>(req: &Request<B>, key: HeaderName, value: HeaderValue) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sign(key: &[u8]) -> HeaderValue {
|
||||||
|
let mut sha1 = Sha1::default();
|
||||||
|
sha1.update(key);
|
||||||
|
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
|
||||||
|
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
|
||||||
|
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
|
||||||
|
}
|
||||||
|
|
||||||
/// A stream of websocket messages.
|
/// A stream of websocket messages.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WebSocket {
|
pub struct WebSocket {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user