diff --git a/examples/websockets/Cargo.toml b/examples/websockets/Cargo.toml
index cb307989..8c2439f0 100644
--- a/examples/websockets/Cargo.toml
+++ b/examples/websockets/Cargo.toml
@@ -6,8 +6,20 @@ publish = false
[dependencies]
axum = { path = "../../axum", features = ["ws", "headers"] }
+futures = "0.3"
+futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
headers = "0.3"
tokio = { version = "1.0", features = ["full"] }
+tokio-tungstenite = "0.18.0"
+tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.3.0", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
+
+[[bin]]
+name = "example-websockets"
+path = "src/main.rs"
+
+[[bin]]
+name = "example-client"
+path = "src/client.rs"
diff --git a/examples/websockets/assets/index.html b/examples/websockets/assets/index.html
index 390bb86b..8a333db1 100644
--- a/examples/websockets/assets/index.html
+++ b/examples/websockets/assets/index.html
@@ -1 +1,2 @@
+Open the console to see stuff, then refresh to initiate exchange.
diff --git a/examples/websockets/assets/script.js b/examples/websockets/assets/script.js
index 3f166736..cdc0f9a4 100644
--- a/examples/websockets/assets/script.js
+++ b/examples/websockets/assets/script.js
@@ -7,3 +7,19 @@ socket.addEventListener('open', function (event) {
socket.addEventListener('message', function (event) {
console.log('Message from server ', event.data);
});
+
+
+setTimeout(() => {
+ const obj = { hello: "world" };
+ const blob = new Blob([JSON.stringify(obj, null, 2)], {
+ type: "application/json",
+ });
+ console.log("Sending blob over websocket");
+ socket.send(blob);
+}, 1000);
+
+setTimeout(() => {
+ socket.send('About done here...');
+ console.log("Sending close over websocket");
+ socket.close(3000, "Crash and Burn!");
+}, 3000);
\ No newline at end of file
diff --git a/examples/websockets/src/client.rs b/examples/websockets/src/client.rs
new file mode 100644
index 00000000..9b160048
--- /dev/null
+++ b/examples/websockets/src/client.rs
@@ -0,0 +1,160 @@
+//! Based on tokio-tungstenite example websocket client, but with multiple
+//! concurrent websocket clients in one package
+//!
+//! This will connect to a server specified in the SERVER with N_CLIENTS
+//! concurrent connections, and then flood some test messages over websocket.
+//! This will also print whatever it gets into stdout.
+//!
+//! Note that this is not currently optimized for performance, especially around
+//! stdout mutex management. Rather it's intended to show an example of working with axum's
+//! websocket server and how the client-side and server-side code can be quite similar.
+//!
+
+use futures_util::stream::FuturesUnordered;
+use futures_util::{SinkExt, StreamExt};
+use std::borrow::Cow;
+use std::ops::ControlFlow;
+use std::time::Instant;
+
+// we will use tungstenite for websocket client impl (same library as what axum is using)
+use tokio_tungstenite::{
+ connect_async,
+ tungstenite::protocol::{frame::coding::CloseCode, CloseFrame, Message},
+};
+
+const N_CLIENTS: usize = 2; //set to desired number
+const SERVER: &'static str = "ws://127.0.0.1:3000/ws";
+
+#[tokio::main]
+async fn main() {
+ let start_time = Instant::now();
+ //spawn several clients that will concurrently talk to the server
+ let mut clients = (0..N_CLIENTS)
+ .into_iter()
+ .map(|cli| tokio::spawn(spawn_client(cli)))
+ .collect::>();
+
+ //wait for all our clients to exit
+ while clients.next().await.is_some() {}
+
+ let end_time = Instant::now();
+
+ //total time should be the same no matter how many clients we spawn
+ println!(
+ "Total time taken {:#?} with {N_CLIENTS} concurrent clients, should be about 6.45 seconds.",
+ end_time - start_time
+ );
+}
+
+//creates a client. quietly exits on failure.
+async fn spawn_client(who: usize) {
+ let ws_stream = match connect_async(SERVER).await {
+ Ok((stream, response)) => {
+ println!("Handshake for client {} has been completed", who);
+ // This will be the HTTP response, same as with server this is the last moment we
+ // can still access HTTP stuff.
+ println!("Server response was {:?}", response);
+ stream
+ }
+ Err(e) => {
+ println!("WebSocket handshake for client {who} failed with {e}!");
+ return;
+ }
+ };
+
+ let (mut sender, mut receiver) = ws_stream.split();
+
+ //we can ping the server for start
+ sender
+ .send(Message::Ping("Hello, Server!".into()))
+ .await
+ .expect("Can not send!");
+
+ //spawn an async sender to push some more messages into the server
+ let mut send_task = tokio::spawn(async move {
+ for i in 1..30 {
+ // In any websocket error, break loop.
+ if sender
+ .send(Message::Text(format!("Message number {}...", i)))
+ .await
+ .is_err()
+ {
+ //just as with server, if send fails there is nothing we can do but exit.
+ return;
+ }
+
+ tokio::time::sleep(std::time::Duration::from_millis(300)).await;
+ }
+
+ // When we are done we may want our client to close connection cleanly.
+ println!("Sending close to {}...", who);
+ if let Err(e) = sender
+ .send(Message::Close(Some(CloseFrame {
+ code: CloseCode::Normal,
+ reason: Cow::from("Goodbye"),
+ })))
+ .await
+ {
+ println!("Could not send Close due to {:?}, probably it is ok?", e);
+ };
+ });
+
+ //receiver just prints whatever it gets
+ let mut recv_task = tokio::spawn(async move {
+ while let Some(Ok(msg)) = receiver.next().await {
+ // print message and break if instructed to do so
+ if process_message(msg, who).is_break() {
+ break;
+ }
+ }
+ });
+
+ //wait for either task to finish and kill the other task
+ tokio::select! {
+ _ = (&mut send_task) => {
+ recv_task.abort();
+ },
+ _ = (&mut recv_task) => {
+ send_task.abort();
+ }
+ }
+}
+
+/// Function to handle messages we get (with a slight twist that Frame variant is visible
+/// since we are working with the underlying tungstenite library directly without axum here).
+fn process_message(msg: Message, who: usize) -> ControlFlow<(), ()> {
+ match msg {
+ Message::Text(t) => {
+ println!(">>> {} got str: {:?}", who, t);
+ }
+ Message::Binary(d) => {
+ println!(">>> {} got {} bytes: {:?}", who, d.len(), d);
+ }
+ Message::Close(c) => {
+ if let Some(cf) = c {
+ println!(
+ ">>> {} got close with code {} and reason `{}`",
+ who, cf.code, cf.reason
+ );
+ } else {
+ println!(">>> {} somehow got close message without CloseFrame", who);
+ }
+ return ControlFlow::Break(());
+ }
+
+ Message::Pong(v) => {
+ println!(">>> {} got pong with {:?}", who, v);
+ }
+ // Just as with axum server, the underlying tungstenite websocket library
+ // will handle Ping for you automagically by replying with Pong and copying the
+ // v according to spec. But if you need the contents of the pings you can see them here.
+ Message::Ping(v) => {
+ println!(">>> {} got ping with {:?}", who, v);
+ }
+
+ Message::Frame(_) => {
+ unreachable!("This is never supposed to happen")
+ }
+ }
+ ControlFlow::Continue(())
+}
diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs
index 2ee9f749..26a31a24 100644
--- a/examples/websockets/src/main.rs
+++ b/examples/websockets/src/main.rs
@@ -1,9 +1,15 @@
//! Example websocket server.
//!
-//! Run with
+//! Run the server with
//!
//! ```not_rust
-//! cd examples && cargo run -p example-websockets
+//! cargo run -p example-websockets
+//! firefox http://localhost:3000
+//! ```
+//!
+//! Alternatively you can run the rust client with
+//! ```not_rust
+//! cargo run -p example-client
//! ```
use axum::{
@@ -16,13 +22,24 @@ use axum::{
routing::{get, get_service},
Router,
};
+
+use std::borrow::Cow;
+use std::ops::ControlFlow;
use std::{net::SocketAddr, path::PathBuf};
use tower_http::{
services::ServeDir,
trace::{DefaultMakeSpan, TraceLayer},
};
+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+//allows to extract the IP of connecting user
+use axum::extract::connect_info::ConnectInfo;
+use axum::extract::ws::CloseFrame;
+
+//allows to split the websocket stream into separate TX and RX branches
+use futures::{sink::SinkExt, stream::StreamExt};
+
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
@@ -59,58 +76,173 @@ async fn main() {
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
- .serve(app.into_make_service())
+ .serve(app.into_make_service_with_connect_info::())
.await
.unwrap();
}
+/// The handler for the HTTP request (this gets called when the HTTP GET lands at the start
+/// of websocket negotiation. After this completes, the actual switching from HTTP to
+/// websocket protocol will occur.
+/// This is the last point where we can extract TCP/IP metadata such as IP address of the client
+/// as well as things from HTTP headers such as user-agent of the browser etc.
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option>,
+ ConnectInfo(addr): ConnectInfo,
) -> impl IntoResponse {
- if let Some(TypedHeader(user_agent)) = user_agent {
- println!("`{}` connected", user_agent.as_str());
- }
-
- ws.on_upgrade(handle_socket)
+ let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
+ user_agent.to_string()
+ } else {
+ String::from("Unknown browser")
+ };
+ println!("`{}` at {} connected.", user_agent, addr.to_string());
+ // finalize the upgrade process by returning upgrade callback.
+ // we can customize the callback by sending additional info such as address.
+ ws.on_upgrade(move |socket| handle_socket(socket, addr))
}
-async fn handle_socket(mut socket: WebSocket) {
+/// Actual websocket statemachine (one will be spawned per connection)
+async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
+ //send a ping (unsupported by some browsers) just to kick things off and get a response
+ if let Ok(_) = socket.send(Message::Ping(vec![1, 2, 3])).await {
+ println!("Pinged {}...", who);
+ } else {
+ println!("Could not send ping {}!", who);
+ // no Error here since the only thing we can do is to close the connection.
+ // If we can not send messages, there is no way to salvage the statemachine anyway.
+ return;
+ }
+
+ // receive single message form a client (we can either receive or send with socket).
+ // this will likely be the Pong for our Ping or a hello message from client.
+ // waiting for message from a client will block this task, but will not block other client's
+ // connections.
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
- match msg {
- Message::Text(t) => {
- println!("client sent str: {:?}", t);
- }
- Message::Binary(_) => {
- println!("client sent binary data");
- }
- Message::Ping(_) => {
- println!("socket ping");
- }
- Message::Pong(_) => {
- println!("socket pong");
- }
- Message::Close(_) => {
- println!("client disconnected");
- return;
- }
+ if process_message(msg, who).is_break() {
+ return;
}
} else {
- println!("client disconnected");
+ println!("client {} abruptly disconnected", who);
return;
}
}
- loop {
+ // Since each client gets individual statemachine, we can pause handling
+ // when necessary to wait for some external event (in this case illustrated by sleeping).
+ // Waiting for this client to finish getting his greetings does not prevent other clients form
+ // connecting to server and receiving their greetings.
+ for i in 1..5 {
if socket
- .send(Message::Text(String::from("Hi!")))
+ .send(Message::Text(String::from(format!("Hi {} times!", i))))
.await
.is_err()
{
- println!("client disconnected");
+ println!("client {} abruptly disconnected", who);
return;
}
- tokio::time::sleep(std::time::Duration::from_secs(3)).await;
+ tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
+
+ // By splitting socket we can send and receive at the same time. In this example we will send
+ // unsolicited messages to client based on some sort of server's internal event (i.e .timer).
+ let (mut sender, mut receiver) = socket.split();
+
+ // Spawn a task that will push several messages to the client (does not matter what client does)
+ let mut send_task = tokio::spawn(async move {
+ let n_msg = 20;
+ for i in 0..n_msg {
+ // In case of any websocket error, we exit.
+ if sender
+ .send(Message::Text(format!("Server message {} ...", i)))
+ .await
+ .is_err()
+ {
+ return i;
+ }
+
+ tokio::time::sleep(std::time::Duration::from_millis(300)).await;
+ }
+
+ println!("Sending close to {}...", who);
+ if let Err(e) = sender
+ .send(Message::Close(Some(CloseFrame {
+ code: axum::extract::ws::close_code::NORMAL,
+ reason: Cow::from("Goodbye"),
+ })))
+ .await
+ {
+ println!("Could not send Close due to {}, probably it is ok?", e);
+ }
+ n_msg
+ });
+
+ // This second task will receive messages from client and print them on server console
+ let mut recv_task = tokio::spawn(async move {
+ let mut cnt = 0;
+ while let Some(Ok(msg)) = receiver.next().await {
+ cnt += 1;
+ // print message and break if instructed to do so
+ if process_message(msg, who).is_break() {
+ break;
+ }
+ }
+ cnt
+ });
+
+ // If any one of the tasks exit, abort the other.
+ tokio::select! {
+ rv_a = (&mut send_task) => {
+ match rv_a {
+ Ok(a) => println!("{} messages sent to {}", a, who),
+ Err(a) => println!("Error sending messages {:?}", a)
+ }
+ recv_task.abort();
+ },
+ rv_b = (&mut recv_task) => {
+ match rv_b {
+ Ok(b) => println!("Received {} messages", b),
+ Err(b) => println!("Error receiving messages {:?}", b)
+ }
+ send_task.abort();
+ }
+ }
+
+ // returning from the handler closes the websocket connection
+ println!("Websocket context {} destroyed", who);
+}
+
+/// helper to print contents of messages to stdout. Has special treatment for Close.
+fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
+ match msg {
+ Message::Text(t) => {
+ println!(">>> {} sent str: {:?}", who, t);
+ }
+ Message::Binary(d) => {
+ println!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
+ }
+ Message::Close(c) => {
+ if let Some(cf) = c {
+ println!(
+ ">>> {} sent close with code {} and reason `{}`",
+ who, cf.code, cf.reason
+ );
+ } else {
+ println!(">>> {} somehow sent close message without CloseFrame", who);
+ }
+ return ControlFlow::Break(());
+ }
+
+ Message::Pong(v) => {
+ println!(">>> {} sent pong with {:?}", who, v);
+ }
+ // You should never need to manually handle Message::Ping, as axum's websocket library
+ // will do so for you automagically by replying with Pong and copying the v according to
+ // spec. But if you need the contents of the pings you can see them here.
+ Message::Ping(v) => {
+ println!(">>> {} sent ping with {:?}", who, v);
+ }
+ }
+ ControlFlow::Continue(())
}