mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 15:25:32 +00:00
add progress handler support to sqlite (#2256)
* rebase main * fmt * use NonNull to fix UB * apply code suggestions * add test for multiple handler drops * remove nightly features for test
This commit is contained in:
parent
14d70feab1
commit
4f1ac1d606
@ -282,6 +282,7 @@ impl EstablishParams {
|
||||
statements: Statements::new(self.statement_cache_capacity),
|
||||
transaction_depth: 0,
|
||||
log_settings: self.log_settings.clone(),
|
||||
progress_handler_callback: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +1,14 @@
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_intrusive::sync::MutexGuard;
|
||||
use futures_util::future;
|
||||
use libsqlite3_sys::sqlite3;
|
||||
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
|
||||
use sqlx_core::common::StatementCache;
|
||||
use sqlx_core::error::Error;
|
||||
use sqlx_core::transaction::Transaction;
|
||||
use std::cmp::Ordering;
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::os::raw::{c_int, c_void};
|
||||
use std::panic::catch_unwind;
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use crate::connection::establish::EstablishParams;
|
||||
@ -51,6 +53,10 @@ pub struct LockedSqliteHandle<'a> {
|
||||
pub(crate) guard: MutexGuard<'a, ConnectionState>,
|
||||
}
|
||||
|
||||
/// Represents a callback handler that will be shared with the underlying sqlite3 connection.
|
||||
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
|
||||
unsafe impl Send for Handler {}
|
||||
|
||||
pub(crate) struct ConnectionState {
|
||||
pub(crate) handle: ConnectionHandle,
|
||||
|
||||
@ -60,6 +66,22 @@ pub(crate) struct ConnectionState {
|
||||
pub(crate) statements: Statements,
|
||||
|
||||
log_settings: LogSettings,
|
||||
|
||||
/// Stores the progress handler set on the current connection. If the handler returns `false`,
|
||||
/// the query is interrupted.
|
||||
progress_handler_callback: Option<Handler>,
|
||||
}
|
||||
|
||||
impl ConnectionState {
|
||||
/// Drops the `progress_handler_callback` if it exists.
|
||||
pub(crate) fn remove_progress_handler(&mut self) {
|
||||
if let Some(mut handler) = self.progress_handler_callback.take() {
|
||||
unsafe {
|
||||
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
|
||||
let _ = { Box::from_raw(handler.0.as_mut()) };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Statements {
|
||||
@ -177,6 +199,21 @@ impl Connection for SqliteConnection {
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements a C binding to a progress callback. The function returns `0` if the
|
||||
/// user-provided callback returns `true`, and `1` otherwise to signal an interrupt.
|
||||
extern "C" fn progress_callback<F>(callback: *mut c_void) -> c_int
|
||||
where
|
||||
F: FnMut() -> bool,
|
||||
{
|
||||
unsafe {
|
||||
let r = catch_unwind(|| {
|
||||
let callback: *mut F = callback.cast::<F>();
|
||||
(*callback)()
|
||||
});
|
||||
c_int::from(!r.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
impl LockedSqliteHandle<'_> {
|
||||
/// Returns the underlying sqlite3* connection handle.
|
||||
///
|
||||
@ -206,12 +243,52 @@ impl LockedSqliteHandle<'_> {
|
||||
) -> Result<(), Error> {
|
||||
collation::create_collation(&mut self.guard.handle, name, compare)
|
||||
}
|
||||
|
||||
/// Sets a progress handler that is invoked periodically during long running calls. If the progress callback
|
||||
/// returns `false`, then the operation is interrupted.
|
||||
///
|
||||
/// `num_ops` is the approximate number of [virtual machine instructions](https://www.sqlite.org/opcode.html)
|
||||
/// that are evaluated between successive invocations of the callback. If `num_ops` is less than one then the
|
||||
/// progress handler is disabled.
|
||||
///
|
||||
/// Only a single progress handler may be defined at one time per database connection; setting a new progress
|
||||
/// handler cancels the old one.
|
||||
///
|
||||
/// The progress handler callback must not do anything that will modify the database connection that invoked
|
||||
/// the progress handler. Note that sqlite3_prepare_v2() and sqlite3_step() both modify their database connections
|
||||
/// in this context.
|
||||
pub fn set_progress_handler<F>(&mut self, num_ops: i32, mut callback: F)
|
||||
where
|
||||
F: FnMut() -> bool + Send + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let callback_boxed = Box::new(callback);
|
||||
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
|
||||
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
|
||||
let handler = callback.as_ptr() as *mut _;
|
||||
self.guard.remove_progress_handler();
|
||||
self.guard.progress_handler_callback = Some(Handler(callback));
|
||||
|
||||
sqlite3_progress_handler(
|
||||
self.as_raw_handle().as_mut(),
|
||||
num_ops,
|
||||
Some(progress_callback::<F>),
|
||||
handler,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
|
||||
pub fn remove_progress_handler(&mut self) {
|
||||
self.guard.remove_progress_handler();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ConnectionState {
|
||||
fn drop(&mut self) {
|
||||
// explicitly drop statements before the connection handle is dropped
|
||||
self.statements.clear();
|
||||
self.remove_progress_handler();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@ use sqlx::{
|
||||
SqliteConnection, SqlitePool, Statement, TypeInfo,
|
||||
};
|
||||
use sqlx_test::new;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_connects() -> anyhow::Result<()> {
|
||||
@ -725,3 +726,71 @@ async fn concurrent_read_and_write() {
|
||||
read.await;
|
||||
write.await;
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_query_with_progress_handler() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
|
||||
let state = format!("test");
|
||||
conn.lock_handle().await?.set_progress_handler(1, move || {
|
||||
assert_eq!(state, "test");
|
||||
false
|
||||
});
|
||||
|
||||
match sqlx::query("SELECT 'hello' AS title")
|
||||
.fetch_all(&mut conn)
|
||||
.await
|
||||
{
|
||||
Err(sqlx::Error::Database(err)) => assert_eq!(err.message(), String::from("interrupted")),
|
||||
_ => panic!("expected an interrupt"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::Result<()> {
|
||||
let ref_counted_object = Arc::new(0);
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
{
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_progress_handler(1, move || {
|
||||
println!("{:?}", o);
|
||||
false
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_progress_handler(1, move || {
|
||||
println!("{:?}", o);
|
||||
false
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
let o = ref_counted_object.clone();
|
||||
conn.lock_handle().await?.set_progress_handler(1, move || {
|
||||
println!("{:?}", o);
|
||||
false
|
||||
});
|
||||
assert_eq!(2, Arc::strong_count(&ref_counted_object));
|
||||
|
||||
match sqlx::query("SELECT 'hello' AS title")
|
||||
.fetch_all(&mut conn)
|
||||
.await
|
||||
{
|
||||
Err(sqlx::Error::Database(err)) => {
|
||||
assert_eq!(err.message(), String::from("interrupted"))
|
||||
}
|
||||
_ => panic!("expected an interrupt"),
|
||||
}
|
||||
|
||||
conn.lock_handle().await?.remove_progress_handler();
|
||||
}
|
||||
|
||||
assert_eq!(1, Arc::strong_count(&ref_counted_object));
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user