sqlx/sqlx-postgres/src/transaction.rs
Austin Bonander 393b731d5e
Merge of #3427 (by @mpyw) and #3614 (by @bonsairobo) (#3765)
* feat: Implement `get_transaction_depth` for drivers

* test: Verify `get_transaction_depth()` on postgres

* Refactor: `TransactionManager` delegation without BC

SQLite implementation is currently WIP

* Fix: Avoid breaking changes on `AnyConnectionBackend`

* Refactor: Remove verbose `SqliteConnection` typing

* Feat: Implementation for SQLite

I have included `AtomicUsize` in `WorkerSharedState`. Ideally, it is not desirable to execute `load` and `fetch_add` in two separate steps, but we decided to allow it here since there is only one thread writing. To prevent writing from other threads, the field itself was made private, and a getter method was provided with `pub(crate)`.

* Refactor: Same approach for `cached_statements_size`

ref: a66787d36d62876b55475ef2326d17bade817aed

* Fix: Add missing `is_in_transaction` for backend

* Doc: Remove verbose "synchronously" word

* Fix: Remove useless `mut` qualifier

* feat: add Connection::begin_with

This patch completes the plumbing of an optional statement from these methods to
`TransactionManager::begin` without any validation of the provided statement.

There is a new `Error::InvalidSavePoint` which is triggered by any attempt to
call `Connection::begin_with` when we are already inside of a transaction.

* feat: add Pool::begin_with and Pool::try_begin_with

* feat: add Error::BeginFailed and validate that custom "begin" statements are successful

* chore: add tests of Error::BeginFailed

* chore: add tests of Error::InvalidSavePointStatement

* chore: test begin_with works for all SQLite "BEGIN" statements

* chore: improve comment on Connection::begin_with

* feat: add default impl of `Connection::begin_with`

This makes the new method a non-breaking change.

* refactor: combine if statement + unwrap_or_else into one match

* feat: use in-memory SQLite DB to avoid conflicts across tests run in parallel

* feedback: remove public wrapper for sqlite3_txn_state

Move the wrapper directly into the test that uses it instead.

* fix: cache Status on MySqlConnection

* fix: compilation errors

* fix: format

* fix: postgres test

* refactor: delete `Connection::get_transaction_depth`

* fix: tests

---------

Co-authored-by: mpyw <ryosuke_i_628@yahoo.co.jp>
Co-authored-by: Duncan Fairbanks <duncanfairbanks6@gmail.com>
2025-03-10 14:29:46 -07:00

111 lines
3.1 KiB
Rust

use futures_core::future::BoxFuture;
use sqlx_core::database::Database;
use std::borrow::Cow;
use crate::error::Error;
use crate::executor::Executor;
use crate::{PgConnection, Postgres};
pub(crate) use sqlx_core::transaction::*;
/// Implementation of [`TransactionManager`] for PostgreSQL.
pub struct PgTransactionManager;
impl TransactionManager for PgTransactionManager {
type Database = Postgres;
fn begin<'conn>(
conn: &'conn mut PgConnection,
statement: Option<Cow<'static, str>>,
) -> BoxFuture<'conn, Result<(), Error>> {
Box::pin(async move {
let depth = conn.inner.transaction_depth;
let statement = match statement {
// custom `BEGIN` statements are not allowed if we're already in
// a transaction (we need to issue a `SAVEPOINT` instead)
Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement),
Some(statement) => statement,
None => begin_ansi_transaction_sql(depth),
};
let rollback = Rollback::new(conn);
rollback.conn.queue_simple_query(&statement)?;
rollback.conn.wait_until_ready().await?;
if !rollback.conn.in_transaction() {
return Err(Error::BeginFailed);
}
rollback.conn.inner.transaction_depth += 1;
rollback.defuse();
Ok(())
})
}
fn commit(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
if conn.inner.transaction_depth > 0 {
conn.execute(&*commit_ansi_transaction_sql(conn.inner.transaction_depth))
.await?;
conn.inner.transaction_depth -= 1;
}
Ok(())
})
}
fn rollback(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
if conn.inner.transaction_depth > 0 {
conn.execute(&*rollback_ansi_transaction_sql(
conn.inner.transaction_depth,
))
.await?;
conn.inner.transaction_depth -= 1;
}
Ok(())
})
}
fn start_rollback(conn: &mut PgConnection) {
if conn.inner.transaction_depth > 0 {
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.inner.transaction_depth))
.expect("BUG: Rollback query somehow too large for protocol");
conn.inner.transaction_depth -= 1;
}
}
fn get_transaction_depth(conn: &<Self::Database as Database>::Connection) -> usize {
conn.inner.transaction_depth
}
}
struct Rollback<'c> {
conn: &'c mut PgConnection,
defuse: bool,
}
impl Drop for Rollback<'_> {
fn drop(&mut self) {
if !self.defuse {
PgTransactionManager::start_rollback(self.conn)
}
}
}
impl<'c> Rollback<'c> {
fn new(conn: &'c mut PgConnection) -> Self {
Self {
conn,
defuse: false,
}
}
fn defuse(mut self) {
self.defuse = true;
}
}