Fix session cookie example (#638)

* refactor: refine session cookie example

* refactor: refine session_cookie extraction

* refactor: avoid to_owned()

* chore: refine debug log

Co-authored-by: 荒野無燈 <ttys3.rust@gmail.com>
This commit is contained in:
ttys3 2021-12-22 22:27:13 +08:00 committed by GitHub
parent 4c48efc861
commit 3841ef44d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 27 deletions

View File

@ -5,7 +5,7 @@ edition = "2018"
publish = false publish = false
[dependencies] [dependencies]
axum = { path = "../../axum" } axum = { path = "../../axum", features = ["headers"] }
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version="0.3", features = ["env-filter"] } tracing-subscriber = { version="0.3", features = ["env-filter"] }

View File

@ -7,7 +7,8 @@
use async_session::{MemoryStore, Session, SessionStore as _}; use async_session::{MemoryStore, Session, SessionStore as _};
use axum::{ use axum::{
async_trait, async_trait,
extract::{Extension, FromRequest, RequestParts}, extract::{Extension, FromRequest, RequestParts, TypedHeader},
headers::Cookie,
http::{ http::{
self, self,
header::{HeaderMap, HeaderValue}, header::{HeaderMap, HeaderValue},
@ -18,9 +19,12 @@ use axum::{
AddExtensionLayer, Router, AddExtensionLayer, Router,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::net::SocketAddr; use std::net::SocketAddr;
use uuid::Uuid; use uuid::Uuid;
const AXUM_SESSION_COOKIE_NAME: &str = "axum_session";
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// Set the RUST_LOG, if it hasn't been explicitly defined // Set the RUST_LOG, if it hasn't been explicitly defined
@ -45,26 +49,34 @@ async fn main() {
} }
async fn handler(user_id: UserIdFromSession) -> impl IntoResponse { async fn handler(user_id: UserIdFromSession) -> impl IntoResponse {
let (headers, user_id) = match user_id { let (headers, user_id, create_cookie) = match user_id {
UserIdFromSession::FoundUserId(user_id) => (HeaderMap::new(), user_id), UserIdFromSession::FoundUserId(user_id) => (HeaderMap::new(), user_id, false),
UserIdFromSession::CreatedFreshUserId { user_id, cookie } => { UserIdFromSession::CreatedFreshUserId(new_user) => {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(http::header::SET_COOKIE, cookie); headers.insert(http::header::SET_COOKIE, new_user.cookie);
(headers, user_id) (headers, new_user.user_id, true)
} }
}; };
dbg!(user_id); tracing::debug!("handler: user_id={:?} send_headers={:?}", user_id, headers);
headers (
headers,
format!(
"user_id={:?} session_cookie_name={} create_new_session_cookie={}",
user_id, AXUM_SESSION_COOKIE_NAME, create_cookie
),
)
}
struct FreshUserId {
pub user_id: UserId,
pub cookie: HeaderValue,
} }
enum UserIdFromSession { enum UserIdFromSession {
FoundUserId(UserId), FoundUserId(UserId),
CreatedFreshUserId { CreatedFreshUserId(FreshUserId),
user_id: UserId,
cookie: HeaderValue,
},
} }
#[async_trait] #[async_trait]
@ -79,28 +91,45 @@ where
.await .await
.expect("`MemoryStore` extension missing"); .expect("`MemoryStore` extension missing");
let headers = req.headers().expect("other extractor taken headers"); let cookie = Option::<TypedHeader<Cookie>>::from_request(req)
.await
.unwrap();
let cookie = if let Some(cookie) = headers let session_cookie = cookie
.get(http::header::COOKIE) .as_ref()
.and_then(|value| value.to_str().ok()) .and_then(|cookie| cookie.get(AXUM_SESSION_COOKIE_NAME));
.map(|value| value.to_string())
{ // return the new created session cookie for client
cookie if session_cookie.is_none() {
} else {
let user_id = UserId::new(); let user_id = UserId::new();
let mut session = Session::new(); let mut session = Session::new();
session.insert("user_id", user_id).unwrap(); session.insert("user_id", user_id).unwrap();
let cookie = store.store_session(session).await.unwrap().unwrap(); let cookie = store.store_session(session).await.unwrap().unwrap();
return Ok(Self::CreatedFreshUserId(FreshUserId {
return Ok(Self::CreatedFreshUserId {
user_id, user_id,
cookie: cookie.parse().unwrap(), cookie: HeaderValue::from_str(
}); format!("{}={}", AXUM_SESSION_COOKIE_NAME, cookie).as_str(),
}; )
.unwrap(),
}));
}
let user_id = if let Some(session) = store.load_session(cookie).await.unwrap() { tracing::debug!(
"UserIdFromSession: got session cookie from user agent, {}={}",
AXUM_SESSION_COOKIE_NAME,
session_cookie.unwrap()
);
// continue to decode the session cookie
let user_id = if let Some(session) = store
.load_session(session_cookie.unwrap().to_owned())
.await
.unwrap()
{
if let Some(user_id) = session.get::<UserId>("user_id") { if let Some(user_id) = session.get::<UserId>("user_id") {
tracing::debug!(
"UserIdFromSession: session decoded success, user_id={:?}",
user_id
);
user_id user_id
} else { } else {
return Err(( return Err((
@ -109,6 +138,11 @@ where
)); ));
} }
} else { } else {
tracing::debug!(
"UserIdFromSession: err session not exists in store, {}={}",
AXUM_SESSION_COOKIE_NAME,
session_cookie.unwrap()
);
return Err((StatusCode::BAD_REQUEST, "No session found for cookie")); return Err((StatusCode::BAD_REQUEST, "No session found for cookie"));
}; };