diff --git a/sqlx-sqlite/src/connection/handle.rs b/sqlx-sqlite/src/connection/handle.rs index 07a5a6da..7df3b1b7 100644 --- a/sqlx-sqlite/src/connection/handle.rs +++ b/sqlx-sqlite/src/connection/handle.rs @@ -4,8 +4,8 @@ use std::{io, ptr}; use crate::error::Error; use libsqlite3_sys::{ - sqlite3, sqlite3_close, sqlite3_exec, sqlite3_extended_result_codes, sqlite3_last_insert_rowid, - sqlite3_open_v2, SQLITE_OK, + sqlite3, sqlite3_close, sqlite3_exec, sqlite3_extended_result_codes, sqlite3_get_autocommit, + sqlite3_last_insert_rowid, sqlite3_open_v2, SQLITE_OK, }; use crate::SqliteError; @@ -78,6 +78,12 @@ impl ConnectionHandle { } } + pub(crate) fn in_transaction(&mut self) -> bool { + // SAFETY: we have exclusive access to the database handle + let ret = unsafe { sqlite3_get_autocommit(self.as_ptr()) }; + ret == 0 + } + pub(crate) fn last_insert_rowid(&mut self) -> i64 { // SAFETY: we have exclusive access to the database handle unsafe { sqlite3_last_insert_rowid(self.as_ptr()) } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index c32b81d3..218c7471 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -10,8 +10,8 @@ use std::ptr::NonNull; use futures_intrusive::sync::MutexGuard; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler, - sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, + sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; #[cfg(feature = "preupdate-hook")] pub use preupdate_hook::*; @@ -545,11 +545,6 @@ impl LockedSqliteHandle<'_> { pub fn last_error(&mut self) -> Option { self.guard.handle.last_error() } - - pub(crate) fn in_transaction(&mut self) -> bool { - let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; - ret == 0 - } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index ae50f3e8..7e84d456 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -213,6 +213,7 @@ impl ConnectionWorker { Command::Begin { tx, statement } => { let depth = shared.transaction_depth.load(Ordering::Acquire); + let is_custom_statement = statement.is_some(); let statement = match statement { // custom `BEGIN` statements are not allowed if // we're already in a transaction (we need to @@ -229,8 +230,14 @@ impl ConnectionWorker { let res = conn.handle .exec(statement.as_str()) - .map(|_| { + .and_then(|res| { + if is_custom_statement && !conn.handle.in_transaction() { + return Err(Error::BeginFailed) + } + shared.transaction_depth.fetch_add(1, Ordering::Release); + + Ok(res) }); let res_ok = res.is_ok(); diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 145999ff..1c57d016 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -12,16 +12,7 @@ impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; async fn begin(conn: &mut SqliteConnection, statement: Option) -> Result<(), Error> { - let is_custom_statement = statement.is_some(); - conn.worker.begin(statement).await?; - if is_custom_statement { - // Check that custom statement actually put the connection into a transaction. - let mut handle = conn.lock_handle().await?; - if !handle.in_transaction() { - return Err(Error::BeginFailed); - } - } - Ok(()) + conn.worker.begin(statement).await } fn commit(conn: &mut SqliteConnection) -> impl Future> + Send + '_ { diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 99626c28..caa77760 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1375,6 +1375,28 @@ async fn it_can_use_transaction_options() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_can_recover_from_bad_transaction_begin() -> anyhow::Result<()> { + let mut conn = SqliteConnectOptions::new() + .in_memory(true) + .connect() + .await + .unwrap(); + + // This statement doesn't actually start a transaction. + assert!(conn.begin_with("SELECT 1").await.is_err()); + + // Transaction state bookkeeping should be correctly reset. + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + let value = sqlx::query_scalar::<_, i32>("SELECT 1") + .fetch_one(&mut *tx) + .await?; + assert_eq!(value, 1); + + Ok(()) +} + fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState { use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE};