diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 3837a8b4..6f9a5fbf 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -24,7 +24,10 @@ pub trait Connection: Send { /// Begin a new transaction or establish a savepoint within the active transaction. /// /// Returns a [`Transaction`] for controlling and tracking the new transaction. - fn begin(&mut self) -> BoxFuture<'_, Result, Error>> { + fn begin(&mut self) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { Transaction::begin(self) } @@ -34,6 +37,7 @@ pub trait Connection: Send { /// return an error, the transaction will be committed. fn transaction<'c: 'f, 'f, T, E, F, Fut>(&'c mut self, f: F) -> BoxFuture<'f, Result> where + Self: Sized, T: Send, F: FnOnce(&mut ::Connection) -> Fut + Send + 'f, E: From + Send, diff --git a/sqlx-core/src/ext/maybe_owned.rs b/sqlx-core/src/ext/maybe_owned.rs new file mode 100644 index 00000000..0cb51eb7 --- /dev/null +++ b/sqlx-core/src/ext/maybe_owned.rs @@ -0,0 +1,40 @@ +use std::ops::{Deref, DerefMut}; + +pub enum MaybeOwned<'a, T> { + Borrowed(&'a mut T), + Owned(T), +} + +impl<'a, T> From for MaybeOwned<'a, T> { + fn from(v: T) -> Self { + MaybeOwned::Owned(v) + } +} + +impl<'a, T> From<&'a mut T> for MaybeOwned<'a, T> { + fn from(v: &'a mut T) -> Self { + MaybeOwned::Borrowed(v) + } +} + +impl<'a, T> Deref for MaybeOwned<'a, T> { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + match self { + MaybeOwned::Borrowed(v) => v, + MaybeOwned::Owned(v) => v, + } + } +} + +impl<'a, T> DerefMut for MaybeOwned<'a, T> { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + MaybeOwned::Borrowed(v) => v, + MaybeOwned::Owned(v) => v, + } + } +} diff --git a/sqlx-core/src/ext/mod.rs b/sqlx-core/src/ext/mod.rs index 4e920bb8..d25216f6 100644 --- a/sqlx-core/src/ext/mod.rs +++ b/sqlx-core/src/ext/mod.rs @@ -1 +1,2 @@ +pub mod maybe_owned; pub mod ustr; diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 2514c9d1..10da0858 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -8,6 +8,7 @@ use std::{ use crate::database::Database; use crate::error::Error; +use crate::transaction::Transaction; use self::inner::SharedPool; use self::options::Options; @@ -60,6 +61,22 @@ impl Pool { self.0.try_acquire().map(|conn| conn.attach(&self.0)) } + /// Retrieves a new connection and immediately begins a new transaction. + pub async fn begin(&self) -> Result>, Error> { + Ok(Transaction::begin(self.acquire().await?).await?) + } + + /// Attempts to retrieve a new connection and immediately begins a new transaction if there + /// is one available. + pub async fn try_begin( + &self, + ) -> Result>>, Error> { + match self.try_acquire() { + Some(conn) => Transaction::begin(conn).await.map(Some), + None => Ok(None), + } + } + /// Ends the use of a connection pool. Prevents any new connections /// and will close all active connections when they are returned to the pool. /// diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index fdb60b8e..27879ffb 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -8,6 +8,7 @@ use futures_util::{future, FutureExt}; use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::ext::maybe_owned::MaybeOwned; /// Generic management of database transactions. /// @@ -56,10 +57,10 @@ pub trait TransactionManager { /// [`rollback`]: #method.rollback pub struct Transaction<'c, DB, C = ::Connection> where - DB: ?Sized + Database, - C: ?Sized + Connection, + DB: Database, + C: Sized + Connection, { - connection: &'c mut C, + connection: MaybeOwned<'c, C>, // the depth of ~this~ transaction, depth directly equates to how many times [`begin()`] // was called without a corresponding [`commit()`] or [`rollback()`] @@ -68,10 +69,12 @@ where impl<'c, DB, C> Transaction<'c, DB, C> where - DB: ?Sized + Database, - C: ?Sized + Connection, + DB: Database, + C: Sized + Connection, { - pub(crate) fn begin(conn: &'c mut C) -> BoxFuture<'c, Result> { + pub(crate) fn begin(conn: impl Into>) -> BoxFuture<'c, Result> { + let mut conn = conn.into(); + Box::pin(async move { let depth = conn.transaction_depth(); @@ -85,20 +88,20 @@ where } /// Commits this transaction or savepoint. - pub async fn commit(self) -> Result<(), Error> { + pub async fn commit(mut self) -> Result<(), Error> { DB::TransactionManager::commit(self.connection.get_mut(), self.depth).await } /// Aborts this transaction or savepoint. - pub async fn rollback(self) -> Result<(), Error> { + pub async fn rollback(mut self) -> Result<(), Error> { DB::TransactionManager::rollback(self.connection.get_mut(), self.depth).await } } impl<'c, DB, C> Connection for Transaction<'c, DB, C> where - DB: ?Sized + Database, - C: ?Sized + Connection, + DB: Database, + C: Sized + Connection, { type Database = C::Database; @@ -137,7 +140,7 @@ where #[allow(unused_macros)] macro_rules! impl_executor_for_transaction { ($DB:ident, $Row:ident) => { - impl<'c, 't, C: ?Sized> crate::executor::Executor<'t> + impl<'c, 't, C: Sized> crate::executor::Executor<'t> for &'t mut crate::transaction::Transaction<'c, $DB, C> where C: crate::connection::Connection, @@ -189,8 +192,8 @@ macro_rules! impl_executor_for_transaction { impl<'c, DB, C> Debug for Transaction<'c, DB, C> where - DB: ?Sized + Database, - C: ?Sized + Connection, + DB: Database, + C: Sized + Connection, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { // TODO: Show the full type <..<..<.. @@ -200,8 +203,8 @@ where impl<'c, DB, C> Deref for Transaction<'c, DB, C> where - DB: ?Sized + Database, - C: ?Sized + Connection, + DB: Database, + C: Sized + Connection, { type Target = ::Connection; @@ -213,8 +216,8 @@ where impl<'c, DB, C> DerefMut for Transaction<'c, DB, C> where - DB: ?Sized + Database, - C: ?Sized + Connection, + DB: Database, + C: Sized + Connection, { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { @@ -224,8 +227,8 @@ where impl<'c, DB, C> Drop for Transaction<'c, DB, C> where - DB: ?Sized + Database, - C: ?Sized + Connection, + DB: Database, + C: Sized + Connection, { fn drop(&mut self) { // starts a rollback operation