Merge remote-tracking branch 'origin/main' into sqlx-toml

# Conflicts:
#	Cargo.lock
#	Cargo.toml
#	sqlx-cli/src/database.rs
#	sqlx-cli/src/lib.rs
#	sqlx-mysql/src/connection/executor.rs
This commit is contained in:
Austin Bonander
2025-02-27 17:04:34 -08:00
61 changed files with 1984 additions and 995 deletions

View File

@@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"]
regexp = ["dep:regex"]
preupdate-hook = ["libsqlite3-sys/preupdate_hook"]
bundled = ["libsqlite3-sys/bundled"]
unbundled = ["libsqlite3-sys/buildtime_bindgen"]
@@ -48,6 +50,7 @@ atoi = "2.0"
log = "0.4.18"
tracing = { version = "0.1.37", features = ["log"] }
thiserror = "2.0.0"
serde = { version = "1.0.145", features = ["derive"], optional = true }
regex = { version = "1.5.5", optional = true }

View File

@@ -17,6 +17,7 @@ use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
use sqlx_core::transaction::TransactionManager;
use std::pin::pin;
sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Sqlite);
@@ -105,12 +106,12 @@ impl AnyConnectionBackend for SqliteConnection {
let args = arguments.map(map_arguments);
Box::pin(async move {
let stream = self
.worker
.execute(query, args, self.row_channel_size, persistent, Some(1))
.map_ok(flume::Receiver::into_stream)
.await?;
futures_util::pin_mut!(stream);
let mut stream = pin!(
self.worker
.execute(query, args, self.row_channel_size, persistent, Some(1))
.map_ok(flume::Receiver::into_stream)
.await?
);
if let Some(Either::Right(row)) = stream.try_next().await? {
return Ok(Some(AnyRow::try_from(&row)?));

View File

@@ -296,6 +296,8 @@ impl EstablishParams {
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: None,
commit_hook_callback: None,
rollback_hook_callback: None,
})

View File

@@ -8,7 +8,7 @@ use sqlx_core::describe::Describe;
use sqlx_core::error::Error;
use sqlx_core::executor::{Execute, Executor};
use sqlx_core::Either;
use std::future;
use std::{future, pin::pin};
impl<'c> Executor<'c> for &'c mut SqliteConnection {
type Database = Sqlite;
@@ -56,13 +56,11 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
let persistent = query.persistent() && arguments.is_some();
Box::pin(async move {
let stream = self
let mut stream = pin!(self
.worker
.execute(sql, arguments, self.row_channel_size, persistent, Some(1))
.map_ok(flume::Receiver::into_stream)
.try_flatten_stream();
futures_util::pin_mut!(stream);
.try_flatten_stream());
while let Some(res) = stream.try_next().await? {
if let Either::Right(row) = res {

View File

@@ -14,6 +14,8 @@ use libsqlite3_sys::{
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
};
#[cfg(feature = "preupdate-hook")]
pub use preupdate_hook::*;
pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
@@ -26,7 +28,7 @@ use crate::connection::establish::EstablishParams;
use crate::connection::worker::ConnectionWorker;
use crate::options::OptimizeOnClose;
use crate::statement::VirtualStatement;
use crate::{Sqlite, SqliteConnectOptions};
use crate::{Sqlite, SqliteConnectOptions, SqliteError};
pub(crate) mod collation;
pub(crate) mod describe;
@@ -36,6 +38,8 @@ mod executor;
mod explain;
mod handle;
pub(crate) mod intmap;
#[cfg(feature = "preupdate-hook")]
mod preupdate_hook;
mod worker;
@@ -88,6 +92,7 @@ pub struct UpdateHookResult<'a> {
pub table: &'a str,
pub rowid: i64,
}
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}
@@ -112,6 +117,8 @@ pub(crate) struct ConnectionState {
progress_handler_callback: Option<Handler>,
update_hook_callback: Option<UpdateHookHandler>,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: Option<preupdate_hook::PreupdateHookHandler>,
commit_hook_callback: Option<CommitHookHandler>,
@@ -138,6 +145,16 @@ impl ConnectionState {
}
}
#[cfg(feature = "preupdate-hook")]
pub(crate) fn remove_preupdate_hook(&mut self) {
if let Some(mut handler) = self.preupdate_hook_callback.take() {
unsafe {
libsqlite3_sys::sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}
pub(crate) fn remove_commit_hook(&mut self) {
if let Some(mut handler) = self.commit_hook_callback.take() {
unsafe {
@@ -421,6 +438,34 @@ impl LockedSqliteHandle<'_> {
}
}
/// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table.
/// At most one preupdate hook may be registered at a time on a single database connection.
///
/// The preupdate hook only fires for changes to real database tables;
/// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1.
///
/// See https://sqlite.org/c3ref/preupdate_count.html
#[cfg(feature = "preupdate-hook")]
pub fn set_preupdate_hook<F>(&mut self, callback: F)
where
F: FnMut(PreupdateHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_preupdate_hook();
self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));
libsqlite3_sys::sqlite3_preupdate_hook(
self.as_raw_handle().as_mut(),
Some(preupdate_hook::<F>),
handler,
);
}
}
/// Sets a commit hook that is invoked whenever a transaction is committed. If the commit hook callback
/// returns `false`, then the operation is turned into a ROLLBACK.
///
@@ -485,6 +530,11 @@ impl LockedSqliteHandle<'_> {
self.guard.remove_update_hook();
}
#[cfg(feature = "preupdate-hook")]
pub fn remove_preupdate_hook(&mut self) {
self.guard.remove_preupdate_hook();
}
pub fn remove_commit_hook(&mut self) {
self.guard.remove_commit_hook();
}
@@ -492,6 +542,10 @@ impl LockedSqliteHandle<'_> {
pub fn remove_rollback_hook(&mut self) {
self.guard.remove_rollback_hook();
}
pub fn last_error(&mut self) -> Option<SqliteError> {
SqliteError::try_new(self.guard.handle.as_ptr())
}
}
impl Drop for ConnectionState {

View File

@@ -0,0 +1,160 @@
use super::SqliteOperation;
use crate::type_info::DataType;
use crate::{SqliteError, SqliteTypeInfo, SqliteValueRef};
use libsqlite3_sys::{
sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_new,
sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK,
};
use std::ffi::CStr;
use std::marker::PhantomData;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr;
use std::ptr::NonNull;
#[derive(Debug, thiserror::Error)]
pub enum PreupdateError {
/// Error returned from the database.
#[error("error returned from database: {0}")]
Database(#[source] SqliteError),
/// Index is not within the valid column range
#[error("{0} is not within the valid column range")]
ColumnIndexOutOfBounds(i32),
/// Column value accessor was invoked from an invalid operation
#[error("column value accessor was invoked from an invalid operation")]
InvalidOperation,
}
pub(crate) struct PreupdateHookHandler(
pub(super) NonNull<dyn FnMut(PreupdateHookResult) + Send + 'static>,
);
unsafe impl Send for PreupdateHookHandler {}
#[derive(Debug)]
pub struct PreupdateHookResult<'a> {
pub operation: SqliteOperation,
pub database: &'a str,
pub table: &'a str,
db: *mut sqlite3,
// The database pointer should not be usable after the preupdate hook.
// The lifetime on this struct needs to ensure it cannot outlive the callback.
_db_lifetime: PhantomData<&'a ()>,
old_row_id: i64,
new_row_id: i64,
}
impl<'a> PreupdateHookResult<'a> {
/// Gets the amount of columns in the row being inserted, deleted, or updated.
pub fn get_column_count(&self) -> i32 {
unsafe { sqlite3_preupdate_count(self.db) }
}
/// Gets the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { sqlite3_preupdate_depth(self.db) }
}
/// Gets the row id of the row being updated/deleted.
/// Returns an error if called from an insert operation.
pub fn get_old_row_id(&self) -> Result<i64, PreupdateError> {
if self.operation == SqliteOperation::Insert {
return Err(PreupdateError::InvalidOperation);
}
Ok(self.old_row_id)
}
/// Gets the row id of the row being inserted/updated.
/// Returns an error if called from a delete operation.
pub fn get_new_row_id(&self) -> Result<i64, PreupdateError> {
if self.operation == SqliteOperation::Delete {
return Err(PreupdateError::InvalidOperation);
}
Ok(self.new_row_id)
}
/// Gets the value of the row being updated/deleted at the specified index.
/// Returns an error if called from an insert operation or the index is out of bounds.
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
if self.operation == SqliteOperation::Insert {
return Err(PreupdateError::InvalidOperation);
}
self.validate_column_index(i)?;
let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_old(self.db, i, &mut p_value);
self.get_value(ret, p_value)
}
}
/// Gets the value of the row being inserted/updated at the specified index.
/// Returns an error if called from a delete operation or the index is out of bounds.
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValueRef<'a>, PreupdateError> {
if self.operation == SqliteOperation::Delete {
return Err(PreupdateError::InvalidOperation);
}
self.validate_column_index(i)?;
let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_new(self.db, i, &mut p_value);
self.get_value(ret, p_value)
}
}
fn validate_column_index(&self, i: i32) -> Result<(), PreupdateError> {
if i < 0 || i >= self.get_column_count() {
return Err(PreupdateError::ColumnIndexOutOfBounds(i));
}
Ok(())
}
unsafe fn get_value(
&self,
ret: i32,
p_value: *mut sqlite3_value,
) -> Result<SqliteValueRef<'a>, PreupdateError> {
if ret != SQLITE_OK {
return Err(PreupdateError::Database(SqliteError::new(self.db)));
}
let data_type = DataType::from_code(sqlite3_value_type(p_value));
// SAFETY: SQLite will free the sqlite3_value when the callback returns
Ok(SqliteValueRef::borrowed(p_value, SqliteTypeInfo(data_type)))
}
}
pub(super) extern "C" fn preupdate_hook<F>(
callback: *mut c_void,
db: *mut sqlite3,
op_code: c_int,
database: *const c_char,
table: *const c_char,
old_row_id: i64,
new_row_id: i64,
) where
F: FnMut(PreupdateHookResult) + Send + 'static,
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let operation: SqliteOperation = op_code.into();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
let table = CStr::from_ptr(table).to_str().unwrap_or_default();
(*callback)(PreupdateHookResult {
operation,
database,
table,
old_row_id,
new_row_id,
db,
_db_lifetime: PhantomData,
})
});
}
}

View File

@@ -151,7 +151,8 @@ impl ConnectionWorker {
match limit {
None => {
for res in iter {
if tx.send(res).is_err() {
let has_error = res.is_err();
if tx.send(res).is_err() || has_error {
break;
}
}
@@ -171,7 +172,8 @@ impl ConnectionWorker {
}
}
}
if tx.send(res).is_err() {
let has_error = res.is_err();
if tx.send(res).is_err() || has_error {
break;
}
}

View File

@@ -23,9 +23,17 @@ pub struct SqliteError {
impl SqliteError {
pub(crate) fn new(handle: *mut sqlite3) -> Self {
Self::try_new(handle).expect("There should be an error")
}
pub(crate) fn try_new(handle: *mut sqlite3) -> Option<Self> {
// returns the extended result code even when extended result codes are disabled
let code: c_int = unsafe { sqlite3_extended_errcode(handle) };
if code == 0 {
return None;
}
// return English-language text that describes the error
let message = unsafe {
let msg = sqlite3_errmsg(handle);
@@ -34,10 +42,10 @@ impl SqliteError {
from_utf8_unchecked(CStr::from_ptr(msg).to_bytes())
};
Self {
Some(Self {
code,
message: message.to_owned(),
}
})
}
/// For errors during extension load, the error message is supplied via a separate pointer

View File

@@ -46,6 +46,8 @@ use std::sync::atomic::AtomicBool;
pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use column::SqliteColumn;
#[cfg(feature = "preupdate-hook")]
pub use connection::PreupdateHookResult;
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
pub use database::Sqlite;
pub use error::SqliteError;

View File

@@ -30,6 +30,10 @@ impl TestSupport for Sqlite {
) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
todo!()
}
fn db_name(args: &TestArgs) -> String {
convert_path(args.test_path)
}
}
async fn test_context(args: &TestArgs) -> Result<TestContext<Sqlite>, Error> {

View File

@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::slice::from_raw_parts;
use std::str::from_utf8;
@@ -17,6 +18,7 @@ use crate::{Sqlite, SqliteTypeInfo};
enum SqliteValueData<'r> {
Value(&'r SqliteValue),
BorrowedHandle(ValueHandle<'r>),
}
pub struct SqliteValueRef<'r>(SqliteValueData<'r>);
@@ -26,31 +28,44 @@ impl<'r> SqliteValueRef<'r> {
Self(SqliteValueData::Value(value))
}
// SAFETY: The supplied sqlite3_value must not be null and SQLite must free it. It will not be freed on drop.
// The lifetime on this struct should tie it to whatever scope it's valid for before SQLite frees it.
#[allow(unused)]
pub(crate) unsafe fn borrowed(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
debug_assert!(!value.is_null());
let handle = ValueHandle::new_borrowed(NonNull::new_unchecked(value), type_info);
Self(SqliteValueData::BorrowedHandle(handle))
}
// NOTE: `int()` is deliberately omitted because it will silently truncate a wider value,
// which is likely to cause bugs:
// https://github.com/launchbadge/sqlx/issues/3179
// (Similar bug in Postgres): https://github.com/launchbadge/sqlx/issues/3161
pub(super) fn int64(&self) -> i64 {
match self.0 {
SqliteValueData::Value(v) => v.int64(),
match &self.0 {
SqliteValueData::Value(v) => v.0.int64(),
SqliteValueData::BorrowedHandle(v) => v.int64(),
}
}
pub(super) fn double(&self) -> f64 {
match self.0 {
SqliteValueData::Value(v) => v.double(),
match &self.0 {
SqliteValueData::Value(v) => v.0.double(),
SqliteValueData::BorrowedHandle(v) => v.double(),
}
}
pub(super) fn blob(&self) -> &'r [u8] {
match self.0 {
SqliteValueData::Value(v) => v.blob(),
match &self.0 {
SqliteValueData::Value(v) => v.0.blob(),
SqliteValueData::BorrowedHandle(v) => v.blob(),
}
}
pub(super) fn text(&self) -> Result<&'r str, BoxDynError> {
match self.0 {
SqliteValueData::Value(v) => v.text(),
match &self.0 {
SqliteValueData::Value(v) => v.0.text(),
SqliteValueData::BorrowedHandle(v) => v.text(),
}
}
}
@@ -59,50 +74,66 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> {
type Database = Sqlite;
fn to_owned(&self) -> SqliteValue {
match self.0 {
SqliteValueData::Value(v) => v.clone(),
match &self.0 {
SqliteValueData::Value(v) => (*v).clone(),
SqliteValueData::BorrowedHandle(v) => unsafe {
SqliteValue::new(v.value.as_ptr(), v.type_info.clone())
},
}
}
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
match self.0 {
match &self.0 {
SqliteValueData::Value(v) => v.type_info(),
SqliteValueData::BorrowedHandle(v) => v.type_info(),
}
}
fn is_null(&self) -> bool {
match self.0 {
match &self.0 {
SqliteValueData::Value(v) => v.is_null(),
SqliteValueData::BorrowedHandle(v) => v.is_null(),
}
}
}
#[derive(Clone)]
pub struct SqliteValue {
pub(crate) handle: Arc<ValueHandle>,
pub(crate) type_info: SqliteTypeInfo,
pub struct SqliteValue(Arc<ValueHandle<'static>>);
pub(crate) struct ValueHandle<'a> {
value: NonNull<sqlite3_value>,
type_info: SqliteTypeInfo,
free_on_drop: bool,
_sqlite_value_lifetime: PhantomData<&'a ()>,
}
pub(crate) struct ValueHandle(NonNull<sqlite3_value>);
// SAFE: only protected value objects are stored in SqliteValue
unsafe impl Send for ValueHandle {}
unsafe impl Sync for ValueHandle {}
impl SqliteValue {
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
debug_assert!(!value.is_null());
unsafe impl<'a> Send for ValueHandle<'a> {}
unsafe impl<'a> Sync for ValueHandle<'a> {}
impl ValueHandle<'static> {
fn new_owned(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
Self {
value,
type_info,
handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup(
value,
)))),
free_on_drop: true,
_sqlite_value_lifetime: PhantomData,
}
}
}
impl<'a> ValueHandle<'a> {
fn new_borrowed(value: NonNull<sqlite3_value>, type_info: SqliteTypeInfo) -> Self {
Self {
value,
type_info,
free_on_drop: false,
_sqlite_value_lifetime: PhantomData,
}
}
fn type_info_opt(&self) -> Option<SqliteTypeInfo> {
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.handle.0.as_ptr()) });
let dt = DataType::from_code(unsafe { sqlite3_value_type(self.value.as_ptr()) });
if let DataType::Null = dt {
None
@@ -112,15 +143,15 @@ impl SqliteValue {
}
fn int64(&self) -> i64 {
unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) }
unsafe { sqlite3_value_int64(self.value.as_ptr()) }
}
fn double(&self) -> f64 {
unsafe { sqlite3_value_double(self.handle.0.as_ptr()) }
unsafe { sqlite3_value_double(self.value.as_ptr()) }
}
fn blob(&self) -> &[u8] {
let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) };
fn blob<'b>(&self) -> &'b [u8] {
let len = unsafe { sqlite3_value_bytes(self.value.as_ptr()) };
// This likely means UB in SQLite itself or our usage of it;
// signed integer overflow is UB in the C standard.
@@ -133,15 +164,45 @@ impl SqliteValue {
return &[];
}
let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8;
let ptr = unsafe { sqlite3_value_blob(self.value.as_ptr()) } as *const u8;
debug_assert!(!ptr.is_null());
unsafe { from_raw_parts(ptr, len) }
}
fn text(&self) -> Result<&str, BoxDynError> {
fn text<'b>(&self) -> Result<&'b str, BoxDynError> {
Ok(from_utf8(self.blob())?)
}
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
self.type_info_opt()
.map(Cow::Owned)
.unwrap_or(Cow::Borrowed(&self.type_info))
}
fn is_null(&self) -> bool {
unsafe { sqlite3_value_type(self.value.as_ptr()) == SQLITE_NULL }
}
}
impl<'a> Drop for ValueHandle<'a> {
fn drop(&mut self) {
if self.free_on_drop {
unsafe {
sqlite3_value_free(self.value.as_ptr());
}
}
}
}
impl SqliteValue {
// SAFETY: The sqlite3_value must be non-null and SQLite must not free it. It will be freed on drop.
pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self {
debug_assert!(!value.is_null());
let handle =
ValueHandle::new_owned(NonNull::new_unchecked(sqlite3_value_dup(value)), type_info);
Self(Arc::new(handle))
}
}
impl Value for SqliteValue {
@@ -152,21 +213,11 @@ impl Value for SqliteValue {
}
fn type_info(&self) -> Cow<'_, SqliteTypeInfo> {
self.type_info_opt()
.map(Cow::Owned)
.unwrap_or(Cow::Borrowed(&self.type_info))
self.0.type_info()
}
fn is_null(&self) -> bool {
unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL }
}
}
impl Drop for ValueHandle {
fn drop(&mut self) {
unsafe {
sqlite3_value_free(self.0.as_ptr());
}
self.0.is_null()
}
}