mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-09-27 13:01:43 +00:00
feat(sqlite): add preupdate hook (#3625)
* feat: add preupdate hook * address some PR comments * add SqliteValueRef variant that takes a borrowed sqlite value pointer * add PhantomData for additional lifetime check
This commit is contained in:
parent
f6d2fa3a3d
commit
aae800090b
6
.github/workflows/sqlx.yml
vendored
6
.github/workflows/sqlx.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
- run: >
|
||||
cargo clippy
|
||||
--no-default-features
|
||||
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
-- -D warnings
|
||||
|
||||
# Run beta for new warnings but don't break the build.
|
||||
@ -47,7 +47,7 @@ jobs:
|
||||
- run: >
|
||||
cargo +beta clippy
|
||||
--no-default-features
|
||||
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
|
||||
--target-dir target/beta/
|
||||
|
||||
check-minimal-versions:
|
||||
@ -140,7 +140,7 @@ jobs:
|
||||
- run: >
|
||||
cargo test
|
||||
--no-default-features
|
||||
--features any,macros,${{ matrix.linking }},_unstable-all-types,runtime-${{ matrix.runtime }}
|
||||
--features any,macros,${{ matrix.linking }},${{ matrix.linking == 'sqlite' && 'sqlite-preupdate-hook,' || ''}}_unstable-all-types,runtime-${{ matrix.runtime }}
|
||||
--
|
||||
--test-threads=1
|
||||
env:
|
||||
|
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -3879,6 +3879,7 @@ dependencies = [
|
||||
"serde_urlencoded",
|
||||
"sqlx",
|
||||
"sqlx-core",
|
||||
"thiserror 2.0.11",
|
||||
"time",
|
||||
"tracing",
|
||||
"url",
|
||||
|
@ -50,7 +50,7 @@ authors.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["all-databases", "_unstable-all-types"]
|
||||
features = ["all-databases", "_unstable-all-types", "sqlite-preupdate-hook"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
|
||||
[features]
|
||||
@ -108,6 +108,7 @@ postgres = ["sqlx-postgres", "sqlx-macros?/postgres"]
|
||||
mysql = ["sqlx-mysql", "sqlx-macros?/mysql"]
|
||||
sqlite = ["_sqlite", "sqlx-sqlite/bundled", "sqlx-macros?/sqlite"]
|
||||
sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled", "sqlx-macros?/sqlite-unbundled"]
|
||||
sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"]
|
||||
|
||||
# types
|
||||
json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"]
|
||||
|
@ -196,6 +196,10 @@ be removed in the future.
|
||||
* May result in link errors if the SQLite version is too old. Version `3.20.0` or newer is recommended.
|
||||
* Can increase build time due to the use of bindgen.
|
||||
|
||||
- `sqlite-preupdate-hook`: enables SQLite's [preupdate hook](https://sqlite.org/c3ref/preupdate_count.html) API.
|
||||
* Exposed as a separate feature because it's generally not enabled by default.
|
||||
* Using this feature with `sqlite-unbundled` may cause linker failures if the system SQLite version does not support it.
|
||||
|
||||
- `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime.
|
||||
|
||||
- `derive`: Add support for the derive family macros, those are `FromRow`, `Type`, `Encode`, `Decode`.
|
||||
|
@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"]
|
||||
|
||||
regexp = ["dep:regex"]
|
||||
|
||||
preupdate-hook = ["libsqlite3-sys/preupdate_hook"]
|
||||
|
||||
bundled = ["libsqlite3-sys/bundled"]
|
||||
unbundled = ["libsqlite3-sys/buildtime_bindgen"]
|
||||
|
||||
@ -48,6 +50,7 @@ atoi = "2.0"
|
||||
|
||||
log = "0.4.18"
|
||||
tracing = { version = "0.1.37", features = ["log"] }
|
||||
thiserror = "2.0.0"
|
||||
|
||||
serde = { version = "1.0.145", features = ["derive"], optional = true }
|
||||
regex = { version = "1.5.5", optional = true }
|
||||
|
@ -296,6 +296,8 @@ impl EstablishParams {
|
||||
log_settings: self.log_settings.clone(),
|
||||
progress_handler_callback: None,
|
||||
update_hook_callback: None,
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
preupdate_hook_callback: None,
|
||||
commit_hook_callback: None,
|
||||
rollback_hook_callback: None,
|
||||
})
|
||||
|
@ -14,6 +14,8 @@ use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
|
||||
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
|
||||
};
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub use preupdate_hook::*;
|
||||
|
||||
pub(crate) use handle::ConnectionHandle;
|
||||
use sqlx_core::common::StatementCache;
|
||||
@ -36,6 +38,8 @@ mod executor;
|
||||
mod explain;
|
||||
mod handle;
|
||||
pub(crate) mod intmap;
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
mod preupdate_hook;
|
||||
|
||||
mod worker;
|
||||
|
||||
@ -88,6 +92,7 @@ pub struct UpdateHookResult<'a> {
|
||||
pub table: &'a str,
|
||||
pub rowid: i64,
|
||||
}
|
||||
|
||||
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
|
||||
unsafe impl Send for UpdateHookHandler {}
|
||||
|
||||
@ -112,6 +117,8 @@ pub(crate) struct ConnectionState {
|
||||
progress_handler_callback: Option<Handler>,
|
||||
|
||||
update_hook_callback: Option<UpdateHookHandler>,
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
preupdate_hook_callback: Option<preupdate_hook::PreupdateHookHandler>,
|
||||
|
||||
commit_hook_callback: Option<CommitHookHandler>,
|
||||
|
||||
@ -138,6 +145,16 @@ impl ConnectionState {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub(crate) fn remove_preupdate_hook(&mut self) {
|
||||
if let Some(mut handler) = self.preupdate_hook_callback.take() {
|
||||
unsafe {
|
||||
libsqlite3_sys::sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
|
||||
let _ = { Box::from_raw(handler.0.as_mut()) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn remove_commit_hook(&mut self) {
|
||||
if let Some(mut handler) = self.commit_hook_callback.take() {
|
||||
unsafe {
|
||||
@ -421,6 +438,34 @@ impl LockedSqliteHandle<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table.
|
||||
/// At most one preupdate hook may be registered at a time on a single database connection.
|
||||
///
|
||||
/// The preupdate hook only fires for changes to real database tables;
|
||||
/// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1.
|
||||
///
|
||||
/// See https://sqlite.org/c3ref/preupdate_count.html
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub fn set_preupdate_hook<F>(&mut self, callback: F)
|
||||
where
|
||||
F: FnMut(PreupdateHookResult) + Send + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let callback_boxed = Box::new(callback);
|
||||
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
|
||||
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
|
||||
let handler = callback.as_ptr() as *mut _;
|
||||
self.guard.remove_preupdate_hook();
|
||||
self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));
|
||||
|
||||
libsqlite3_sys::sqlite3_preupdate_hook(
|
||||
self.as_raw_handle().as_mut(),
|
||||
Some(preupdate_hook::<F>),
|
||||
handler,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback
|
||||
/// returns `false`, then the operation is turned into a ROLLBACK.
|
||||
///
|
||||
@ -485,6 +530,11 @@ impl LockedSqliteHandle<'_> {
|
||||
self.guard.remove_update_hook();
|
||||
}
|
||||
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub fn remove_preupdate_hook(&mut self) {
|
||||
self.guard.remove_preupdate_hook();
|
||||
}
|
||||
|
||||
pub fn remove_commit_hook(&mut self) {
|
||||
self.guard.remove_commit_hook();
|
||||
}
|
||||
|
160
sqlx-sqlite/src/connection/preupdate_hook.rs
Normal file
160
sqlx-sqlite/src/connection/preupdate_hook.rs
Normal file
@ -0,0 +1,160 @@
|
||||
use super::SqliteOperation;
|
||||
use crate::type_info::DataType;
|
||||
use crate::{SqliteError, SqliteTypeInfo, SqliteValueRef};
|
||||
|
||||
use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new,
|
||||
sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK,
|
||||
};
|
||||
use std::ffi::CStr;
|
||||
use std::marker::PhantomData;
|
||||
use std::os::raw::{c_char, c_int, c_void};
|
||||
use std::panic::catch_unwind;
|
||||
use std::ptr;
|
||||
use std::ptr::NonNull;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PreupdateError {
|
||||
/// Error returned from the database.
|
||||
#[error("error returned from database: {0}")]
|
||||
Database(#[source] SqliteError),
|
||||
/// Index is not within the valid column range
|
||||
#[error("{0} is not within the valid column range")]
|
||||
ColumnIndexOutOfBounds(i32),
|
||||
/// Column value accessor was invoked from an invalid operation
|
||||
#[error("column value accessor was invoked from an invalid operation")]
|
||||
InvalidOperation,
|
||||
}
|
||||
|
||||
pub(crate) struct PreupdateHookHandler(
|
||||
pub(super) NonNull<dyn FnMut(PreupdateHookResult) + Send + 'static>,
|
||||
);
|
||||
unsafe impl Send for PreupdateHookHandler {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PreupdateHookResult<'a> {
|
||||
pub operation: SqliteOperation,
|
||||
pub database: &'a str,
|
||||
pub table: &'a str,
|
||||
db: *mut sqlite3,
|
||||
// The database pointer should not be usable after the preupdate hook.
|
||||
// The lifetime on this struct needs to ensure it cannot outlive the callback.
|
||||
_db_lifetime: PhantomData<&'a ()>,
|
||||
old_row_id: i64,
|
||||
new_row_id: i64,
|
||||
}
|
||||
|
||||
impl<'a> PreupdateHookResult<'a> {
|
||||
/// Gets the amount of columns in the row being inserted, deleted, or updated.
|
||||
pub fn get_column_count(&self) -> i32 {
|
||||
unsafe { sqlite3_preupdate_count(self.db) }
|
||||
}
|
||||
|
||||
/// Gets the depth of the query that triggered the preupdate hook.
|
||||
/// Returns 0 if the preupdate callback was invoked as a result of
|
||||
/// a direct insert, update, or delete operation;
|
||||
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
|
||||
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
|
||||
pub fn get_query_depth(&self) -> i32 {
|
||||
unsafe { sqlite3_preupdate_depth(self.db) }
|
||||
}
|
||||
|
||||
/// Gets the row id of the row being updated/deleted.
|
||||
/// Returns an error if called from an insert operation.
|
||||
pub fn get_old_row_id(&self) -> Result<i64, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Insert {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
Ok(self.old_row_id)
|
||||
}
|
||||
|
||||
/// Gets the row id of the row being inserted/updated.
|
||||
/// Returns an error if called from a delete operation.
|
||||
pub fn get_new_row_id(&self) -> Result<i64, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Delete {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
Ok(self.new_row_id)
|
||||
}
|
||||
|
||||
/// Gets the value of the row being updated/deleted at the specified index.
|
||||
/// Returns an error if called from an insert operation or the index is out of bounds.
|
||||
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Insert {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
self.validate_column_index(i)?;
|
||||
|
||||
let mut p_value: *mut sqlite3_value = ptr::null_mut();
|
||||
unsafe {
|
||||
let ret = sqlite3_preupdate_old(self.db, i, &mut p_value);
|
||||
self.get_value(ret, p_value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the value of the row being inserted/updated at the specified index.
|
||||
/// Returns an error if called from a delete operation or the index is out of bounds.
|
||||
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
|
||||
if self.operation == SqliteOperation::Delete {
|
||||
return Err(PreupdateError::InvalidOperation);
|
||||
}
|
||||
self.validate_column_index(i)?;
|
||||
|
||||
let mut p_value: *mut sqlite3_value = ptr::null_mut();
|
||||
unsafe {
|
||||
let ret = sqlite3_preupdate_new(self.db, i, &mut p_value);
|
||||
self.get_value(ret, p_value)
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_column_index(&self, i: i32) -> Result<(), PreupdateError> {
|
||||
if i < 0 || i >= self.get_column_count() {
|
||||
return Err(PreupdateError::ColumnIndexOutOfBounds(i));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
unsafe fn get_value(
|
||||
&self,
|
||||
ret: i32,
|
||||
p_value: *mut sqlite3_value,
|
||||
) -> Result<SqliteValueRef<'a>, PreupdateError> {
|
||||
if ret != SQLITE_OK {
|
||||
return Err(PreupdateError::Database(SqliteError::new(self.db)));
|
||||
}
|
||||
let data_type = DataType::from_code(sqlite3_value_type(p_value));
|
||||
// SAFETY: SQLite will free the sqlite3_value when the callback returns
|
||||
Ok(SqliteValueRef::borrowed(p_value, SqliteTypeInfo(data_type)))
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) extern "C" fn preupdate_hook<F>(
|
||||
callback: *mut c_void,
|
||||
db: *mut sqlite3,
|
||||
op_code: c_int,
|
||||
database: *const c_char,
|
||||
table: *const c_char,
|
||||
old_row_id: i64,
|
||||
new_row_id: i64,
|
||||
) where
|
||||
F: FnMut(PreupdateHookResult) + Send + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let _ = catch_unwind(|| {
|
||||
let callback: *mut F = callback.cast::<F>();
|
||||
let operation: SqliteOperation = op_code.into();
|
||||
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
|
||||
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
|
||||
|
||||
(*callback)(PreupdateHookResult {
|
||||
operation,
|
||||
database,
|
||||
table,
|
||||
old_row_id,
|
||||
new_row_id,
|
||||
db,
|
||||
_db_lifetime: PhantomData,
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
@ -46,6 +46,8 @@ use std::sync::atomic::AtomicBool;
|
||||
|
||||
pub use arguments::{SqliteArgumentValue, SqliteArguments};
|
||||
pub use column::SqliteColumn;
|
||||
#[cfg(feature = "preupdate-hook")]
|
||||
pub use connection::PreupdateHookResult;
|
||||
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
|
||||
pub use database::Sqlite;
|
||||
pub use error::SqliteError;
|
||||
|
@ -1,4 +1,5 @@
|
||||
use std::borrow::Cow;
|
||||
use std::marker::PhantomData;
|
||||
use std::ptr::NonNull;
|
||||
use std::slice::from_raw_parts;
|
||||
use std::str::from_utf8;
|
||||
@ -17,6 +18,7 @@ use crate::{Sqlite, SqliteTypeInfo};
|
||||
|
||||
enum SqliteValueData<'r> {
|
||||
Value(&'r SqliteValue),
|
||||
BorrowedHandle(ValueHandle<'r>),
|
||||
}
|
||||
|
||||
pub struct SqliteValueRef<'r>(SqliteValueData<'r>);
|
||||
@ -26,31 +28,44 @@ impl<'r> SqliteValueRef<'r> {
|
||||
Self(SqliteValueData::Value(value))
|
||||
}
|
||||
|
||||
// SAFETY: The supplied sqlite3_value must not be null and SQLite must free it. It will not be freed on drop.
|
||||
// The lifetime on this struct should tie it to whatever scope it's valid for before SQLite frees it.
|
||||
#[allow(unused)]
|
||||
pub(crate) unsafe fn borrowed(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
|
||||
debug_assert!(!value.is_null());
|
||||
let handle = ValueHandle::new_borrowed(NonNull::new_unchecked(value), type_info);
|
||||
Self(SqliteValueData::BorrowedHandle(handle))
|
||||
}
|
||||
|
||||
// NOTE: `int()` is deliberately omitted because it will silently truncate a wider value,
|
||||
// which is likely to cause bugs:
|
||||
// https://github.com/launchbadge/sqlx/issues/3179
|
||||
// (Similar bug in Postgres): https://github.com/launchbadge/sqlx/issues/3161
|
||||
pub(super) fn int64(&self) -> i64 {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.int64(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.int64(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.int64(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn double(&self) -> f64 {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.double(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.double(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.double(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn blob(&self) -> &'r [u8] {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.blob(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.blob(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.blob(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn text(&self) -> Result<&'r str, BoxDynError> {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.text(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.0.text(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.text(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -59,50 +74,66 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> {
|
||||
type Database = Sqlite;
|
||||
|
||||
fn to_owned(&self) -> SqliteValue {
|
||||
match self.0 {
|
||||
SqliteValueData::Value(v) => v.clone(),
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => (*v).clone(),
|
||||
SqliteValueData::BorrowedHandle(v) => unsafe {
|
||||
SqliteValue::new(v.value.as_ptr(), v.type_info.clone())
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
|
||||
match self.0 {
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.type_info(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.type_info(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
match self.0 {
|
||||
match &self.0 {
|
||||
SqliteValueData::Value(v) => v.is_null(),
|
||||
SqliteValueData::BorrowedHandle(v) => v.is_null(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SqliteValue {
|
||||
pub(crate) handle: Arc<ValueHandle>,
|
||||
pub(crate) type_info: SqliteTypeInfo,
|
||||
pub struct SqliteValue(Arc<ValueHandle<'static>>);
|
||||
|
||||
pub(crate) struct ValueHandle<'a> {
|
||||
value: NonNull<sqlite3_value>,
|
||||
type_info: SqliteTypeInfo,
|
||||
free_on_drop: bool,
|
||||
_sqlite_value_lifetime: PhantomData<&'a ()>,
|
||||
}
|
||||
|
||||
pub(crate) struct ValueHandle(NonNull<sqlite3_value>);
|
||||
|
||||
// SAFE: only protected value objects are stored in SqliteValue
|
||||
unsafe impl Send for ValueHandle {}
|
||||
unsafe impl Sync for ValueHandle {}
|
||||
|
||||
impl SqliteValue {
|
||||
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
|
||||
debug_assert!(!value.is_null());
|
||||
unsafe impl<'a> Send for ValueHandle<'a> {}
|
||||
unsafe impl<'a> Sync for ValueHandle<'a> {}
|
||||
|
||||
impl ValueHandle<'static> {
|
||||
fn new_owned(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
|
||||
Self {
|
||||
value,
|
||||
type_info,
|
||||
handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup(
|
||||
value,
|
||||
)))),
|
||||
free_on_drop: true,
|
||||
_sqlite_value_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ValueHandle<'a> {
|
||||
fn new_borrowed(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
|
||||
Self {
|
||||
value,
|
||||
type_info,
|
||||
free_on_drop: false,
|
||||
_sqlite_value_lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn type_info_opt(&self) -> Option<SqliteTypeInfo> {
|
||||
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.handle.0.as_ptr()) });
|
||||
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.value.as_ptr()) });
|
||||
|
||||
if let DataType::Null = dt {
|
||||
None
|
||||
@ -112,15 +143,15 @@ impl SqliteValue {
|
||||
}
|
||||
|
||||
fn int64(&self) -> i64 {
|
||||
unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) }
|
||||
unsafe { sqlite3_value_int64(self.value.as_ptr()) }
|
||||
}
|
||||
|
||||
fn double(&self) -> f64 {
|
||||
unsafe { sqlite3_value_double(self.handle.0.as_ptr()) }
|
||||
unsafe { sqlite3_value_double(self.value.as_ptr()) }
|
||||
}
|
||||
|
||||
fn blob(&self) -> &[u8] {
|
||||
let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) };
|
||||
fn blob<'b>(&self) -> &'b [u8] {
|
||||
let len = unsafe { sqlite3_value_bytes(self.value.as_ptr()) };
|
||||
|
||||
// This likely means UB in SQLite itself or our usage of it;
|
||||
// signed integer overflow is UB in the C standard.
|
||||
@ -133,15 +164,45 @@ impl SqliteValue {
|
||||
return &[];
|
||||
}
|
||||
|
||||
let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8;
|
||||
let ptr = unsafe { sqlite3_value_blob(self.value.as_ptr()) } as *const u8;
|
||||
debug_assert!(!ptr.is_null());
|
||||
|
||||
unsafe { from_raw_parts(ptr, len) }
|
||||
}
|
||||
|
||||
fn text(&self) -> Result<&str, BoxDynError> {
|
||||
fn text<'b>(&self) -> Result<&'b str, BoxDynError> {
|
||||
Ok(from_utf8(self.blob())?)
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
|
||||
self.type_info_opt()
|
||||
.map(Cow::Owned)
|
||||
.unwrap_or(Cow::Borrowed(&self.type_info))
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
unsafe { sqlite3_value_type(self.value.as_ptr()) == SQLITE_NULL }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for ValueHandle<'a> {
|
||||
fn drop(&mut self) {
|
||||
if self.free_on_drop {
|
||||
unsafe {
|
||||
sqlite3_value_free(self.value.as_ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SqliteValue {
|
||||
// SAFETY: The sqlite3_value must be non-null and SQLite must not free it. It will be freed on drop.
|
||||
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
|
||||
debug_assert!(!value.is_null());
|
||||
let handle =
|
||||
ValueHandle::new_owned(NonNull::new_unchecked(sqlite3_value_dup(value)), type_info);
|
||||
Self(Arc::new(handle))
|
||||
}
|
||||
}
|
||||
|
||||
impl Value for SqliteValue {
|
||||
@ -152,21 +213,11 @@ impl Value for SqliteValue {
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
|
||||
self.type_info_opt()
|
||||
.map(Cow::Owned)
|
||||
.unwrap_or(Cow::Borrowed(&self.type_info))
|
||||
self.0.type_info()
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ValueHandle {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
sqlite3_value_free(self.0.as_ptr());
|
||||
}
|
||||
self.0.is_null()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,14 @@
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![doc = include_str!("lib.md")]
|
||||
|
||||
#[cfg(all(
|
||||
feature = "sqlite-preupdate-hook",
|
||||
not(any(feature = "sqlite", feature = "sqlite-unbundled"))
|
||||
))]
|
||||
compile_error!(
|
||||
"sqlite-preupdate-hook requires either 'sqlite' or 'sqlite-unbundled' to be enabled"
|
||||
);
|
||||
|
||||
pub use sqlx_core::acquire::Acquire;
|
||||
pub use sqlx_core::arguments::{Arguments, IntoArguments};
|
||||
pub use sqlx_core::column::Column;
|
||||
|
@ -2,11 +2,14 @@ use futures::TryStreamExt;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_xoshiro::Xoshiro256PlusPlus;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions};
|
||||
use sqlx::Decode;
|
||||
use sqlx::{
|
||||
query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row,
|
||||
SqliteConnection, SqlitePool, Statement, TypeInfo,
|
||||
};
|
||||
use sqlx::{Value, ValueRef};
|
||||
use sqlx_test::new;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[sqlx_macros::test]
|
||||
@ -798,7 +801,7 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_update_hook() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_update_hook(move |result| {
|
||||
@ -807,11 +810,13 @@ async fn test_query_with_update_hook() -> anyhow::Result<()> {
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
assert_eq!(result.rowid, 2);
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -852,10 +857,11 @@ async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Resul
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_commit_hook() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_commit_hook(move || {
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
assert_eq!(state, "test");
|
||||
false
|
||||
});
|
||||
@ -870,7 +876,7 @@ async fn test_query_with_commit_hook() -> anyhow::Result<()> {
|
||||
}
|
||||
_ => panic!("expected an error"),
|
||||
}
|
||||
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -916,8 +922,10 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> {
|
||||
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
conn.lock_handle().await?.set_rollback_hook(move || {
|
||||
assert_eq!(state, "test");
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
let mut tx = conn.begin().await?;
|
||||
@ -925,6 +933,7 @@ async fn test_query_with_rollback_hook() -> anyhow::Result<()> {
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
tx.rollback().await?;
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -960,3 +969,206 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_preupdate_hook({
|
||||
move |result| {
|
||||
assert_eq!(state, "test");
|
||||
assert_eq!(result.operation, SqliteOperation::Insert);
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
|
||||
assert_eq!(4, result.get_column_count());
|
||||
assert_eq!(2, result.get_new_row_id().unwrap());
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
assert_eq!(
|
||||
4,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_new_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
"Hello, World",
|
||||
<String as Decode<Sqlite>>::decode(result.get_new_column_value(1).unwrap())
|
||||
.unwrap()
|
||||
);
|
||||
// out of bounds access should return an error
|
||||
assert!(result.get_new_column_value(4).is_err());
|
||||
// old values aren't available for inserts
|
||||
assert!(result.get_old_column_value(0).is_err());
|
||||
assert!(result.get_old_row_id().is_err());
|
||||
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
}
|
||||
});
|
||||
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 4, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
conn.lock_handle().await?.remove_preupdate_hook();
|
||||
let _ = sqlx::query("DELETE FROM tweet where id = 4")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |result| {
|
||||
assert_eq!(state, "test");
|
||||
assert_eq!(result.operation, SqliteOperation::Delete);
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
|
||||
assert_eq!(4, result.get_column_count());
|
||||
assert_eq!(2, result.get_old_row_id().unwrap());
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
assert_eq!(
|
||||
5,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_old_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
"Hello, World",
|
||||
<String as Decode<Sqlite>>::decode(result.get_old_column_value(1).unwrap()).unwrap()
|
||||
);
|
||||
// out of bounds access should return an error
|
||||
assert!(result.get_old_column_value(4).is_err());
|
||||
// new values aren't available for deletes
|
||||
assert!(result.get_new_column_value(0).is_err());
|
||||
assert!(result.get_new_row_id().is_err());
|
||||
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
let _ = sqlx::query("DELETE FROM tweet WHERE id = 5")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
static CALLED: AtomicBool = AtomicBool::new(false);
|
||||
let sqlite_value_stored: Arc<std::sync::Mutex<Option<_>>> = Default::default();
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_preupdate_hook({
|
||||
let sqlite_value_stored = sqlite_value_stored.clone();
|
||||
move |result| {
|
||||
assert_eq!(state, "test");
|
||||
assert_eq!(result.operation, SqliteOperation::Update);
|
||||
assert_eq!(result.database, "main");
|
||||
assert_eq!(result.table, "tweet");
|
||||
|
||||
assert_eq!(4, result.get_column_count());
|
||||
assert_eq!(4, result.get_column_count());
|
||||
|
||||
assert_eq!(2, result.get_old_row_id().unwrap());
|
||||
assert_eq!(2, result.get_new_row_id().unwrap());
|
||||
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
assert_eq!(0, result.get_query_depth());
|
||||
|
||||
assert_eq!(
|
||||
6,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_old_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
6,
|
||||
<i64 as Decode<Sqlite>>::decode(result.get_new_column_value(0).unwrap()).unwrap()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
"Hello, World",
|
||||
<String as Decode<Sqlite>>::decode(result.get_old_column_value(1).unwrap())
|
||||
.unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
"Hello, World2",
|
||||
<String as Decode<Sqlite>>::decode(result.get_new_column_value(1).unwrap())
|
||||
.unwrap()
|
||||
);
|
||||
*sqlite_value_stored.lock().unwrap() =
|
||||
Some(result.get_old_column_value(0).unwrap().to_owned());
|
||||
|
||||
// out of bounds access should return an error
|
||||
assert!(result.get_old_column_value(4).is_err());
|
||||
assert!(result.get_new_column_value(4).is_err());
|
||||
|
||||
CALLED.store(true, Ordering::Relaxed);
|
||||
}
|
||||
});
|
||||
|
||||
let _ = sqlx::query("UPDATE tweet SET text = 'Hello, World2' WHERE id = 6")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(CALLED.load(Ordering::Relaxed));
|
||||
conn.lock_handle().await?.remove_preupdate_hook();
|
||||
let _ = sqlx::query("DELETE FROM tweet where id = 6")
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
// Ensure that taking an owned SqliteValue maintains a valid reference after the callback returns
|
||||
assert_eq!(
|
||||
6,
|
||||
<i64 as Decode<Sqlite>>::decode(
|
||||
sqlite_value_stored.lock().unwrap().take().unwrap().as_ref()
|
||||
)
|
||||
.unwrap()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "sqlite-preupdate-hook")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_multiple_set_preupdate_hook_calls_drop_old_handler() -> anyhow::Result<()> {
|
||||
let ref_counted_object = Arc::new(0);
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
{
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_preupdate_hook(move |_| {
|
||||
println!("{o:?}");
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
conn.lock_handle().await?.remove_preupdate_hook();
|
||||
}
|
||||
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user