use crate::connection::handle::ConnectionHandle; use crate::connection::LogSettings; use crate::connection::{ConnectionState, Statements}; use crate::error::Error; use crate::{SqliteConnectOptions, SqliteError}; use libsqlite3_sys::{ sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free, sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, }; use percent_encoding::NON_ALPHANUMERIC; use sqlx_core::IndexMap; use std::collections::BTreeMap; use std::ffi::{c_void, CStr, CString}; use std::io; use std::os::raw::c_int; use std::ptr::{addr_of_mut, null, null_mut}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; // This was originally `AtomicU64` but that's not supported on MIPS (or PowerPC): // https://github.com/launchbadge/sqlx/issues/2859 // https://doc.rust-lang.org/stable/std/sync/atomic/index.html#portability static THREAD_ID: AtomicUsize = AtomicUsize::new(0); enum SqliteLoadExtensionMode { /// Enables only the C-API, leaving the SQL function disabled. Enable, /// Disables both the C-API and the SQL function. DisableAll, } impl SqliteLoadExtensionMode { fn as_int(self) -> c_int { match self { SqliteLoadExtensionMode::Enable => 1, SqliteLoadExtensionMode::DisableAll => 0, } } } pub struct EstablishParams { filename: CString, open_flags: i32, busy_timeout: Duration, statement_cache_capacity: usize, log_settings: LogSettings, extensions: IndexMap>, pub(crate) thread_name: String, pub(crate) command_channel_size: usize, #[cfg(feature = "regexp")] register_regexp_function: bool, } impl EstablishParams { pub fn from_options(options: &SqliteConnectOptions) -> Result { let mut filename = options .filename .to_str() .ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidData, "filename passed to SQLite must be valid UTF-8", ) })? .to_owned(); // By default, we connect to an in-memory database. // [SQLITE_OPEN_NOMUTEX] will instruct [sqlite3_open_v2] to return an error if it // cannot satisfy our wish for a thread-safe, lock-free connection object let mut flags = if options.serialized { SQLITE_OPEN_FULLMUTEX } else { SQLITE_OPEN_NOMUTEX }; flags |= if options.read_only { SQLITE_OPEN_READONLY } else if options.create_if_missing { SQLITE_OPEN_CREATE | SQLITE_OPEN_READWRITE } else { SQLITE_OPEN_READWRITE }; if options.in_memory { flags |= SQLITE_OPEN_MEMORY; } flags |= if options.shared_cache { SQLITE_OPEN_SHAREDCACHE } else { SQLITE_OPEN_PRIVATECACHE }; let mut query_params = BTreeMap::new(); if options.immutable { query_params.insert("immutable", "true"); } if let Some(vfs) = options.vfs.as_deref() { query_params.insert("vfs", &vfs); } if !query_params.is_empty() { filename = format!( "file:{}?{}", percent_encoding::percent_encode(filename.as_bytes(), &NON_ALPHANUMERIC), serde_urlencoded::to_string(&query_params).unwrap() ); flags |= libsqlite3_sys::SQLITE_OPEN_URI; } let filename = CString::new(filename).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, "filename passed to SQLite must not contain nul bytes", ) })?; let extensions = options .extensions .iter() .map(|(name, entry)| { let entry = entry .as_ref() .map(|e| { CString::new(e.as_bytes()).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, "extension entrypoint names passed to SQLite must not contain nul bytes" ) }) }) .transpose()?; Ok(( CString::new(name.as_bytes()).map_err(|_| { io::Error::new( io::ErrorKind::InvalidData, "extension names passed to SQLite must not contain nul bytes", ) })?, entry, )) }) .collect::>, io::Error>>()?; let thread_id = THREAD_ID.fetch_add(1, Ordering::AcqRel); Ok(Self { filename, open_flags: flags, busy_timeout: options.busy_timeout, statement_cache_capacity: options.statement_cache_capacity, log_settings: options.log_settings.clone(), extensions, thread_name: (options.thread_name)(thread_id as u64), command_channel_size: options.command_channel_size, #[cfg(feature = "regexp")] register_regexp_function: options.register_regexp_function, }) } // Enable extension loading via the db_config function, as recommended by the docs rather // than the more obvious `sqlite3_enable_load_extension` // https://www.sqlite.org/c3ref/db_config.html // https://www.sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension unsafe fn sqlite3_set_load_extension( db: *mut sqlite3, mode: SqliteLoadExtensionMode, ) -> Result<(), Error> { let status = sqlite3_db_config( db, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, mode.as_int(), null::(), ); if status != SQLITE_OK { return Err(Error::Database(Box::new(SqliteError::new(db)))); } Ok(()) } pub(crate) fn establish(&self) -> Result { let mut handle = null_mut(); // let mut status = unsafe { sqlite3_open_v2(self.filename.as_ptr(), &mut handle, self.open_flags, null()) }; if handle.is_null() { // Failed to allocate memory return Err(Error::Io(io::Error::new( io::ErrorKind::OutOfMemory, "SQLite is unable to allocate memory to hold the sqlite3 object", ))); } // SAFE: tested for NULL just above // This allows any returns below to close this handle with RAII let handle = unsafe { ConnectionHandle::new(handle) }; if status != SQLITE_OK { return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); } // Enable extended result codes // https://www.sqlite.org/c3ref/extended_result_codes.html unsafe { // NOTE: ignore the failure here sqlite3_extended_result_codes(handle.as_ptr(), 1); } if !self.extensions.is_empty() { // Enable loading extensions unsafe { Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?; } for ext in self.extensions.iter() { // `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer // rather than by calling `sqlite3_errmsg` let mut error = null_mut(); status = unsafe { sqlite3_load_extension( handle.as_ptr(), ext.0.as_ptr(), ext.1.as_ref().map_or(null(), |e| e.as_ptr()), addr_of_mut!(error), ) }; if status != SQLITE_OK { // SAFETY: We become responsible for any memory allocation at `&error`, so test // for null and take an RAII version for returns let err_msg = if !error.is_null() { unsafe { let e = CStr::from_ptr(error).into(); sqlite3_free(error as *mut c_void); e } } else { CString::new("Unknown error when loading extension") .expect("text should be representable as a CString") }; return Err(Error::Database(Box::new(SqliteError::extension( handle.as_ptr(), &err_msg, )))); } } // Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION // on by disabling the flag again once we've loaded all the requested modules. // Fail-fast (via `?`) if disabling the extension loader didn't work for some reason, // avoids an unexpected state going undetected. unsafe { Self::sqlite3_set_load_extension( handle.as_ptr(), SqliteLoadExtensionMode::DisableAll, )?; } } #[cfg(feature = "regexp")] if self.register_regexp_function { // configure a `regexp` function for sqlite, it does not come with one by default let status = crate::regexp::register(handle.as_ptr()); if status != SQLITE_OK { return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); } } // Configure a busy timeout // This causes SQLite to automatically sleep in increasing intervals until the time // when there is something locked during [sqlite3_step]. // // We also need to convert the u128 value to i32, checking we're not overflowing. let ms = i32::try_from(self.busy_timeout.as_millis()) .expect("Given busy timeout value is too big."); status = unsafe { sqlite3_busy_timeout(handle.as_ptr(), ms) }; if status != SQLITE_OK { return Err(Error::Database(Box::new(SqliteError::new(handle.as_ptr())))); } Ok(ConnectionState { handle, statements: Statements::new(self.statement_cache_capacity), transaction_depth: 0, log_settings: self.log_settings.clone(), progress_handler_callback: None, }) } }