Add a Transaction type to simplify dealing with Transactions

This commit is contained in:
Ryan Leckey 2020-01-03 22:42:10 -08:00
parent 28ed854b03
commit b1a27ddac2
10 changed files with 272 additions and 41 deletions

View File

@ -50,7 +50,9 @@ async fn register(mut req: Request<PgPool>) -> Response {
let body: RegisterRequestBody = req.body_json().await.unwrap(); let body: RegisterRequestBody = req.body_json().await.unwrap();
let hash = hash_password(&body.password).unwrap(); let hash = hash_password(&body.password).unwrap();
let mut pool = req.state(); // Make a new transaction
let pool = req.state();
let mut tx = pool.begin().await.unwrap();
let rec = sqlx::query!( let rec = sqlx::query!(
r#" r#"
@ -62,12 +64,15 @@ RETURNING id, username, email
body.email, body.email,
hash, hash,
) )
.fetch_one(&mut pool) .fetch_one(&mut tx)
.await .await
.unwrap(); .unwrap();
let token = generate_token(rec.id).unwrap(); let token = generate_token(rec.id).unwrap();
// Explicitly commit
tx.commit().await.unwrap();
#[derive(serde::Serialize)] #[derive(serde::Serialize)]
struct RegisterResponseBody { struct RegisterResponseBody {
user: User, user: User,

View File

@ -16,6 +16,7 @@ mod database;
mod executor; mod executor;
mod query; mod query;
mod query_as; mod query_as;
mod transaction;
mod url; mod url;
#[macro_use] #[macro_use]
@ -47,6 +48,7 @@ pub use connection::{Connect, Connection};
pub use executor::Executor; pub use executor::Executor;
pub use query::{query, Query}; pub use query::{query, Query};
pub use query_as::{query_as, QueryAs}; pub use query_as::{query_as, QueryAs};
pub use transaction::Transaction;
#[doc(hidden)] #[doc(hidden)]
pub use query_as::query_as_mapped; pub use query_as::query_as_mapped;

View File

@ -7,7 +7,7 @@ use futures_core::future::BoxFuture;
use sha1::Sha1; use sha1::Sha1;
use crate::cache::StatementCache; use crate::cache::StatementCache;
use crate::connection::Connection; use crate::connection::{Connect, Connection};
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream}; use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
use crate::mysql::error::MySqlError; use crate::mysql::error::MySqlError;
use crate::mysql::protocol::{ use crate::mysql::protocol::{
@ -475,7 +475,7 @@ impl MySqlConnection {
} }
impl MySqlConnection { impl MySqlConnection {
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> { pub(super) async fn establish(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?; let url = url?;
let mut self_ = Self::new(&url).await?; let mut self_ = Self::new(&url).await?;
@ -598,19 +598,19 @@ impl MySqlConnection {
T: TryInto<Url, Error = crate::Error>, T: TryInto<Url, Error = crate::Error>,
Self: Sized, Self: Sized,
{ {
Box::pin(MySqlConnection::open(url.try_into())) Box::pin(MySqlConnection::establish(url.try_into()))
} }
} }
impl Connect for MySqlConnection { impl Connect for MySqlConnection {
type Connection = MySqlConnection; type Connection = MySqlConnection;
fn connect<T>(url: T) -> BoxFuture<'static, Result<MySqlConnection>> fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<MySqlConnection>>
where where
T: TryInto<Url, Error = crate::Error>, T: TryInto<Url, Error = crate::Error>,
Self: Sized, Self: Sized,
{ {
Box::pin(PgConnection::open(url.try_into())) Box::pin(MySqlConnection::establish(url.try_into()))
} }
} }

View File

@ -26,16 +26,3 @@ pub use row::MySqlRow;
/// An alias for [`Pool`], specialized for **MySQL**. /// An alias for [`Pool`], specialized for **MySQL**.
pub type MySqlPool = super::Pool<MySql>; pub type MySqlPool = super::Pool<MySql>;
use std::convert::TryInto;
use crate::url::Url;
// used in tests and hidden code in examples
#[doc(hidden)]
pub async fn connect<T>(url: T) -> crate::Result<MySqlConnection>
where
T: TryInto<Url, Error = crate::Error>,
{
MySqlConnection::open(url.try_into()).await
}

View File

@ -1,3 +1,5 @@
use std::ops::DerefMut;
use futures_core::{future::BoxFuture, stream::BoxStream}; use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::StreamExt; use futures_util::StreamExt;
@ -9,6 +11,8 @@ use crate::{
Database, Database,
}; };
use super::PoolConnection;
impl<C> Executor for Pool<C> impl<C> Executor for Pool<C>
where where
C: Connection + Connect<Connection = C>, C: Connection + Connect<Connection = C>,
@ -108,3 +112,45 @@ where
Box::pin(async move { self.acquire().await?.describe(query).await }) Box::pin(async move { self.acquire().await?.describe(query).await })
} }
} }
impl<C> Executor for PoolConnection<C>
where
C: Connection + Connect<Connection = C>,
{
type Database = <C as Executor>::Database;
fn send<'e, 'q: 'e>(&'e mut self, commands: &'q str) -> BoxFuture<'e, crate::Result<()>> {
self.deref_mut().send(commands)
}
fn execute<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <<C as Executor>::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<u64>> {
self.deref_mut().execute(query, args)
}
fn fetch<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <<C as Executor>::Database as Database>::Arguments,
) -> BoxStream<'e, crate::Result<<<C as Executor>::Database as Database>::Row>> {
self.deref_mut().fetch(query, args)
}
fn fetch_optional<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <<C as Executor>::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<Option<<<C as Executor>::Database as Database>::Row>>> {
self.deref_mut().fetch_optional(query, args)
}
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
self.deref_mut().describe(query)
}
}

View File

@ -2,21 +2,26 @@
use std::{ use std::{
fmt, fmt,
mem,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::Arc, sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use futures_core::future::BoxFuture;
use crate::connection::{Connect, Connection}; use crate::connection::{Connect, Connection};
use crate::transaction::Transaction;
use self::inner::SharedPool; use self::inner::SharedPool;
pub use self::options::Builder;
use self::options::Options; use self::options::Options;
mod executor; mod executor;
mod inner; mod inner;
mod options; mod options;
pub use self::options::Builder;
/// A pool of database connections. /// A pool of database connections.
pub struct Pool<C>(Arc<SharedPool<C>>); pub struct Pool<C>(Arc<SharedPool<C>>);
@ -84,6 +89,11 @@ where
}) })
} }
/// Retrieves a new connection and immediately begins a new transaction.
pub async fn begin(&self) -> crate::Result<Transaction<PoolConnection<C>>> {
Ok(Transaction::new(0, self.acquire().await?).await?)
}
/// Ends the use of a connection pool. Prevents any new connections /// Ends the use of a connection pool. Prevents any new connections
/// and will close all active connections when they are returned to the pool. /// and will close all active connections when they are returned to the pool.
/// ///
@ -172,6 +182,27 @@ where
} }
} }
impl<C> Connection for PoolConnection<C>
where
C: Connection + Connect<Connection = C>,
{
fn close(mut self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(async move {
if let Some(live) = self.live.take() {
let raw = live.raw;
// Explicitly close the connection
raw.close().await?;
}
// Forget ourself so it does not go back to the pool
mem::forget(self);
Ok(())
})
}
}
impl<C> Drop for PoolConnection<C> impl<C> Drop for PoolConnection<C>
where where
C: Connection + Connect<Connection = C>, C: Connection + Connect<Connection = C>,

View File

@ -8,7 +8,7 @@ use rand::Rng;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use crate::cache::StatementCache; use crate::cache::StatementCache;
use crate::connection::Connection; use crate::connection::{Connect, Connection};
use crate::io::{Buf, BufStream, MaybeTlsStream}; use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::protocol::{ use crate::postgres::protocol::{
self, hi, Authentication, Decode, Encode, Message, SaslInitialResponse, SaslResponse, self, hi, Authentication, Decode, Encode, Message, SaslInitialResponse, SaslResponse,
@ -334,7 +334,7 @@ impl PgConnection {
} }
impl PgConnection { impl PgConnection {
pub(super) async fn open(url: Result<Url>) -> Result<Self> { pub(super) async fn establish(url: Result<Url>) -> Result<Self> {
let url = url?; let url = url?;
let stream = MaybeTlsStream::connect(&url, 5432).await?; let stream = MaybeTlsStream::connect(&url, 5432).await?;
@ -402,7 +402,7 @@ impl PgConnection {
T: TryInto<Url, Error = crate::Error>, T: TryInto<Url, Error = crate::Error>,
Self: Sized, Self: Sized,
{ {
Box::pin(PgConnection::open(url.try_into())) Box::pin(PgConnection::establish(url.try_into()))
} }
} }
@ -414,7 +414,7 @@ impl Connect for PgConnection {
T: TryInto<Url, Error = crate::Error>, T: TryInto<Url, Error = crate::Error>,
Self: Sized, Self: Sized,
{ {
Box::pin(PgConnection::open(url.try_into())) Box::pin(PgConnection::establish(url.try_into()))
} }
} }

View File

@ -18,16 +18,3 @@ mod types;
/// An alias for [`Pool`], specialized for **Postgres**. /// An alias for [`Pool`], specialized for **Postgres**.
pub type PgPool = super::Pool<Postgres>; pub type PgPool = super::Pool<Postgres>;
use std::convert::TryInto;
use crate::url::Url;
// used in tests and hidden code in examples
#[doc(hidden)]
pub async fn connect<T>(url: T) -> crate::Result<PgConnection>
where
T: TryInto<Url, Error = crate::Error>,
{
PgConnection::open(url.try_into()).await
}

View File

@ -0,0 +1,173 @@
use std::ops::{Deref, DerefMut};
use async_std::task;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use crate::database::Database;
use crate::describe::Describe;
use crate::executor::Executor;
use crate::connection::Connection;
pub struct Transaction<T>
where
T: Connection + Send + 'static,
{
inner: Option<T>,
depth: u32,
}
impl<T> Transaction<T>
where
T: Connection + Send + 'static,
{
pub(crate) async fn new(depth: u32, mut inner: T) -> crate::Result<Self> {
if depth == 0 {
inner.send("BEGIN").await?;
} else {
inner
.send(&format!("SAVEPOINT _sqlx_savepoint_{}", depth))
.await?;
}
Ok(Self {
inner: Some(inner),
depth: depth + 1,
})
}
pub async fn begin(mut self) -> crate::Result<Transaction<T>> {
Transaction::new(self.depth, self.inner.take().expect(ERR_FINALIZED)).await
}
pub async fn commit(mut self) -> crate::Result<T> {
let mut inner = self.inner.take().expect(ERR_FINALIZED);
let depth = self.depth;
if depth == 1 {
inner.send("COMMIT").await?;
} else {
inner
.send(&format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1))
.await?;
}
Ok(inner)
}
pub async fn rollback(mut self) -> crate::Result<T> {
let mut inner = self.inner.take().expect(ERR_FINALIZED);
let depth = self.depth;
if depth == 1 {
inner.send("ROLLBACK").await?;
} else {
inner
.send(&format!(
"ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}",
depth - 1
))
.await?;
}
Ok(inner)
}
}
const ERR_FINALIZED: &str = "(bug) transaction already finalized";
impl<Conn> Deref for Transaction<Conn>
where
Conn: Connection,
{
type Target = Conn;
fn deref(&self) -> &Self::Target {
self.inner.as_ref().expect(ERR_FINALIZED)
}
}
impl<Conn> DerefMut for Transaction<Conn>
where
Conn: Connection,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.as_mut().expect(ERR_FINALIZED)
}
}
impl<T> Connection for Transaction<T>
where
T: Connection
{
// Close is equivalent to ROLLBACK followed by CLOSE
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(async move {
self.rollback().await?.close().await
})
}
}
impl<T> Executor for Transaction<T>
where
T: Connection,
{
type Database = T::Database;
fn send<'e, 'q: 'e>(&'e mut self, commands: &'q str) -> BoxFuture<'e, crate::Result<()>> {
self.deref_mut().send(commands)
}
fn execute<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <Self::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<u64>> {
self.deref_mut().execute(query, args)
}
fn fetch<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <Self::Database as Database>::Arguments,
) -> BoxStream<'e, crate::Result<<Self::Database as Database>::Row>> {
self.deref_mut().fetch(query, args)
}
fn fetch_optional<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <Self::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<Option<<Self::Database as Database>::Row>>> {
self.deref_mut().fetch_optional(query, args)
}
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
self.deref_mut().describe(query)
}
}
impl<Conn> Drop for Transaction<Conn>
where
Conn: Connection,
{
fn drop(&mut self) {
if self.depth > 0 {
if let Some(mut inner) = self.inner.take() {
task::spawn(async move {
let res = inner.send("ROLLBACK").await;
// If the rollback failed we need to close the inner connection
if res.is_err() {
// This will explicitly forget the connection so it will not
// return to the pool
let _ = inner.close().await;
}
});
}
}
}
}

View File

@ -29,14 +29,14 @@ use query_macros::*;
macro_rules! async_macro ( macro_rules! async_macro (
($db:ident => $expr:expr) => {{ ($db:ident => $expr:expr) => {{
let res: Result<proc_macro2::TokenStream> = task::block_on(async { let res: Result<proc_macro2::TokenStream> = task::block_on(async {
use sqlx::Connection; use sqlx::Connect;
let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?; let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?;
match db_url.scheme() { match db_url.scheme() {
#[cfg(feature = "postgres")] #[cfg(feature = "postgres")]
"postgresql" | "postgres" => { "postgresql" | "postgres" => {
let $db = sqlx::postgres::PgConnection::open(db_url.as_str()) let $db = sqlx::postgres::PgConnection::connect(db_url.as_str())
.await .await
.map_err(|e| format!("failed to connect to database: {}", e))?; .map_err(|e| format!("failed to connect to database: {}", e))?;
@ -50,7 +50,7 @@ macro_rules! async_macro (
).into()), ).into()),
#[cfg(feature = "mysql")] #[cfg(feature = "mysql")]
"mysql" | "mariadb" => { "mysql" | "mariadb" => {
let $db = sqlx::mysql::MySqlConnection::open(db_url.as_str()) let $db = sqlx::mysql::MySqlConnection::connect(db_url.as_str())
.await .await
.map_err(|e| format!("failed to connect to database: {}", e))?; .map_err(|e| format!("failed to connect to database: {}", e))?;