From 1f7af3abc22f86f11cbe2848ac014e82826c4f7e Mon Sep 17 00:00:00 2001 From: Kevin Cox Date: Mon, 18 Aug 2025 19:16:52 -0400 Subject: [PATCH] SQLite: fix transaction level accounting with bad custom command. (#3981) In the previous code the worker would always assume that the custom command worked. However the higher level code would run a check and notice that a transaction was not actually started and raise an error without rolling back the transaction. This improves the code by moving the transaction check into the worker to ensure that the transaction depth tracker is only modified if the user's custom command actually started a transaction. Fixes: https://github.com/launchbadge/sqlx/issues/3932 --- sqlx-sqlite/src/connection/handle.rs | 10 ++++++++-- sqlx-sqlite/src/connection/mod.rs | 9 ++------- sqlx-sqlite/src/connection/worker.rs | 9 ++++++++- sqlx-sqlite/src/transaction.rs | 11 +---------- tests/sqlite/sqlite.rs | 22 ++++++++++++++++++++++ 5 files changed, 41 insertions(+), 20 deletions(-) diff --git a/sqlx-sqlite/src/connection/handle.rs b/sqlx-sqlite/src/connection/handle.rs index 07a5a6da..7df3b1b7 100644 --- a/sqlx-sqlite/src/connection/handle.rs +++ b/sqlx-sqlite/src/connection/handle.rs @@ -4,8 +4,8 @@ use std::{io, ptr}; use crate::error::Error; use libsqlite3_sys::{ - sqlite3, sqlite3_close, sqlite3_exec, sqlite3_extended_result_codes, sqlite3_last_insert_rowid, - sqlite3_open_v2, SQLITE_OK, + sqlite3, sqlite3_close, sqlite3_exec, sqlite3_extended_result_codes, sqlite3_get_autocommit, + sqlite3_last_insert_rowid, sqlite3_open_v2, SQLITE_OK, }; use crate::SqliteError; @@ -78,6 +78,12 @@ impl ConnectionHandle { } } + pub(crate) fn in_transaction(&mut self) -> bool { + // SAFETY: we have exclusive access to the database handle + let ret = unsafe { sqlite3_get_autocommit(self.as_ptr()) }; + ret == 0 + } + pub(crate) fn last_insert_rowid(&mut self) -> i64 { // SAFETY: we have exclusive access to the database handle unsafe { sqlite3_last_insert_rowid(self.as_ptr()) } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index c32b81d3..218c7471 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -10,8 +10,8 @@ use std::ptr::NonNull; use futures_intrusive::sync::MutexGuard; use libsqlite3_sys::{ - sqlite3, sqlite3_commit_hook, sqlite3_get_autocommit, sqlite3_progress_handler, - sqlite3_rollback_hook, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, + sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; #[cfg(feature = "preupdate-hook")] pub use preupdate_hook::*; @@ -545,11 +545,6 @@ impl LockedSqliteHandle<'_> { pub fn last_error(&mut self) -> Option { self.guard.handle.last_error() } - - pub(crate) fn in_transaction(&mut self) -> bool { - let ret = unsafe { sqlite3_get_autocommit(self.as_raw_handle().as_ptr()) }; - ret == 0 - } } impl Drop for ConnectionState { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index ae50f3e8..7e84d456 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -213,6 +213,7 @@ impl ConnectionWorker { Command::Begin { tx, statement } => { let depth = shared.transaction_depth.load(Ordering::Acquire); + let is_custom_statement = statement.is_some(); let statement = match statement { // custom `BEGIN` statements are not allowed if // we're already in a transaction (we need to @@ -229,8 +230,14 @@ impl ConnectionWorker { let res = conn.handle .exec(statement.as_str()) - .map(|_| { + .and_then(|res| { + if is_custom_statement && !conn.handle.in_transaction() { + return Err(Error::BeginFailed) + } + shared.transaction_depth.fetch_add(1, Ordering::Release); + + Ok(res) }); let res_ok = res.is_ok(); diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 145999ff..1c57d016 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -12,16 +12,7 @@ impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; async fn begin(conn: &mut SqliteConnection, statement: Option) -> Result<(), Error> { - let is_custom_statement = statement.is_some(); - conn.worker.begin(statement).await?; - if is_custom_statement { - // Check that custom statement actually put the connection into a transaction. - let mut handle = conn.lock_handle().await?; - if !handle.in_transaction() { - return Err(Error::BeginFailed); - } - } - Ok(()) + conn.worker.begin(statement).await } fn commit(conn: &mut SqliteConnection) -> impl Future> + Send + '_ { diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 99626c28..caa77760 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -1375,6 +1375,28 @@ async fn it_can_use_transaction_options() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_can_recover_from_bad_transaction_begin() -> anyhow::Result<()> { + let mut conn = SqliteConnectOptions::new() + .in_memory(true) + .connect() + .await + .unwrap(); + + // This statement doesn't actually start a transaction. + assert!(conn.begin_with("SELECT 1").await.is_err()); + + // Transaction state bookkeeping should be correctly reset. + + let mut tx = conn.begin_with("BEGIN IMMEDIATE").await?; + let value = sqlx::query_scalar::<_, i32>("SELECT 1") + .fetch_one(&mut *tx) + .await?; + assert_eq!(value, 1); + + Ok(()) +} + fn transaction_state(handle: &mut LockedSqliteHandle) -> SqliteTransactionState { use libsqlite3_sys::{sqlite3_txn_state, SQLITE_TXN_NONE, SQLITE_TXN_READ, SQLITE_TXN_WRITE};