From daeb87bef19fff856fe8d4be47a91a8136774872 Mon Sep 17 00:00:00 2001 From: gridbox Date: Thu, 12 Sep 2024 14:57:02 -0400 Subject: [PATCH] Add sqlite commit and rollback hooks (#3500) * fix: Derive clone for SqliteOperation * feat: Add sqlite commit and rollback hooks --------- Co-authored-by: John Smith --- sqlx-sqlite/src/connection/establish.rs | 2 + sqlx-sqlite/src/connection/mod.rs | 124 +++++++++++++++++++++++- tests/sqlite/sqlite.rs | 114 +++++++++++++++++++++- 3 files changed, 236 insertions(+), 4 deletions(-) diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 6438b6b7..40f9b4c3 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -296,6 +296,8 @@ impl EstablishParams { log_settings: self.log_settings.clone(), progress_handler_callback: None, update_hook_callback: None, + commit_hook_callback: None, + rollback_hook_callback: None, }) } } diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 3588b94f..a579b8a6 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -11,8 +11,8 @@ use futures_core::future::BoxFuture; use futures_intrusive::sync::MutexGuard; use futures_util::future; use libsqlite3_sys::{ - sqlite3, sqlite3_progress_handler, sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, - SQLITE_UPDATE, + sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook, + sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE, }; pub(crate) use handle::ConnectionHandle; @@ -63,7 +63,7 @@ pub struct LockedSqliteHandle<'a> { pub(crate) struct Handler(NonNull bool + Send + 'static>); unsafe impl Send for Handler {} -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum SqliteOperation { Insert, Update, @@ -91,6 +91,12 @@ pub struct UpdateHookResult<'a> { pub(crate) struct UpdateHookHandler(NonNull); unsafe impl Send for UpdateHookHandler {} +pub(crate) struct CommitHookHandler(NonNull bool + Send + 'static>); +unsafe impl Send for CommitHookHandler {} + +pub(crate) struct RollbackHookHandler(NonNull); +unsafe impl Send for RollbackHookHandler {} + pub(crate) struct ConnectionState { pub(crate) handle: ConnectionHandle, @@ -106,6 +112,10 @@ pub(crate) struct ConnectionState { progress_handler_callback: Option, update_hook_callback: Option, + + commit_hook_callback: Option, + + rollback_hook_callback: Option, } impl ConnectionState { @@ -127,6 +137,24 @@ impl ConnectionState { } } } + + pub(crate) fn remove_commit_hook(&mut self) { + if let Some(mut handler) = self.commit_hook_callback.take() { + unsafe { + sqlite3_commit_hook(self.handle.as_ptr(), None, ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } + + pub(crate) fn remove_rollback_hook(&mut self) { + if let Some(mut handler) = self.rollback_hook_callback.take() { + unsafe { + sqlite3_rollback_hook(self.handle.as_ptr(), None, ptr::null_mut()); + let _ = { Box::from_raw(handler.0.as_mut()) }; + } + } + } } pub(crate) struct Statements { @@ -284,6 +312,31 @@ extern "C" fn update_hook( } } +extern "C" fn commit_hook(callback: *mut c_void) -> c_int +where + F: FnMut() -> bool, +{ + unsafe { + let r = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + (*callback)() + }); + c_int::from(!r.unwrap_or_default()) + } +} + +extern "C" fn rollback_hook(callback: *mut c_void) +where + F: FnMut(), +{ + unsafe { + let _ = catch_unwind(|| { + let callback: *mut F = callback.cast::(); + (*callback)() + }); + } +} + impl LockedSqliteHandle<'_> { /// Returns the underlying sqlite3* connection handle. /// @@ -368,6 +421,61 @@ impl LockedSqliteHandle<'_> { } } + /// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback + /// returns `false`, then the operation is turned into a ROLLBACK. + /// + /// Only a single commit hook may be defined at one time per database connection; setting a new commit hook + /// overrides the old one. + /// + /// The commit hook callback must not do anything that will modify the database connection that invoked + /// the commit hook. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections + /// in this context. + /// + /// See https://www.sqlite.org/c3ref/commit_hook.html + pub fn set_commit_hook(&mut self, callback: F) + where + F: FnMut() -> bool + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_commit_hook(); + self.guard.commit_hook_callback = Some(CommitHookHandler(callback)); + + sqlite3_commit_hook( + self.as_raw_handle().as_mut(), + Some(commit_hook::), + handler, + ); + } + } + + /// Sets a rollback hook that is invoked whenever a transaction rollback occurs. The rollback callback is not + /// invoked if a transaction is automatically rolled back because the database connection is closed. + /// + /// See https://www.sqlite.org/c3ref/commit_hook.html + pub fn set_rollback_hook(&mut self, callback: F) + where + F: FnMut() + Send + 'static, + { + unsafe { + let callback_boxed = Box::new(callback); + // SAFETY: `Box::into_raw()` always returns a non-null pointer. + let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed)); + let handler = callback.as_ptr() as *mut _; + self.guard.remove_rollback_hook(); + self.guard.rollback_hook_callback = Some(RollbackHookHandler(callback)); + + sqlite3_rollback_hook( + self.as_raw_handle().as_mut(), + Some(rollback_hook::), + handler, + ); + } + } + /// Removes the progress handler on a database connection. The method does nothing if no handler was set. pub fn remove_progress_handler(&mut self) { self.guard.remove_progress_handler(); @@ -376,6 +484,14 @@ impl LockedSqliteHandle<'_> { pub fn remove_update_hook(&mut self) { self.guard.remove_update_hook(); } + + pub fn remove_commit_hook(&mut self) { + self.guard.remove_commit_hook(); + } + + pub fn remove_rollback_hook(&mut self) { + self.guard.remove_rollback_hook(); + } } impl Drop for ConnectionState { @@ -384,6 +500,8 @@ impl Drop for ConnectionState { self.statements.clear(); self.remove_progress_handler(); self.remove_update_hook(); + self.remove_commit_hook(); + self.remove_rollback_hook(); } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index c47b1a77..b733ccbb 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -806,7 +806,7 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> { assert_eq!(result.operation, SqliteOperation::Insert); assert_eq!(result.database, "main"); assert_eq!(result.table, "tweet"); - assert_eq!(result.rowid, 3); + assert_eq!(result.rowid, 2); }); let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )") @@ -848,3 +848,115 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul assert_eq!(1, Arc::strong_count(&ref_counted_object)); Ok(()) } + +#[sqlx_macros::test] +async fn test_query_with_commit_hook() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_commit_hook(move || { + assert_eq!(state, "test"); + false + }); + + let mut tx = conn.begin().await?; + sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )") + .execute(&mut *tx) + .await?; + match tx.commit().await { + Err(sqlx::Error::Database(err)) => { + assert_eq!(err.message(), String::from("constraint failed")) + } + _ => panic!("expected an error"), + } + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_multiple_set_commit_hook_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_commit_hook(move || { + println!("{o:?}"); + true + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_commit_hook(move || { + println!("{o:?}"); + true + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_commit_hook(move || { + println!("{o:?}"); + true + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + conn.lock_handle().await?.remove_commit_hook(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_query_with_rollback_hook() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. + let state = format!("test"); + conn.lock_handle().await?.set_rollback_hook(move || { + assert_eq!(state, "test"); + }); + + let mut tx = conn.begin().await?; + sqlx::query("INSERT INTO tweet ( id, text ) VALUES (5, 'Hello, World' )") + .execute(&mut *tx) + .await?; + tx.rollback().await?; + Ok(()) +} + +#[sqlx_macros::test] +async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Result<()> { + let ref_counted_object = Arc::new(0); + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + + { + let mut conn = new::().await?; + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_rollback_hook(move || { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_rollback_hook(move || { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + let o = ref_counted_object.clone(); + conn.lock_handle().await?.set_rollback_hook(move || { + println!("{o:?}"); + }); + assert_eq!(2, Arc::strong_count(&ref_counted_object)); + + conn.lock_handle().await?.remove_rollback_hook(); + } + + assert_eq!(1, Arc::strong_count(&ref_counted_object)); + Ok(()) +}