diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 17774add..98b42fbc 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -34,6 +34,12 @@ pub enum Error { #[error("error with configuration: {0}")] Configuration(#[source] BoxDynError), + /// One or more of the arguments to the called function was invalid. + /// + /// The string contains more information. + #[error("{0}")] + InvalidArgument(String), + /// Error returned from the database. #[error("error returned from database: {0}")] Database(#[source] Box), @@ -79,7 +85,7 @@ pub enum Error { }, /// Error occured while encoding a value. - #[error("error occured while encoding a value: {0}")] + #[error("error occurred while encoding a value: {0}")] Encode(#[source] BoxDynError), /// Error occurred while decoding a value. @@ -136,6 +142,12 @@ impl Error { Error::Protocol(err.to_string()) } + #[doc(hidden)] + #[inline] + pub fn database(err: impl DatabaseError) -> Self { + Error::Database(Box::new(err)) + } + #[doc(hidden)] #[inline] pub fn config(err: impl StdError + Send + Sync + 'static) -> Self { diff --git a/sqlx-sqlite/src/connection/collation.rs b/sqlx-sqlite/src/connection/collation.rs index 573a9af8..e7422138 100644 --- a/sqlx-sqlite/src/connection/collation.rs +++ b/sqlx-sqlite/src/connection/collation.rs @@ -10,7 +10,6 @@ use libsqlite3_sys::{sqlite3_create_collation_v2, SQLITE_OK, SQLITE_UTF8}; use crate::connection::handle::ConnectionHandle; use crate::error::Error; -use crate::SqliteError; #[derive(Clone)] pub struct Collation { @@ -67,7 +66,7 @@ impl Collation { } else { // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails. drop(unsafe { Arc::from_raw(raw_f) }); - Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + Err(handle.expect_error().into()) } } } @@ -112,7 +111,7 @@ where } else { // The xDestroy callback is not called if the sqlite3_create_collation_v2() function fails. drop(unsafe { Box::from_raw(boxed_f) }); - Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))) + Err(handle.expect_error().into()) } } diff --git a/sqlx-sqlite/src/connection/establish.rs b/sqlx-sqlite/src/connection/establish.rs index 5b8aa01b..334b1616 100644 --- a/sqlx-sqlite/src/connection/establish.rs +++ b/sqlx-sqlite/src/connection/establish.rs @@ -204,10 +204,10 @@ impl EstablishParams { // SAFE: tested for NULL just above // This allows any returns below to close this handle with RAII - let handle = unsafe { ConnectionHandle::new(handle) }; + let mut handle = unsafe { ConnectionHandle::new(handle) }; if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } // Enable extended result codes @@ -226,33 +226,29 @@ impl EstablishParams { for ext in self.extensions.iter() { // `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer // rather than by calling `sqlite3_errmsg` - let mut error = null_mut(); + let mut error_msg = null_mut(); status = unsafe { sqlite3_load_extension( handle.as_ptr(), ext.0.as_ptr(), ext.1.as_ref().map_or(null(), |e| e.as_ptr()), - addr_of_mut!(error), + addr_of_mut!(error_msg), ) }; if status != SQLITE_OK { + let mut e = handle.expect_error(); + // SAFETY: We become responsible for any memory allocation at `&error`, so test // for null and take an RAII version for returns - let err_msg = if !error.is_null() { - unsafe { - let e = CStr::from_ptr(error).into(); - sqlite3_free(error as *mut c_void); - e - } - } else { - CString::new("Unknown error when loading extension") - .expect("text should be representable as a CString") - }; - return Err(Error::Database(Box::new(SqliteError::extension( - handle.as_ptr(), - &err_msg, - )))); + if !error_msg.is_null() { + e = e.with_message(unsafe { + let msg = CStr::from_ptr(error_msg).to_string_lossy().into(); + sqlite3_free(error_msg as *mut c_void); + msg + }); + } + return Err(Error::Database(Box::new(e))); } } // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION // on by disabling the flag again once we've loaded all the requested modules. @@ -271,7 +267,7 @@ impl EstablishParams { // configure a `regexp` function for sqlite, it does not come with one by default let status = crate::regexp::register(handle.as_ptr()); if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } } @@ -286,7 +282,7 @@ impl EstablishParams { status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) }; if status != SQLITE_OK { - return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); + return Err(Error::Database(Box::new(handle.expect_error()))); } Ok(ConnectionState { diff --git a/sqlx-sqlite/src/connection/handle.rs b/sqlx-sqlite/src/connection/handle.rs index aaf5b74e..60fbe17d 100644 --- a/sqlx-sqlite/src/connection/handle.rs +++ b/sqlx-sqlite/src/connection/handle.rs @@ -46,6 +46,17 @@ impl ConnectionHandle { unsafe { sqlite3_last_insert_rowid(self.as_ptr()) } } + pub(crate) fn last_error(&mut self) -> Option { + // SAFETY: we have exclusive access to the database handle + unsafe { SqliteError::try_new(self.as_ptr()) } + } + + #[track_caller] + pub(crate) fn expect_error(&mut self) -> SqliteError { + self.last_error() + .expect("expected error code to be set in current context") + } + pub(crate) fn exec(&mut self, query: impl Into) -> Result<(), Error> { let query = query.into(); let query = CString::new(query).map_err(|_| err_protocol!("query contains nul bytes"))?; diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 7412eef1..3316ad40 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -40,6 +40,7 @@ mod handle; pub(crate) mod intmap; #[cfg(feature = "preupdate-hook")] mod preupdate_hook; +pub(crate) mod serialize; mod worker; @@ -544,7 +545,7 @@ impl LockedSqliteHandle<'_> { } pub fn last_error(&mut self) -> Option { - SqliteError::try_new(self.guard.handle.as_ptr()) + self.guard.handle.last_error() } } diff --git a/sqlx-sqlite/src/connection/serialize.rs b/sqlx-sqlite/src/connection/serialize.rs new file mode 100644 index 00000000..c8835093 --- /dev/null +++ b/sqlx-sqlite/src/connection/serialize.rs @@ -0,0 +1,297 @@ +use super::ConnectionState; +use crate::{error::Error, SqliteConnection, SqliteError}; +use libsqlite3_sys::{ + sqlite3_deserialize, sqlite3_free, sqlite3_malloc64, sqlite3_serialize, + SQLITE_DESERIALIZE_FREEONCLOSE, SQLITE_DESERIALIZE_READONLY, SQLITE_DESERIALIZE_RESIZEABLE, + SQLITE_NOMEM, SQLITE_OK, +}; +use std::ffi::c_char; +use std::fmt::Debug; +use std::{ + ops::{Deref, DerefMut}, + ptr, + ptr::NonNull, +}; + +impl SqliteConnection { + /// Serialize the given SQLite database schema using [`sqlite3_serialize()`]. + /// + /// The returned buffer is a SQLite managed allocation containing the equivalent data + /// as writing the database to disk. It is freed on-drop. + /// + /// To serialize the primary, unqualified schema (`main`), pass `None` for the schema name. + /// + /// # Errors + /// * [`Error::InvalidArgument`] if the schema name contains a zero/NUL byte (`\0`). + /// * [`Error::Database`] if the schema does not exist or another error occurs. + /// + /// [`sqlite3_serialize()`]: https://sqlite.org/c3ref/serialize.html + pub async fn serialize(&mut self, schema: Option<&str>) -> Result { + let schema = schema.map(SchemaName::try_from).transpose()?; + + self.worker.serialize(schema).await + } + + /// Deserialize a SQLite database from a buffer into the specified schema using [`sqlite3_deserialize()`]. + /// + /// The given schema will be disconnected and re-connected as an in-memory database + /// backed by `data`, which should be the serialized form of a database previously returned + /// by a call to [`Self::serialize()`], documented as being equivalent to + /// the contents of the database file on disk. + /// + /// An error will be returned if a schema with the given name is not already attached. + /// You can use `ATTACH ':memory' as ""` to create an empty schema first. + /// + /// Pass `None` to deserialize to the primary, unqualified schema (`main`). + /// + /// The SQLite connection will take ownership of `data` and will free it when the connection + /// is closed or the schema is detached ([`SQLITE_DESERIALIZE_FREEONCLOSE`][deserialize-flags]). + /// + /// If `read_only` is `true`, the schema is opened as read-only ([`SQLITE_DESERIALIZE_READONLY`][deserialize-flags]). + /// If `false`, the schema is marked as resizable ([`SQLITE_DESERIALIZE_RESIZABLE`][deserialize-flags]). + /// + /// If the database is in WAL mode, an error is returned. + /// See [`sqlite3_deserialize()`] for details. + /// + /// # Errors + /// * [`Error::InvalidArgument`] if the schema name contains a zero/NUL byte (`\0`). + /// * [`Error::Database`] if an error occurs during deserialization. + /// + /// [`sqlite3_deserialize()`]: https://sqlite.org/c3ref/deserialize.html + /// [deserialize-flags]: https://sqlite.org/c3ref/c_deserialize_freeonclose.html + pub async fn deserialize( + &mut self, + schema: Option<&str>, + data: SqliteOwnedBuf, + read_only: bool, + ) -> Result<(), Error> { + let schema = schema.map(SchemaName::try_from).transpose()?; + + self.worker.deserialize(schema, data, read_only).await + } +} + +pub(crate) fn serialize( + conn: &mut ConnectionState, + schema: Option, +) -> Result { + let mut size = 0; + + let buf = unsafe { + let ptr = sqlite3_serialize( + conn.handle.as_ptr(), + schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr), + &mut size, + 0, + ); + + // looking at the source, `sqlite3_serialize` actually sets `size = -1` on error: + // https://github.com/sqlite/sqlite/blob/da5f81387843f92652128087a8f8ecef0b79461d/src/memdb.c#L776 + usize::try_from(size) + .ok() + .and_then(|size| SqliteOwnedBuf::from_raw(ptr, size)) + }; + + if let Some(buf) = buf { + return Ok(buf); + } + + if let Some(error) = conn.handle.last_error() { + return Err(error.into()); + } + + if size > 0 { + // If `size` is positive but `sqlite3_serialize` still returned NULL, + // the most likely culprit is an out-of-memory condition. + return Err(SqliteError::from_code(SQLITE_NOMEM).into()); + } + + // Otherwise, the schema was probably not found. + // We return the equivalent error as when you try to execute `PRAGMA .page_count` + // against a non-existent schema. + Err(SqliteError::generic(format!( + "database {} does not exist", + schema.as_ref().map_or("main", SchemaName::as_str) + )) + .into()) +} + +pub(crate) fn deserialize( + conn: &mut ConnectionState, + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, +) -> Result<(), Error> { + // SQLITE_DESERIALIZE_FREEONCLOSE causes SQLite to take ownership of the buffer + let mut flags = SQLITE_DESERIALIZE_FREEONCLOSE; + if read_only { + flags |= SQLITE_DESERIALIZE_READONLY; + } else { + flags |= SQLITE_DESERIALIZE_RESIZEABLE; + } + + let (buf, size) = data.into_raw(); + + let rc = unsafe { + sqlite3_deserialize( + conn.handle.as_ptr(), + schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr), + buf, + i64::try_from(size).unwrap(), + i64::try_from(size).unwrap(), + flags, + ) + }; + + match rc { + SQLITE_OK => Ok(()), + SQLITE_NOMEM => Err(SqliteError::from_code(SQLITE_NOMEM).into()), + // SQLite unfortunately doesn't set any specific message for deserialization errors. + _ => Err(SqliteError::generic("an error occurred during deserialization").into()), + } +} + +/// Memory buffer owned and allocated by SQLite. Freed on drop. +/// +/// Intended primarily for use with [`SqliteConnection::serialize()`] and [`SqliteConnection::deserialize()`]. +/// +/// Can be created from `&[u8]` using the `TryFrom` impl. The slice must not be empty. +#[derive(Debug)] +pub struct SqliteOwnedBuf { + ptr: NonNull, + size: usize, +} + +unsafe impl Send for SqliteOwnedBuf {} +unsafe impl Sync for SqliteOwnedBuf {} + +impl Drop for SqliteOwnedBuf { + fn drop(&mut self) { + unsafe { + sqlite3_free(self.ptr.as_ptr().cast()); + } + } +} + +impl SqliteOwnedBuf { + /// Uses `sqlite3_malloc` to allocate a buffer and returns a pointer to it. + /// + /// # Safety + /// The allocated buffer is uninitialized. + unsafe fn with_capacity(size: usize) -> Option { + let ptr = sqlite3_malloc64(u64::try_from(size).unwrap()).cast::(); + Self::from_raw(ptr, size) + } + + /// Creates a new mem buffer from a pointer that has been created with sqlite_malloc + /// + /// # Safety: + /// * The pointer must point to a valid allocation created by `sqlite3_malloc()`, or `NULL`. + unsafe fn from_raw(ptr: *mut u8, size: usize) -> Option { + Some(Self { + ptr: NonNull::new(ptr)?, + size, + }) + } + + fn into_raw(self) -> (*mut u8, usize) { + let raw = (self.ptr.as_ptr(), self.size); + // this is used in sqlite_deserialize and + // underlying buffer must not be freed + std::mem::forget(self); + raw + } +} + +/// # Errors +/// Returns [`Error::InvalidArgument`] if the slice is empty. +impl TryFrom<&[u8]> for SqliteOwnedBuf { + type Error = Error; + + fn try_from(bytes: &[u8]) -> Result { + unsafe { + // SAFETY: `buf` is not initialized until `ptr::copy_nonoverlapping` completes. + let mut buf = Self::with_capacity(bytes.len()).ok_or_else(|| { + Error::InvalidArgument("SQLite owned buffer cannot be empty".to_string()) + })?; + ptr::copy_nonoverlapping(bytes.as_ptr(), buf.ptr.as_mut(), buf.size); + Ok(buf) + } + } +} + +impl Deref for SqliteOwnedBuf { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) } + } +} + +impl DerefMut for SqliteOwnedBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut(), self.size) } + } +} + +impl AsRef<[u8]> for SqliteOwnedBuf { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for SqliteOwnedBuf { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +/// Checked schema name to pass to SQLite. +/// +/// # Safety: +/// * Valid UTF-8 (not guaranteed by `CString`) +/// * No internal zero bytes (`\0`) (not guaranteed by `String`) +/// * Terminated with a zero byte (`\0`) (not guaranteed by `String`) +#[derive(Debug)] +pub(crate) struct SchemaName(Box); + +impl SchemaName { + /// Get the schema name as a string without the zero byte terminator. + pub fn as_str(&self) -> &str { + &self.0[..self.0.len() - 1] + } + + /// Get a pointer to the string data, suitable for passing as C's `*const char`. + /// + /// # Safety + /// The string data is guaranteed to be terminated with a zero byte. + pub fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() as *const c_char + } +} + +impl<'a> TryFrom<&'a str> for SchemaName { + type Error = Error; + + fn try_from(name: &'a str) -> Result { + // SAFETY: we must ensure that the string does not contain an internal NULL byte + if let Some(pos) = name.as_bytes().iter().position(|&b| b == 0) { + return Err(Error::InvalidArgument(format!( + "schema name {name:?} contains a zero byte at index {pos}" + ))); + } + + let capacity = name.len().checked_add(1).unwrap(); + + let mut s = String::new(); + // `String::with_capacity()` does not guarantee that it will not overallocate, + // which might mean an unnecessary reallocation to make `capacity == len` + // in the conversion to `Box`. + s.reserve_exact(capacity); + + s.push_str(name); + s.push('\0'); + + Ok(SchemaName(s.into())) + } +} diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index c1c67636..8a1d140b 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -21,6 +21,8 @@ use crate::connection::execute; use crate::connection::ConnectionState; use crate::{Sqlite, SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement}; +use super::serialize::{deserialize, serialize, SchemaName, SqliteOwnedBuf}; + // Each SQLite connection has a dedicated thread. // TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce @@ -54,6 +56,16 @@ enum Command { tx: flume::Sender, Error>>, limit: Option, }, + Serialize { + schema: Option, + tx: oneshot::Sender>, + }, + Deserialize { + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, + tx: oneshot::Sender>, + }, Begin { tx: rendezvous_oneshot::Sender>, }, @@ -263,6 +275,12 @@ impl ConnectionWorker { } } } + Command::Serialize { schema, tx } => { + tx.send(serialize(&mut conn, schema)).ok(); + } + Command::Deserialize { schema, data, read_only, tx } => { + tx.send(deserialize(&mut conn, schema, data, read_only)).ok(); + } Command::ClearCache { tx } => { conn.statements.clear(); update_cached_statements_size(&conn, &shared.cached_statements_size); @@ -358,6 +376,29 @@ impl ConnectionWorker { self.oneshot_cmd(|tx| Command::Ping { tx }).await } + pub(crate) async fn deserialize( + &mut self, + schema: Option, + data: SqliteOwnedBuf, + read_only: bool, + ) -> Result<(), Error> { + self.oneshot_cmd(|tx| Command::Deserialize { + schema, + data, + read_only, + tx, + }) + .await? + } + + pub(crate) async fn serialize( + &mut self, + schema: Option, + ) -> Result { + self.oneshot_cmd(|tx| Command::Serialize { schema, tx }) + .await? + } + async fn oneshot_cmd(&mut self, command: F) -> Result where F: FnOnce(oneshot::Sender) -> Command, diff --git a/sqlx-sqlite/src/error.rs b/sqlx-sqlite/src/error.rs index 0d34bc10..eee2e8b1 100644 --- a/sqlx-sqlite/src/error.rs +++ b/sqlx-sqlite/src/error.rs @@ -2,12 +2,12 @@ use std::error::Error as StdError; use std::ffi::CStr; use std::fmt::{self, Display, Formatter}; use std::os::raw::c_int; -use std::{borrow::Cow, str::from_utf8_unchecked}; +use std::{borrow::Cow, str}; use libsqlite3_sys::{ - sqlite3, sqlite3_errmsg, sqlite3_extended_errcode, SQLITE_CONSTRAINT_CHECK, + sqlite3, sqlite3_errmsg, sqlite3_errstr, sqlite3_extended_errcode, SQLITE_CONSTRAINT_CHECK, SQLITE_CONSTRAINT_FOREIGNKEY, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, - SQLITE_CONSTRAINT_UNIQUE, + SQLITE_CONSTRAINT_UNIQUE, SQLITE_ERROR, }; pub(crate) use sqlx_core::error::*; @@ -18,15 +18,15 @@ pub(crate) use sqlx_core::error::*; #[derive(Debug)] pub struct SqliteError { code: c_int, - message: String, + message: Cow<'static, str>, } impl SqliteError { - pub(crate) fn new(handle: *mut sqlite3) -> Self { + pub(crate) unsafe fn new(handle: *mut sqlite3) -> Self { Self::try_new(handle).expect("There should be an error") } - pub(crate) fn try_new(handle: *mut sqlite3) -> Option { + pub(crate) unsafe fn try_new(handle: *mut sqlite3) -> Option { // returns the extended result code even when extended result codes are disabled let code: c_int = unsafe { sqlite3_extended_errcode(handle) }; @@ -39,20 +39,44 @@ impl SqliteError { let msg = sqlite3_errmsg(handle); debug_assert!(!msg.is_null()); - from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()) + str::from_utf8_unchecked(CStr::from_ptr(msg).to_bytes()).to_owned() }; Some(Self { code, - message: message.to_owned(), + message: message.into(), }) } /// For errors during extension load, the error message is supplied via a separate pointer - pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self { - let mut err = Self::new(handle); - err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() }; - err + pub(crate) fn with_message(mut self, error_msg: String) -> Self { + self.message = error_msg.into(); + self + } + + pub(crate) fn from_code(code: c_int) -> Self { + let message = unsafe { + let errstr = sqlite3_errstr(code); + + if !errstr.is_null() { + // SAFETY: `errstr` is guaranteed to be UTF-8 + // The lifetime of the string is "internally managed"; + // the implementation just selects from an array of static strings. + // We copy to an owned buffer in case `libsqlite3` is dynamically loaded somehow. + Cow::Owned(str::from_utf8_unchecked(CStr::from_ptr(errstr).to_bytes()).into()) + } else { + Cow::Borrowed("") + } + }; + + SqliteError { code, message } + } + + pub(crate) fn generic(message: impl Into>) -> Self { + Self { + code: SQLITE_ERROR, + message: message.into(), + } } } diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index f1a45c3d..e4a122b6 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -46,6 +46,7 @@ use std::sync::atomic::AtomicBool; pub use arguments::{SqliteArgumentValue, SqliteArguments}; pub use column::SqliteColumn; +pub use connection::serialize::SqliteOwnedBuf; #[cfg(feature = "preupdate-hook")] pub use connection::PreupdateHookResult; pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult}; diff --git a/sqlx-sqlite/src/statement/handle.rs b/sqlx-sqlite/src/statement/handle.rs index 2925d1a1..ccc299fc 100644 --- a/sqlx-sqlite/src/statement/handle.rs +++ b/sqlx-sqlite/src/statement/handle.rs @@ -81,8 +81,8 @@ impl StatementHandle { } #[inline] - pub(crate) fn last_error(&self) -> SqliteError { - SqliteError::new(unsafe { self.db_handle() }) + pub(crate) fn last_error(&mut self) -> SqliteError { + unsafe { SqliteError::new(self.db_handle()) } } #[inline] diff --git a/sqlx-sqlite/src/statement/virtual.rs b/sqlx-sqlite/src/statement/virtual.rs index 6be980c3..2817146b 100644 --- a/sqlx-sqlite/src/statement/virtual.rs +++ b/sqlx-sqlite/src/statement/virtual.rs @@ -184,7 +184,7 @@ fn prepare( }; if status != SQLITE_OK { - return Err(SqliteError::new(conn).into()); + return Err(unsafe { SqliteError::new(conn).into() }); } // tail should point to the first byte past the end of the first SQL diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 16b4b2d9..92a11387 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -2,12 +2,10 @@ use futures::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; -use sqlx::Decode; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, }; -use sqlx::{Value, ValueRef}; use sqlx_test::new; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -271,7 +269,7 @@ async fn it_handles_empty_queries() -> anyhow::Result<()> { } #[sqlx_macros::test] -fn it_binds_parameters() -> anyhow::Result<()> { +async fn it_binds_parameters() -> anyhow::Result<()> { let mut conn = new::().await?; let v: i32 = sqlx::query_scalar("SELECT ?") @@ -293,7 +291,7 @@ fn it_binds_parameters() -> anyhow::Result<()> { } #[sqlx_macros::test] -fn it_binds_dollar_parameters() -> anyhow::Result<()> { +async fn it_binds_dollar_parameters() -> anyhow::Result<()> { let mut conn = new::().await?; let v: (i32, i32) = sqlx::query_as("SELECT $1, $2") @@ -973,6 +971,8 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { + use sqlx::Decode; + let mut conn = new::().await?; static CALLED: AtomicBool = AtomicBool::new(false); // Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer. @@ -1021,6 +1021,8 @@ async fn test_query_with_preupdate_hook_insert() -> anyhow::Result<()> { #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { + use sqlx::Decode; + let mut conn = new::().await?; let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 5, 'Hello, World' )") .execute(&mut conn) @@ -1064,6 +1066,9 @@ async fn test_query_with_preupdate_hook_delete() -> anyhow::Result<()> { #[cfg(feature = "sqlite-preupdate-hook")] #[sqlx_macros::test] async fn test_query_with_preupdate_hook_update() -> anyhow::Result<()> { + use sqlx::Decode; + use sqlx::{Value, ValueRef}; + let mut conn = new::().await?; let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 6, 'Hello, World' )") .execute(&mut conn) @@ -1193,3 +1198,121 @@ async fn test_get_last_error() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn test_serialize_deserialize() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + sqlx::raw_sql("create table foo(bar integer not null, baz text not null)") + .execute(&mut conn) + .await?; + + sqlx::query("insert into foo(bar, baz) values (1234, 'Lorem ipsum'), (5678, 'dolor sit amet')") + .execute(&mut conn) + .await?; + + let serialized = conn.serialize(None).await?; + + // Close and open a new connection to ensure cleanliness. + conn.close().await?; + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + conn.deserialize(None, serialized, false).await?; + + let rows = sqlx::query_as::<_, (i32, String)>("select bar, baz from foo") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + + assert_eq!(rows[0].0, 1234); + assert_eq!(rows[0].1, "Lorem ipsum"); + + assert_eq!(rows[1].0, 5678); + assert_eq!(rows[1].1, "dolor sit amet"); + + Ok(()) +} +#[sqlx_macros::test] +async fn test_serialize_deserialize_with_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + sqlx::raw_sql( + "attach ':memory:' as foo; create table foo.foo(bar integer not null, baz text not null)", + ) + .execute(&mut conn) + .await?; + + sqlx::query( + "insert into foo.foo(bar, baz) values (1234, 'Lorem ipsum'), (5678, 'dolor sit amet')", + ) + .execute(&mut conn) + .await?; + + let serialized = conn.serialize(Some("foo")).await?; + + // Close and open a new connection to ensure cleanliness. + conn.close().await?; + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + // Unexpected quirk: the schema must exist before deserialization. + sqlx::raw_sql("attach ':memory:' as foo") + .execute(&mut conn) + .await?; + + conn.deserialize(Some("foo"), serialized, false).await?; + + let rows = sqlx::query_as::<_, (i32, String)>("select bar, baz from foo.foo") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + + assert_eq!(rows[0].0, 1234); + assert_eq!(rows[0].1, "Lorem ipsum"); + + assert_eq!(rows[1].0, 5678); + assert_eq!(rows[1].1, "dolor sit amet"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_serialize_nonexistent_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + let err = conn + .serialize(Some("foobar")) + .await + .expect_err("an error should have been returned"); + + let sqlx::Error::Database(dbe) = err else { + panic!("expected DatabaseError: {err:?}") + }; + + assert_eq!(dbe.code().as_deref(), Some("1")); + assert_eq!(dbe.message(), "database foobar does not exist"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_serialize_invalid_schema() -> anyhow::Result<()> { + let mut conn = SqliteConnection::connect("sqlite::memory:").await?; + + let err = conn + .serialize(Some("foo\0bar")) + .await + .expect_err("an error should have been returned"); + + let sqlx::Error::InvalidArgument(msg) = err else { + panic!("expected InvalidArgument: {err:?}") + }; + + assert_eq!( + msg, + "schema name \"foo\\0bar\" contains a zero byte at index 3" + ); + + Ok(()) +}