mirror of
https://github.com/tokio-rs/axum.git
synced 2025-09-27 04:50:31 +00:00
Remove unwraps via '?' with anyhow crate for example-oauth (#2069)
This commit is contained in:
parent
0ed02a9a46
commit
8cb11e7f94
@ -5,6 +5,7 @@ edition = "2021"
|
|||||||
publish = false
|
publish = false
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
anyhow = "1"
|
||||||
async-session = "3.0.0"
|
async-session = "3.0.0"
|
||||||
axum = { path = "../../axum" }
|
axum = { path = "../../axum" }
|
||||||
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
|
axum-extra = { path = "../../axum-extra", features = ["typed-header"] }
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
|
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
use async_session::{MemoryStore, Session, SessionStore};
|
use async_session::{MemoryStore, Session, SessionStore};
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
@ -18,7 +19,7 @@ use axum::{
|
|||||||
RequestPartsExt, Router,
|
RequestPartsExt, Router,
|
||||||
};
|
};
|
||||||
use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader};
|
use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader};
|
||||||
use http::{header, request::Parts};
|
use http::{header, request::Parts, StatusCode};
|
||||||
use oauth2::{
|
use oauth2::{
|
||||||
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
||||||
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
|
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
|
||||||
@ -41,7 +42,7 @@ async fn main() {
|
|||||||
|
|
||||||
// `MemoryStore` is just used as an example. Don't use this in production.
|
// `MemoryStore` is just used as an example. Don't use this in production.
|
||||||
let store = MemoryStore::new();
|
let store = MemoryStore::new();
|
||||||
let oauth_client = oauth_client();
|
let oauth_client = oauth_client().unwrap();
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
store,
|
store,
|
||||||
oauth_client,
|
oauth_client,
|
||||||
@ -57,9 +58,21 @@ async fn main() {
|
|||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
|
||||||
.await
|
.await
|
||||||
|
.context("failed to bind TcpListener")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"listening on {}",
|
||||||
|
listener
|
||||||
|
.local_addr()
|
||||||
|
.context("failed to return local address")
|
||||||
|
.unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
axum::serve(listener, app)
|
||||||
|
.await
|
||||||
|
.context("failed to serve service")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
tracing::debug!("listening on {}", listener.local_addr().unwrap());
|
|
||||||
axum::serve(listener, app).await.unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -80,7 +93,7 @@ impl FromRef<AppState> for BasicClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn oauth_client() -> BasicClient {
|
fn oauth_client() -> Result<BasicClient, AppError> {
|
||||||
// Environment variables (* = required):
|
// Environment variables (* = required):
|
||||||
// *"CLIENT_ID" "REPLACE_ME";
|
// *"CLIENT_ID" "REPLACE_ME";
|
||||||
// *"CLIENT_SECRET" "REPLACE_ME";
|
// *"CLIENT_SECRET" "REPLACE_ME";
|
||||||
@ -88,8 +101,8 @@ fn oauth_client() -> BasicClient {
|
|||||||
// "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code";
|
// "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code";
|
||||||
// "TOKEN_URL" "https://discord.com/api/oauth2/token";
|
// "TOKEN_URL" "https://discord.com/api/oauth2/token";
|
||||||
|
|
||||||
let client_id = env::var("CLIENT_ID").expect("Missing CLIENT_ID!");
|
let client_id = env::var("CLIENT_ID").context("Missing CLIENT_ID!")?;
|
||||||
let client_secret = env::var("CLIENT_SECRET").expect("Missing CLIENT_SECRET!");
|
let client_secret = env::var("CLIENT_SECRET").context("Missing CLIENT_SECRET!")?;
|
||||||
let redirect_url = env::var("REDIRECT_URL")
|
let redirect_url = env::var("REDIRECT_URL")
|
||||||
.unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string());
|
.unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string());
|
||||||
|
|
||||||
@ -100,13 +113,15 @@ fn oauth_client() -> BasicClient {
|
|||||||
let token_url = env::var("TOKEN_URL")
|
let token_url = env::var("TOKEN_URL")
|
||||||
.unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string());
|
.unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string());
|
||||||
|
|
||||||
BasicClient::new(
|
Ok(BasicClient::new(
|
||||||
ClientId::new(client_id),
|
ClientId::new(client_id),
|
||||||
Some(ClientSecret::new(client_secret)),
|
Some(ClientSecret::new(client_secret)),
|
||||||
AuthUrl::new(auth_url).unwrap(),
|
AuthUrl::new(auth_url).context("failed to create new authorization server URL")?,
|
||||||
Some(TokenUrl::new(token_url).unwrap()),
|
Some(TokenUrl::new(token_url).context("failed to create new token endpoint URL")?),
|
||||||
)
|
)
|
||||||
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap())
|
.set_redirect_uri(
|
||||||
|
RedirectUrl::new(redirect_url).context("failed to create new redirection URL")?,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// The user data we'll get back from Discord.
|
// The user data we'll get back from Discord.
|
||||||
@ -151,17 +166,27 @@ async fn protected(user: User) -> impl IntoResponse {
|
|||||||
async fn logout(
|
async fn logout(
|
||||||
State(store): State<MemoryStore>,
|
State(store): State<MemoryStore>,
|
||||||
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
TypedHeader(cookies): TypedHeader<headers::Cookie>,
|
||||||
) -> impl IntoResponse {
|
) -> Result<impl IntoResponse, AppError> {
|
||||||
let cookie = cookies.get(COOKIE_NAME).unwrap();
|
let cookie = cookies
|
||||||
let session = match store.load_session(cookie.to_string()).await.unwrap() {
|
.get(COOKIE_NAME)
|
||||||
|
.context("unexpected error getting cookie name")?;
|
||||||
|
|
||||||
|
let session = match store
|
||||||
|
.load_session(cookie.to_string())
|
||||||
|
.await
|
||||||
|
.context("failed to load session")?
|
||||||
|
{
|
||||||
Some(s) => s,
|
Some(s) => s,
|
||||||
// No session active, just redirect
|
// No session active, just redirect
|
||||||
None => return Redirect::to("/"),
|
None => return Ok(Redirect::to("/")),
|
||||||
};
|
};
|
||||||
|
|
||||||
store.destroy_session(session).await.unwrap();
|
store
|
||||||
|
.destroy_session(session)
|
||||||
|
.await
|
||||||
|
.context("failed to destroy session")?;
|
||||||
|
|
||||||
Redirect::to("/")
|
Ok(Redirect::to("/"))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@ -175,13 +200,13 @@ async fn login_authorized(
|
|||||||
Query(query): Query<AuthRequest>,
|
Query(query): Query<AuthRequest>,
|
||||||
State(store): State<MemoryStore>,
|
State(store): State<MemoryStore>,
|
||||||
State(oauth_client): State<BasicClient>,
|
State(oauth_client): State<BasicClient>,
|
||||||
) -> impl IntoResponse {
|
) -> Result<impl IntoResponse, AppError> {
|
||||||
// Get an auth token
|
// Get an auth token
|
||||||
let token = oauth_client
|
let token = oauth_client
|
||||||
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
||||||
.request_async(async_http_client)
|
.request_async(async_http_client)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.context("failed in sending request request to authorization server")?;
|
||||||
|
|
||||||
// Fetch user data from discord
|
// Fetch user data from discord
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
@ -191,26 +216,35 @@ async fn login_authorized(
|
|||||||
.bearer_auth(token.access_token().secret())
|
.bearer_auth(token.access_token().secret())
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.context("failed in sending request to target Url")?
|
||||||
.json::<User>()
|
.json::<User>()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.context("failed to deserialize response as JSON")?;
|
||||||
|
|
||||||
// Create a new session filled with user data
|
// Create a new session filled with user data
|
||||||
let mut session = Session::new();
|
let mut session = Session::new();
|
||||||
session.insert("user", &user_data).unwrap();
|
session
|
||||||
|
.insert("user", &user_data)
|
||||||
|
.context("failed in inserting serialized value into session")?;
|
||||||
|
|
||||||
// Store session and get corresponding cookie
|
// Store session and get corresponding cookie
|
||||||
let cookie = store.store_session(session).await.unwrap().unwrap();
|
let cookie = store
|
||||||
|
.store_session(session)
|
||||||
|
.await
|
||||||
|
.context("failed to store session")?
|
||||||
|
.context("unexpected error retrieving cookie value")?;
|
||||||
|
|
||||||
// Build the cookie
|
// Build the cookie
|
||||||
let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie);
|
let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie);
|
||||||
|
|
||||||
// Set cookie
|
// Set cookie
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(SET_COOKIE, cookie.parse().unwrap());
|
headers.insert(
|
||||||
|
SET_COOKIE,
|
||||||
|
cookie.parse().context("failed to parse cookie")?,
|
||||||
|
);
|
||||||
|
|
||||||
(headers, Redirect::to("/"))
|
Ok((headers, Redirect::to("/")))
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AuthRedirect;
|
struct AuthRedirect;
|
||||||
@ -256,3 +290,28 @@ where
|
|||||||
Ok(user)
|
Ok(user)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use anyhow, define error and enable '?'
|
||||||
|
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct AppError(anyhow::Error);
|
||||||
|
|
||||||
|
// Tell axum how to convert `AppError` into a response.
|
||||||
|
impl IntoResponse for AppError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
tracing::error!("Application error: {:#}", self.0);
|
||||||
|
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
|
||||||
|
// `Result<_, AppError>`. That way you don't need to do that manually.
|
||||||
|
impl<E> From<E> for AppError
|
||||||
|
where
|
||||||
|
E: Into<anyhow::Error>,
|
||||||
|
{
|
||||||
|
fn from(err: E) -> Self {
|
||||||
|
Self(err.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user