Remove unwraps via '?' with anyhow crate for example-oauth (#2069)

This commit is contained in:
Rodrigo Santiago 2023-07-04 15:48:58 -04:00 committed by GitHub
parent 0ed02a9a46
commit 8cb11e7f94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 25 deletions

View File

@ -5,6 +5,7 @@ edition = "2021"
publish = false publish = false
[dependencies] [dependencies]
anyhow = "1"
async-session = "3.0.0" async-session = "3.0.0"
axum = { path = "../../axum" } axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra", features = ["typed-header"] } axum-extra = { path = "../../axum-extra", features = ["typed-header"] }

View File

@ -8,6 +8,7 @@
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
//! ``` //! ```
use anyhow::{Context, Result};
use async_session::{MemoryStore, Session, SessionStore}; use async_session::{MemoryStore, Session, SessionStore};
use axum::{ use axum::{
async_trait, async_trait,
@ -18,7 +19,7 @@ use axum::{
RequestPartsExt, Router, RequestPartsExt, Router,
}; };
use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader}; use axum_extra::{headers, typed_header::TypedHeaderRejectionReason, TypedHeader};
use http::{header, request::Parts}; use http::{header, request::Parts, StatusCode};
use oauth2::{ use oauth2::{
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId, basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl,
@ -41,7 +42,7 @@ async fn main() {
// `MemoryStore` is just used as an example. Don't use this in production. // `MemoryStore` is just used as an example. Don't use this in production.
let store = MemoryStore::new(); let store = MemoryStore::new();
let oauth_client = oauth_client(); let oauth_client = oauth_client().unwrap();
let app_state = AppState { let app_state = AppState {
store, store,
oauth_client, oauth_client,
@ -57,9 +58,21 @@ async fn main() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await .await
.context("failed to bind TcpListener")
.unwrap();
tracing::debug!(
"listening on {}",
listener
.local_addr()
.context("failed to return local address")
.unwrap()
);
axum::serve(listener, app)
.await
.context("failed to serve service")
.unwrap(); .unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
} }
#[derive(Clone)] #[derive(Clone)]
@ -80,7 +93,7 @@ impl FromRef<AppState> for BasicClient {
} }
} }
fn oauth_client() -> BasicClient { fn oauth_client() -> Result<BasicClient, AppError> {
// Environment variables (* = required): // Environment variables (* = required):
// *"CLIENT_ID" "REPLACE_ME"; // *"CLIENT_ID" "REPLACE_ME";
// *"CLIENT_SECRET" "REPLACE_ME"; // *"CLIENT_SECRET" "REPLACE_ME";
@ -88,8 +101,8 @@ fn oauth_client() -> BasicClient {
// "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code"; // "AUTH_URL" "https://discord.com/api/oauth2/authorize?response_type=code";
// "TOKEN_URL" "https://discord.com/api/oauth2/token"; // "TOKEN_URL" "https://discord.com/api/oauth2/token";
let client_id = env::var("CLIENT_ID").expect("Missing CLIENT_ID!"); let client_id = env::var("CLIENT_ID").context("Missing CLIENT_ID!")?;
let client_secret = env::var("CLIENT_SECRET").expect("Missing CLIENT_SECRET!"); let client_secret = env::var("CLIENT_SECRET").context("Missing CLIENT_SECRET!")?;
let redirect_url = env::var("REDIRECT_URL") let redirect_url = env::var("REDIRECT_URL")
.unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string()); .unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string());
@ -100,13 +113,15 @@ fn oauth_client() -> BasicClient {
let token_url = env::var("TOKEN_URL") let token_url = env::var("TOKEN_URL")
.unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string()); .unwrap_or_else(|_| "https://discord.com/api/oauth2/token".to_string());
BasicClient::new( Ok(BasicClient::new(
ClientId::new(client_id), ClientId::new(client_id),
Some(ClientSecret::new(client_secret)), Some(ClientSecret::new(client_secret)),
AuthUrl::new(auth_url).unwrap(), AuthUrl::new(auth_url).context("failed to create new authorization server URL")?,
Some(TokenUrl::new(token_url).unwrap()), Some(TokenUrl::new(token_url).context("failed to create new token endpoint URL")?),
) )
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap()) .set_redirect_uri(
RedirectUrl::new(redirect_url).context("failed to create new redirection URL")?,
))
} }
// The user data we'll get back from Discord. // The user data we'll get back from Discord.
@ -151,17 +166,27 @@ async fn protected(user: User) -> impl IntoResponse {
async fn logout( async fn logout(
State(store): State<MemoryStore>, State(store): State<MemoryStore>,
TypedHeader(cookies): TypedHeader<headers::Cookie>, TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> impl IntoResponse { ) -> Result<impl IntoResponse, AppError> {
let cookie = cookies.get(COOKIE_NAME).unwrap(); let cookie = cookies
let session = match store.load_session(cookie.to_string()).await.unwrap() { .get(COOKIE_NAME)
.context("unexpected error getting cookie name")?;
let session = match store
.load_session(cookie.to_string())
.await
.context("failed to load session")?
{
Some(s) => s, Some(s) => s,
// No session active, just redirect // No session active, just redirect
None => return Redirect::to("/"), None => return Ok(Redirect::to("/")),
}; };
store.destroy_session(session).await.unwrap(); store
.destroy_session(session)
.await
.context("failed to destroy session")?;
Redirect::to("/") Ok(Redirect::to("/"))
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -175,13 +200,13 @@ async fn login_authorized(
Query(query): Query<AuthRequest>, Query(query): Query<AuthRequest>,
State(store): State<MemoryStore>, State(store): State<MemoryStore>,
State(oauth_client): State<BasicClient>, State(oauth_client): State<BasicClient>,
) -> impl IntoResponse { ) -> Result<impl IntoResponse, AppError> {
// Get an auth token // Get an auth token
let token = oauth_client let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone())) .exchange_code(AuthorizationCode::new(query.code.clone()))
.request_async(async_http_client) .request_async(async_http_client)
.await .await
.unwrap(); .context("failed in sending request request to authorization server")?;
// Fetch user data from discord // Fetch user data from discord
let client = reqwest::Client::new(); let client = reqwest::Client::new();
@ -191,26 +216,35 @@ async fn login_authorized(
.bearer_auth(token.access_token().secret()) .bearer_auth(token.access_token().secret())
.send() .send()
.await .await
.unwrap() .context("failed in sending request to target Url")?
.json::<User>() .json::<User>()
.await .await
.unwrap(); .context("failed to deserialize response as JSON")?;
// Create a new session filled with user data // Create a new session filled with user data
let mut session = Session::new(); let mut session = Session::new();
session.insert("user", &user_data).unwrap(); session
.insert("user", &user_data)
.context("failed in inserting serialized value into session")?;
// Store session and get corresponding cookie // Store session and get corresponding cookie
let cookie = store.store_session(session).await.unwrap().unwrap(); let cookie = store
.store_session(session)
.await
.context("failed to store session")?
.context("unexpected error retrieving cookie value")?;
// Build the cookie // Build the cookie
let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie); let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie);
// Set cookie // Set cookie
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(SET_COOKIE, cookie.parse().unwrap()); headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);
(headers, Redirect::to("/")) Ok((headers, Redirect::to("/")))
} }
struct AuthRedirect; struct AuthRedirect;
@ -256,3 +290,28 @@ where
Ok(user) Ok(user)
} }
} }
// Use anyhow, define error and enable '?'
// For a simplified example of using anyhow in axum check /examples/anyhow-error-response
#[derive(Debug)]
struct AppError(anyhow::Error);
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
tracing::error!("Application error: {:#}", self.0);
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
}
}
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}