diff --git a/examples/websockets/Cargo.toml b/examples/websockets/Cargo.toml index cb307989..8c2439f0 100644 --- a/examples/websockets/Cargo.toml +++ b/examples/websockets/Cargo.toml @@ -6,8 +6,20 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["ws", "headers"] } +futures = "0.3" +futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } headers = "0.3" tokio = { version = "1.0", features = ["full"] } +tokio-tungstenite = "0.18.0" +tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.3.0", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[[bin]] +name = "example-websockets" +path = "src/main.rs" + +[[bin]] +name = "example-client" +path = "src/client.rs" diff --git a/examples/websockets/assets/index.html b/examples/websockets/assets/index.html index 390bb86b..8a333db1 100644 --- a/examples/websockets/assets/index.html +++ b/examples/websockets/assets/index.html @@ -1 +1,2 @@ +Open the console to see stuff, then refresh to initiate exchange. diff --git a/examples/websockets/assets/script.js b/examples/websockets/assets/script.js index 3f166736..cdc0f9a4 100644 --- a/examples/websockets/assets/script.js +++ b/examples/websockets/assets/script.js @@ -7,3 +7,19 @@ socket.addEventListener('open', function (event) { socket.addEventListener('message', function (event) { console.log('Message from server ', event.data); }); + + +setTimeout(() => { + const obj = { hello: "world" }; + const blob = new Blob([JSON.stringify(obj, null, 2)], { + type: "application/json", + }); + console.log("Sending blob over websocket"); + socket.send(blob); +}, 1000); + +setTimeout(() => { + socket.send('About done here...'); + console.log("Sending close over websocket"); + socket.close(3000, "Crash and Burn!"); +}, 3000); \ No newline at end of file diff --git a/examples/websockets/src/client.rs b/examples/websockets/src/client.rs new file mode 100644 index 00000000..9b160048 --- /dev/null +++ b/examples/websockets/src/client.rs @@ -0,0 +1,160 @@ +//! Based on tokio-tungstenite example websocket client, but with multiple +//! concurrent websocket clients in one package +//! +//! This will connect to a server specified in the SERVER with N_CLIENTS +//! concurrent connections, and then flood some test messages over websocket. +//! This will also print whatever it gets into stdout. +//! +//! Note that this is not currently optimized for performance, especially around +//! stdout mutex management. Rather it's intended to show an example of working with axum's +//! websocket server and how the client-side and server-side code can be quite similar. +//! + +use futures_util::stream::FuturesUnordered; +use futures_util::{SinkExt, StreamExt}; +use std::borrow::Cow; +use std::ops::ControlFlow; +use std::time::Instant; + +// we will use tungstenite for websocket client impl (same library as what axum is using) +use tokio_tungstenite::{ + connect_async, + tungstenite::protocol::{frame::coding::CloseCode, CloseFrame, Message}, +}; + +const N_CLIENTS: usize = 2; //set to desired number +const SERVER: &'static str = "ws://127.0.0.1:3000/ws"; + +#[tokio::main] +async fn main() { + let start_time = Instant::now(); + //spawn several clients that will concurrently talk to the server + let mut clients = (0..N_CLIENTS) + .into_iter() + .map(|cli| tokio::spawn(spawn_client(cli))) + .collect::>(); + + //wait for all our clients to exit + while clients.next().await.is_some() {} + + let end_time = Instant::now(); + + //total time should be the same no matter how many clients we spawn + println!( + "Total time taken {:#?} with {N_CLIENTS} concurrent clients, should be about 6.45 seconds.", + end_time - start_time + ); +} + +//creates a client. quietly exits on failure. +async fn spawn_client(who: usize) { + let ws_stream = match connect_async(SERVER).await { + Ok((stream, response)) => { + println!("Handshake for client {} has been completed", who); + // This will be the HTTP response, same as with server this is the last moment we + // can still access HTTP stuff. + println!("Server response was {:?}", response); + stream + } + Err(e) => { + println!("WebSocket handshake for client {who} failed with {e}!"); + return; + } + }; + + let (mut sender, mut receiver) = ws_stream.split(); + + //we can ping the server for start + sender + .send(Message::Ping("Hello, Server!".into())) + .await + .expect("Can not send!"); + + //spawn an async sender to push some more messages into the server + let mut send_task = tokio::spawn(async move { + for i in 1..30 { + // In any websocket error, break loop. + if sender + .send(Message::Text(format!("Message number {}...", i))) + .await + .is_err() + { + //just as with server, if send fails there is nothing we can do but exit. + return; + } + + tokio::time::sleep(std::time::Duration::from_millis(300)).await; + } + + // When we are done we may want our client to close connection cleanly. + println!("Sending close to {}...", who); + if let Err(e) = sender + .send(Message::Close(Some(CloseFrame { + code: CloseCode::Normal, + reason: Cow::from("Goodbye"), + }))) + .await + { + println!("Could not send Close due to {:?}, probably it is ok?", e); + }; + }); + + //receiver just prints whatever it gets + let mut recv_task = tokio::spawn(async move { + while let Some(Ok(msg)) = receiver.next().await { + // print message and break if instructed to do so + if process_message(msg, who).is_break() { + break; + } + } + }); + + //wait for either task to finish and kill the other task + tokio::select! { + _ = (&mut send_task) => { + recv_task.abort(); + }, + _ = (&mut recv_task) => { + send_task.abort(); + } + } +} + +/// Function to handle messages we get (with a slight twist that Frame variant is visible +/// since we are working with the underlying tungstenite library directly without axum here). +fn process_message(msg: Message, who: usize) -> ControlFlow<(), ()> { + match msg { + Message::Text(t) => { + println!(">>> {} got str: {:?}", who, t); + } + Message::Binary(d) => { + println!(">>> {} got {} bytes: {:?}", who, d.len(), d); + } + Message::Close(c) => { + if let Some(cf) = c { + println!( + ">>> {} got close with code {} and reason `{}`", + who, cf.code, cf.reason + ); + } else { + println!(">>> {} somehow got close message without CloseFrame", who); + } + return ControlFlow::Break(()); + } + + Message::Pong(v) => { + println!(">>> {} got pong with {:?}", who, v); + } + // Just as with axum server, the underlying tungstenite websocket library + // will handle Ping for you automagically by replying with Pong and copying the + // v according to spec. But if you need the contents of the pings you can see them here. + Message::Ping(v) => { + println!(">>> {} got ping with {:?}", who, v); + } + + Message::Frame(_) => { + unreachable!("This is never supposed to happen") + } + } + ControlFlow::Continue(()) +} diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index 2ee9f749..26a31a24 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -1,9 +1,15 @@ //! Example websocket server. //! -//! Run with +//! Run the server with //! //! ```not_rust -//! cd examples && cargo run -p example-websockets +//! cargo run -p example-websockets +//! firefox http://localhost:3000 +//! ``` +//! +//! Alternatively you can run the rust client with +//! ```not_rust +//! cargo run -p example-client //! ``` use axum::{ @@ -16,13 +22,24 @@ use axum::{ routing::{get, get_service}, Router, }; + +use std::borrow::Cow; +use std::ops::ControlFlow; use std::{net::SocketAddr, path::PathBuf}; use tower_http::{ services::ServeDir, trace::{DefaultMakeSpan, TraceLayer}, }; + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +//allows to extract the IP of connecting user +use axum::extract::connect_info::ConnectInfo; +use axum::extract::ws::CloseFrame; + +//allows to split the websocket stream into separate TX and RX branches +use futures::{sink::SinkExt, stream::StreamExt}; + #[tokio::main] async fn main() { tracing_subscriber::registry() @@ -59,58 +76,173 @@ async fn main() { let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); axum::Server::bind(&addr) - .serve(app.into_make_service()) + .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); } +/// The handler for the HTTP request (this gets called when the HTTP GET lands at the start +/// of websocket negotiation. After this completes, the actual switching from HTTP to +/// websocket protocol will occur. +/// This is the last point where we can extract TCP/IP metadata such as IP address of the client +/// as well as things from HTTP headers such as user-agent of the browser etc. async fn ws_handler( ws: WebSocketUpgrade, user_agent: Option>, + ConnectInfo(addr): ConnectInfo, ) -> impl IntoResponse { - if let Some(TypedHeader(user_agent)) = user_agent { - println!("`{}` connected", user_agent.as_str()); - } - - ws.on_upgrade(handle_socket) + let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { + user_agent.to_string() + } else { + String::from("Unknown browser") + }; + println!("`{}` at {} connected.", user_agent, addr.to_string()); + // finalize the upgrade process by returning upgrade callback. + // we can customize the callback by sending additional info such as address. + ws.on_upgrade(move |socket| handle_socket(socket, addr)) } -async fn handle_socket(mut socket: WebSocket) { +/// Actual websocket statemachine (one will be spawned per connection) +async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { + //send a ping (unsupported by some browsers) just to kick things off and get a response + if let Ok(_) = socket.send(Message::Ping(vec![1, 2, 3])).await { + println!("Pinged {}...", who); + } else { + println!("Could not send ping {}!", who); + // no Error here since the only thing we can do is to close the connection. + // If we can not send messages, there is no way to salvage the statemachine anyway. + return; + } + + // receive single message form a client (we can either receive or send with socket). + // this will likely be the Pong for our Ping or a hello message from client. + // waiting for message from a client will block this task, but will not block other client's + // connections. if let Some(msg) = socket.recv().await { if let Ok(msg) = msg { - match msg { - Message::Text(t) => { - println!("client sent str: {:?}", t); - } - Message::Binary(_) => { - println!("client sent binary data"); - } - Message::Ping(_) => { - println!("socket ping"); - } - Message::Pong(_) => { - println!("socket pong"); - } - Message::Close(_) => { - println!("client disconnected"); - return; - } + if process_message(msg, who).is_break() { + return; } } else { - println!("client disconnected"); + println!("client {} abruptly disconnected", who); return; } } - loop { + // Since each client gets individual statemachine, we can pause handling + // when necessary to wait for some external event (in this case illustrated by sleeping). + // Waiting for this client to finish getting his greetings does not prevent other clients form + // connecting to server and receiving their greetings. + for i in 1..5 { if socket - .send(Message::Text(String::from("Hi!"))) + .send(Message::Text(String::from(format!("Hi {} times!", i)))) .await .is_err() { - println!("client disconnected"); + println!("client {} abruptly disconnected", who); return; } - tokio::time::sleep(std::time::Duration::from_secs(3)).await; + tokio::time::sleep(std::time::Duration::from_millis(100)).await; } + + // By splitting socket we can send and receive at the same time. In this example we will send + // unsolicited messages to client based on some sort of server's internal event (i.e .timer). + let (mut sender, mut receiver) = socket.split(); + + // Spawn a task that will push several messages to the client (does not matter what client does) + let mut send_task = tokio::spawn(async move { + let n_msg = 20; + for i in 0..n_msg { + // In case of any websocket error, we exit. + if sender + .send(Message::Text(format!("Server message {} ...", i))) + .await + .is_err() + { + return i; + } + + tokio::time::sleep(std::time::Duration::from_millis(300)).await; + } + + println!("Sending close to {}...", who); + if let Err(e) = sender + .send(Message::Close(Some(CloseFrame { + code: axum::extract::ws::close_code::NORMAL, + reason: Cow::from("Goodbye"), + }))) + .await + { + println!("Could not send Close due to {}, probably it is ok?", e); + } + n_msg + }); + + // This second task will receive messages from client and print them on server console + let mut recv_task = tokio::spawn(async move { + let mut cnt = 0; + while let Some(Ok(msg)) = receiver.next().await { + cnt += 1; + // print message and break if instructed to do so + if process_message(msg, who).is_break() { + break; + } + } + cnt + }); + + // If any one of the tasks exit, abort the other. + tokio::select! { + rv_a = (&mut send_task) => { + match rv_a { + Ok(a) => println!("{} messages sent to {}", a, who), + Err(a) => println!("Error sending messages {:?}", a) + } + recv_task.abort(); + }, + rv_b = (&mut recv_task) => { + match rv_b { + Ok(b) => println!("Received {} messages", b), + Err(b) => println!("Error receiving messages {:?}", b) + } + send_task.abort(); + } + } + + // returning from the handler closes the websocket connection + println!("Websocket context {} destroyed", who); +} + +/// helper to print contents of messages to stdout. Has special treatment for Close. +fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { + match msg { + Message::Text(t) => { + println!(">>> {} sent str: {:?}", who, t); + } + Message::Binary(d) => { + println!(">>> {} sent {} bytes: {:?}", who, d.len(), d); + } + Message::Close(c) => { + if let Some(cf) = c { + println!( + ">>> {} sent close with code {} and reason `{}`", + who, cf.code, cf.reason + ); + } else { + println!(">>> {} somehow sent close message without CloseFrame", who); + } + return ControlFlow::Break(()); + } + + Message::Pong(v) => { + println!(">>> {} sent pong with {:?}", who, v); + } + // You should never need to manually handle Message::Ping, as axum's websocket library + // will do so for you automagically by replying with Pong and copying the v according to + // spec. But if you need the contents of the pings you can see them here. + Message::Ping(v) => { + println!(">>> {} sent ping with {:?}", who, v); + } + } + ControlFlow::Continue(()) }