diff --git a/examples/realworld-postgres/src/main.rs b/examples/realworld-postgres/src/main.rs index 2a6c4281..c58ddd38 100644 --- a/examples/realworld-postgres/src/main.rs +++ b/examples/realworld-postgres/src/main.rs @@ -50,7 +50,9 @@ async fn register(mut req: Request) -> Response { let body: RegisterRequestBody = req.body_json().await.unwrap(); let hash = hash_password(&body.password).unwrap(); - let mut pool = req.state(); + // Make a new transaction + let pool = req.state(); + let mut tx = pool.begin().await.unwrap(); let rec = sqlx::query!( r#" @@ -62,12 +64,15 @@ RETURNING id, username, email body.email, hash, ) - .fetch_one(&mut pool) + .fetch_one(&mut tx) .await .unwrap(); let token = generate_token(rec.id).unwrap(); + // Explicitly commit + tx.commit().await.unwrap(); + #[derive(serde::Serialize)] struct RegisterResponseBody { user: User, diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index b52c8817..aee2cae1 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -16,6 +16,7 @@ mod database; mod executor; mod query; mod query_as; +mod transaction; mod url; #[macro_use] @@ -47,6 +48,7 @@ pub use connection::{Connect, Connection}; pub use executor::Executor; pub use query::{query, Query}; pub use query_as::{query_as, QueryAs}; +pub use transaction::Transaction; #[doc(hidden)] pub use query_as::query_as_mapped; diff --git a/sqlx-core/src/mysql/connection.rs b/sqlx-core/src/mysql/connection.rs index 2115fa03..d357e2d1 100644 --- a/sqlx-core/src/mysql/connection.rs +++ b/sqlx-core/src/mysql/connection.rs @@ -7,7 +7,7 @@ use futures_core::future::BoxFuture; use sha1::Sha1; use crate::cache::StatementCache; -use crate::connection::Connection; +use crate::connection::{Connect, Connection}; use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream}; use crate::mysql::error::MySqlError; use crate::mysql::protocol::{ @@ -475,7 +475,7 @@ impl MySqlConnection { } impl MySqlConnection { - pub(super) async fn open(url: crate::Result) -> crate::Result { + pub(super) async fn establish(url: crate::Result) -> crate::Result { let url = url?; let mut self_ = Self::new(&url).await?; @@ -598,19 +598,19 @@ impl MySqlConnection { T: TryInto, Self: Sized, { - Box::pin(MySqlConnection::open(url.try_into())) + Box::pin(MySqlConnection::establish(url.try_into())) } } impl Connect for MySqlConnection { type Connection = MySqlConnection; - fn connect(url: T) -> BoxFuture<'static, Result> + fn connect(url: T) -> BoxFuture<'static, crate::Result> where T: TryInto, Self: Sized, { - Box::pin(PgConnection::open(url.try_into())) + Box::pin(MySqlConnection::establish(url.try_into())) } } diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 463b817f..37bd3047 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -26,16 +26,3 @@ pub use row::MySqlRow; /// An alias for [`Pool`], specialized for **MySQL**. pub type MySqlPool = super::Pool; - -use std::convert::TryInto; - -use crate::url::Url; - -// used in tests and hidden code in examples -#[doc(hidden)] -pub async fn connect(url: T) -> crate::Result -where - T: TryInto, -{ - MySqlConnection::open(url.try_into()).await -} diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index 4edff8d9..ff5bda5e 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -1,3 +1,5 @@ +use std::ops::DerefMut; + use futures_core::{future::BoxFuture, stream::BoxStream}; use futures_util::StreamExt; @@ -9,6 +11,8 @@ use crate::{ Database, }; +use super::PoolConnection; + impl Executor for Pool where C: Connection + Connect, @@ -108,3 +112,45 @@ where Box::pin(async move { self.acquire().await?.describe(query).await }) } } + +impl Executor for PoolConnection +where + C: Connection + Connect, +{ + type Database = ::Database; + + fn send<'e, 'q: 'e>(&'e mut self, commands: &'q str) -> BoxFuture<'e, crate::Result<()>> { + self.deref_mut().send(commands) + } + + fn execute<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: <::Database as Database>::Arguments, + ) -> BoxFuture<'e, crate::Result> { + self.deref_mut().execute(query, args) + } + + fn fetch<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: <::Database as Database>::Arguments, + ) -> BoxStream<'e, crate::Result<<::Database as Database>::Row>> { + self.deref_mut().fetch(query, args) + } + + fn fetch_optional<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: <::Database as Database>::Arguments, + ) -> BoxFuture<'e, crate::Result::Database as Database>::Row>>> { + self.deref_mut().fetch_optional(query, args) + } + + fn describe<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + ) -> BoxFuture<'e, crate::Result>> { + self.deref_mut().describe(query) + } +} diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 78592092..1fb4385c 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -2,21 +2,26 @@ use std::{ fmt, + mem, ops::{Deref, DerefMut}, sync::Arc, time::{Duration, Instant}, }; +use futures_core::future::BoxFuture; + use crate::connection::{Connect, Connection}; +use crate::transaction::Transaction; use self::inner::SharedPool; -pub use self::options::Builder; use self::options::Options; mod executor; mod inner; mod options; +pub use self::options::Builder; + /// A pool of database connections. pub struct Pool(Arc>); @@ -84,6 +89,11 @@ where }) } + /// Retrieves a new connection and immediately begins a new transaction. + pub async fn begin(&self) -> crate::Result>> { + Ok(Transaction::new(0, self.acquire().await?).await?) + } + /// Ends the use of a connection pool. Prevents any new connections /// and will close all active connections when they are returned to the pool. /// @@ -172,6 +182,27 @@ where } } +impl Connection for PoolConnection +where + C: Connection + Connect, +{ + fn close(mut self) -> BoxFuture<'static, crate::Result<()>> { + Box::pin(async move { + if let Some(live) = self.live.take() { + let raw = live.raw; + + // Explicitly close the connection + raw.close().await?; + } + + // Forget ourself so it does not go back to the pool + mem::forget(self); + + Ok(()) + }) + } +} + impl Drop for PoolConnection where C: Connection + Connect, diff --git a/sqlx-core/src/postgres/connection.rs b/sqlx-core/src/postgres/connection.rs index a3f50b40..2be7d6d7 100644 --- a/sqlx-core/src/postgres/connection.rs +++ b/sqlx-core/src/postgres/connection.rs @@ -8,7 +8,7 @@ use rand::Rng; use sha2::{Digest, Sha256}; use crate::cache::StatementCache; -use crate::connection::Connection; +use crate::connection::{Connect, Connection}; use crate::io::{Buf, BufStream, MaybeTlsStream}; use crate::postgres::protocol::{ self, hi, Authentication, Decode, Encode, Message, SaslInitialResponse, SaslResponse, @@ -334,7 +334,7 @@ impl PgConnection { } impl PgConnection { - pub(super) async fn open(url: Result) -> Result { + pub(super) async fn establish(url: Result) -> Result { let url = url?; let stream = MaybeTlsStream::connect(&url, 5432).await?; @@ -402,7 +402,7 @@ impl PgConnection { T: TryInto, Self: Sized, { - Box::pin(PgConnection::open(url.try_into())) + Box::pin(PgConnection::establish(url.try_into())) } } @@ -414,7 +414,7 @@ impl Connect for PgConnection { T: TryInto, Self: Sized, { - Box::pin(PgConnection::open(url.try_into())) + Box::pin(PgConnection::establish(url.try_into())) } } diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index ffa060d3..4dab66a9 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -18,16 +18,3 @@ mod types; /// An alias for [`Pool`], specialized for **Postgres**. pub type PgPool = super::Pool; - -use std::convert::TryInto; - -use crate::url::Url; - -// used in tests and hidden code in examples -#[doc(hidden)] -pub async fn connect(url: T) -> crate::Result -where - T: TryInto, -{ - PgConnection::open(url.try_into()).await -} diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs new file mode 100644 index 00000000..ff3d9d0c --- /dev/null +++ b/sqlx-core/src/transaction.rs @@ -0,0 +1,173 @@ +use std::ops::{Deref, DerefMut}; + +use async_std::task; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; + +use crate::database::Database; +use crate::describe::Describe; +use crate::executor::Executor; +use crate::connection::Connection; + +pub struct Transaction +where + T: Connection + Send + 'static, +{ + inner: Option, + depth: u32, +} + +impl Transaction +where + T: Connection + Send + 'static, +{ + pub(crate) async fn new(depth: u32, mut inner: T) -> crate::Result { + if depth == 0 { + inner.send("BEGIN").await?; + } else { + inner + .send(&format!("SAVEPOINT _sqlx_savepoint_{}", depth)) + .await?; + } + + Ok(Self { + inner: Some(inner), + depth: depth + 1, + }) + } + + pub async fn begin(mut self) -> crate::Result> { + Transaction::new(self.depth, self.inner.take().expect(ERR_FINALIZED)).await + } + + pub async fn commit(mut self) -> crate::Result { + let mut inner = self.inner.take().expect(ERR_FINALIZED); + let depth = self.depth; + + if depth == 1 { + inner.send("COMMIT").await?; + } else { + inner + .send(&format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)) + .await?; + } + + Ok(inner) + } + + pub async fn rollback(mut self) -> crate::Result { + let mut inner = self.inner.take().expect(ERR_FINALIZED); + let depth = self.depth; + + if depth == 1 { + inner.send("ROLLBACK").await?; + } else { + inner + .send(&format!( + "ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", + depth - 1 + )) + .await?; + } + + Ok(inner) + } +} + +const ERR_FINALIZED: &str = "(bug) transaction already finalized"; + +impl Deref for Transaction +where + Conn: Connection, +{ + type Target = Conn; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref().expect(ERR_FINALIZED) + } +} + +impl DerefMut for Transaction +where + Conn: Connection, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner.as_mut().expect(ERR_FINALIZED) + } +} + +impl Connection for Transaction +where + T: Connection +{ + // Close is equivalent to ROLLBACK followed by CLOSE + fn close(self) -> BoxFuture<'static, crate::Result<()>> { + Box::pin(async move { + self.rollback().await?.close().await + }) + } +} + +impl Executor for Transaction +where + T: Connection, +{ + type Database = T::Database; + + fn send<'e, 'q: 'e>(&'e mut self, commands: &'q str) -> BoxFuture<'e, crate::Result<()>> { + self.deref_mut().send(commands) + } + + fn execute<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: ::Arguments, + ) -> BoxFuture<'e, crate::Result> { + self.deref_mut().execute(query, args) + } + + fn fetch<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: ::Arguments, + ) -> BoxStream<'e, crate::Result<::Row>> { + self.deref_mut().fetch(query, args) + } + + fn fetch_optional<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + args: ::Arguments, + ) -> BoxFuture<'e, crate::Result::Row>>> { + self.deref_mut().fetch_optional(query, args) + } + + fn describe<'e, 'q: 'e>( + &'e mut self, + query: &'q str, + ) -> BoxFuture<'e, crate::Result>> { + self.deref_mut().describe(query) + } +} + +impl Drop for Transaction +where + Conn: Connection, +{ + fn drop(&mut self) { + if self.depth > 0 { + if let Some(mut inner) = self.inner.take() { + task::spawn(async move { + let res = inner.send("ROLLBACK").await; + + // If the rollback failed we need to close the inner connection + if res.is_err() { + // This will explicitly forget the connection so it will not + // return to the pool + let _ = inner.close().await; + } + }); + } + } + } +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index ea997c3e..b1f21f14 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -29,14 +29,14 @@ use query_macros::*; macro_rules! async_macro ( ($db:ident => $expr:expr) => {{ let res: Result = task::block_on(async { - use sqlx::Connection; + use sqlx::Connect; let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?; match db_url.scheme() { #[cfg(feature = "postgres")] "postgresql" | "postgres" => { - let $db = sqlx::postgres::PgConnection::open(db_url.as_str()) + let $db = sqlx::postgres::PgConnection::connect(db_url.as_str()) .await .map_err(|e| format!("failed to connect to database: {}", e))?; @@ -50,7 +50,7 @@ macro_rules! async_macro ( ).into()), #[cfg(feature = "mysql")] "mysql" | "mariadb" => { - let $db = sqlx::mysql::MySqlConnection::open(db_url.as_str()) + let $db = sqlx::mysql::MySqlConnection::connect(db_url.as_str()) .await .map_err(|e| format!("failed to connect to database: {}", e))?;