From cd6735b5d787aa0daf51d62c61e0985e7b0123cd Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Wed, 8 Apr 2020 02:13:37 -0700 Subject: [PATCH] fix(sqlite): handle empty statements, fixes #231 --- sqlx-core/src/sqlite/arguments.rs | 36 ++++----- sqlx-core/src/sqlite/statement.rs | 125 +++++++++++++++++++----------- sqlx-core/src/sqlite/value.rs | 57 ++++++++++---- 3 files changed, 138 insertions(+), 80 deletions(-) diff --git a/sqlx-core/src/sqlite/arguments.rs b/sqlx-core/src/sqlite/arguments.rs index eb0a75d4..b2b6bbba 100644 --- a/sqlx-core/src/sqlite/arguments.rs +++ b/sqlx-core/src/sqlite/arguments.rs @@ -71,6 +71,16 @@ impl Arguments for SqliteArguments { impl SqliteArgumentValue { pub(super) fn bind(&self, statement: &mut Statement, index: usize) -> crate::Result<()> { + let handle = unsafe { + if let Some(handle) = statement.handle() { + handle + } else { + // drop all requested bindings for a null/empty statement + // note that this _should_ not happen as argument size for a null statement should be zero + return Ok(()); + } + }; + // TODO: Handle error of trying to bind too many parameters here let index = index as c_int; @@ -83,13 +93,7 @@ impl SqliteArgumentValue { let bytes_len = bytes.len() as i32; unsafe { - sqlite3_bind_blob( - statement.handle(), - index, - bytes_ptr, - bytes_len, - SQLITE_TRANSIENT(), - ) + sqlite3_bind_blob(handle, index, bytes_ptr, bytes_len, SQLITE_TRANSIENT()) } } @@ -100,29 +104,21 @@ impl SqliteArgumentValue { let bytes_len = bytes.len() as i32; unsafe { - sqlite3_bind_text( - statement.handle(), - index, - bytes_ptr, - bytes_len, - SQLITE_TRANSIENT(), - ) + sqlite3_bind_text(handle, index, bytes_ptr, bytes_len, SQLITE_TRANSIENT()) } } SqliteArgumentValue::Double(value) => unsafe { - sqlite3_bind_double(statement.handle(), index, *value) + sqlite3_bind_double(handle, index, *value) }, - SqliteArgumentValue::Int(value) => unsafe { - sqlite3_bind_int(statement.handle(), index, *value) - }, + SqliteArgumentValue::Int(value) => unsafe { sqlite3_bind_int(handle, index, *value) }, SqliteArgumentValue::Int64(value) => unsafe { - sqlite3_bind_int64(statement.handle(), index, *value) + sqlite3_bind_int64(handle, index, *value) }, - SqliteArgumentValue::Null => unsafe { sqlite3_bind_null(statement.handle(), index) }, + SqliteArgumentValue::Null => unsafe { sqlite3_bind_null(handle, index) }, }; if status != SQLITE_OK { diff --git a/sqlx-core/src/sqlite/statement.rs b/sqlx-core/src/sqlite/statement.rs index d33383de..37b1b69a 100644 --- a/sqlx-core/src/sqlite/statement.rs +++ b/sqlx-core/src/sqlite/statement.rs @@ -37,7 +37,7 @@ pub(super) struct SqliteStatementHandle(NonNull); /// /// The statement is finalized ( `sqlite3_finalize` ) on drop. pub(super) struct Statement { - handle: SqliteStatementHandle, + handle: Option, pub(super) connection: SqliteConnectionHandle, pub(super) worker: Worker, pub(super) tail: usize, @@ -94,7 +94,7 @@ impl Statement { let mut self_ = Self { worker: conn.worker.clone(), connection: conn.handle, - handle: SqliteStatementHandle(NonNull::new(statement_handle).unwrap()), + handle: NonNull::new(statement_handle).map(SqliteStatementHandle), columns: HashMap::new(), tail, }; @@ -113,8 +113,8 @@ impl Statement { /// Returns a pointer to the raw C pointer backing this statement. #[inline] - pub(super) unsafe fn handle(&self) -> *mut sqlite3_stmt { - self.handle.0.as_ptr() + pub(super) unsafe fn handle(&self) -> Option<*mut sqlite3_stmt> { + self.handle.map(|handle| handle.0.as_ptr()) } pub(super) fn data_count(&mut self) -> usize { @@ -126,43 +126,59 @@ impl Statement { // The value is correct only if there was a recent call to // sqlite3_step that returned SQLITE_ROW. - let count: c_int = unsafe { sqlite3_data_count(self.handle()) }; - count as usize + unsafe { self.handle().map_or(0, |handle| sqlite3_data_count(handle)) as usize } } pub(super) fn column_count(&mut self) -> usize { // https://sqlite.org/c3ref/column_count.html - let count = unsafe { sqlite3_column_count(self.handle()) }; - count as usize + unsafe { + self.handle() + .map_or(0, |handle| sqlite3_column_count(handle)) as usize + } } pub(super) fn column_name(&mut self, index: usize) -> &str { - // https://sqlite.org/c3ref/column_name.html - let name = unsafe { - let ptr = sqlite3_column_name(self.handle(), index as c_int); - debug_assert!(!ptr.is_null()); + unsafe { + self.handle() + .map(|handle| { + // https://sqlite.org/c3ref/column_name.html + let ptr = sqlite3_column_name(handle, index as c_int); + debug_assert!(!ptr.is_null()); - CStr::from_ptr(ptr) - }; - - name.to_str().unwrap() + CStr::from_ptr(ptr) + }) + .map_or(Ok(""), |name| name.to_str()) + .unwrap() + } } pub(super) fn column_decltype(&mut self, index: usize) -> Option<&str> { - let name = unsafe { - let ptr = sqlite3_column_decltype(self.handle(), index as c_int); + unsafe { + self.handle() + .and_then(|handle| { + let ptr = sqlite3_column_decltype(handle, index as c_int); - if ptr.is_null() { - None - } else { - Some(CStr::from_ptr(ptr)) - } - }; - - name.map(|s| s.to_str().unwrap()) + if ptr.is_null() { + None + } else { + Some(CStr::from_ptr(ptr)) + } + }) + .map(|name| name.to_str().unwrap()) + } } pub(super) fn column_not_null(&mut self, index: usize) -> crate::Result> { + let handle = unsafe { + if let Some(handle) = self.handle() { + handle + } else { + // we do not know the nullablility of a column that doesn't exist on a statement + // that doesn't exist + return Ok(None); + } + }; + unsafe { // https://sqlite.org/c3ref/column_database_name.html // @@ -171,9 +187,9 @@ impl Statement { // sqlite3_finalize() or until the statement is automatically reprepared by the // first call to sqlite3_step() for a particular run or until the same information // is requested again in a different encoding. - let db_name = sqlite3_column_database_name(self.handle(), index as c_int); - let table_name = sqlite3_column_table_name(self.handle(), index as c_int); - let origin_name = sqlite3_column_origin_name(self.handle(), index as c_int); + let db_name = sqlite3_column_database_name(handle, index as c_int); + let table_name = sqlite3_column_table_name(handle, index as c_int); + let origin_name = sqlite3_column_origin_name(handle, index as c_int); if db_name.is_null() || table_name.is_null() || origin_name.is_null() { return Ok(None); @@ -213,8 +229,10 @@ impl Statement { pub(super) fn params(&mut self) -> usize { // https://www.hwaci.com/sw/sqlite/c3ref/bind_parameter_count.html - let num = unsafe { sqlite3_bind_parameter_count(self.handle()) }; - num as usize + unsafe { + self.handle() + .map_or(0, |handle| sqlite3_bind_parameter_count(handle)) as usize + } } pub(super) fn bind(&mut self, arguments: &mut SqliteArguments) -> crate::Result<()> { @@ -230,36 +248,47 @@ impl Statement { } pub(super) fn reset(&mut self) { + let handle = unsafe { + if let Some(handle) = self.handle() { + handle + } else { + // nothing to reset if its null + return; + } + }; + // https://sqlite.org/c3ref/reset.html // https://sqlite.org/c3ref/clear_bindings.html // the status value of reset is ignored because it merely propagates // the status of the most recently invoked step function - let _ = unsafe { sqlite3_reset(self.handle()) }; - - let _ = unsafe { sqlite3_clear_bindings(self.handle()) }; + let _ = unsafe { sqlite3_reset(handle) }; + let _ = unsafe { sqlite3_clear_bindings(handle) }; } pub(super) async fn step(&mut self) -> crate::Result { // https://sqlite.org/c3ref/step.html - let handle = self.handle; + if let Some(handle) = self.handle { + let status = unsafe { + self.worker + .run(move || sqlite3_step(handle.0.as_ptr())) + .await + }; - let status = unsafe { - self.worker - .run(move || sqlite3_step(handle.0.as_ptr())) - .await - }; + match status { + SQLITE_DONE => Ok(Step::Done), - match status { - SQLITE_DONE => Ok(Step::Done), + SQLITE_ROW => Ok(Step::Row), - SQLITE_ROW => Ok(Step::Row), - - _ => { - return Err(SqliteError::from_connection(self.connection.0.as_ptr()).into()); + _ => { + return Err(SqliteError::from_connection(self.connection.0.as_ptr()).into()); + } } + } else { + // An empty (null) query will always emit `Step::Done` + Ok(Step::Done) } } } @@ -268,7 +297,9 @@ impl Drop for Statement { fn drop(&mut self) { // https://sqlite.org/c3ref/finalize.html unsafe { - let _ = sqlite3_finalize(self.handle()); + if let Some(handle) = self.handle() { + let _ = sqlite3_finalize(handle); + } } } } diff --git a/sqlx-core/src/sqlite/value.rs b/sqlx-core/src/sqlite/value.rs index 4086d5fd..961127f1 100644 --- a/sqlx-core/src/sqlite/value.rs +++ b/sqlx-core/src/sqlite/value.rs @@ -31,7 +31,14 @@ impl<'c> SqliteValue<'c> { } fn r#type(&self) -> Option { - let type_code = unsafe { sqlite3_column_type(self.statement.handle(), self.index) }; + let type_code = unsafe { + if let Some(handle) = self.statement.handle() { + sqlite3_column_type(handle, self.index) + } else { + // unreachable: null statements do not have any values to type + return None; + } + }; // SQLITE_INTEGER, SQLITE_FLOAT, SQLITE_TEXT, SQLITE_BLOB, or SQLITE_NULL match type_code { @@ -47,41 +54,65 @@ impl<'c> SqliteValue<'c> { /// Returns the 32-bit INTEGER result. pub(super) fn int(&self) -> i32 { - unsafe { sqlite3_column_int(self.statement.handle(), self.index) } + unsafe { + self.statement + .handle() + .map_or(0, |handle| sqlite3_column_int(handle, self.index)) + } } /// Returns the 64-bit INTEGER result. pub(super) fn int64(&self) -> i64 { - unsafe { sqlite3_column_int64(self.statement.handle(), self.index) } + unsafe { + self.statement + .handle() + .map_or(0, |handle| sqlite3_column_int64(handle, self.index)) + } } /// Returns the 64-bit, REAL result. pub(super) fn double(&self) -> f64 { - unsafe { sqlite3_column_double(self.statement.handle(), self.index) } + unsafe { + self.statement + .handle() + .map_or(0.0, |handle| sqlite3_column_double(handle, self.index)) + } } /// Returns the UTF-8 TEXT result. pub(super) fn text(&self) -> Option<&'c str> { unsafe { - let ptr = sqlite3_column_text(self.statement.handle(), self.index); + self.statement.handle().and_then(|handle| { + let ptr = sqlite3_column_text(handle, self.index); - if ptr.is_null() { - None - } else { - Some(from_utf8_unchecked(CStr::from_ptr(ptr as _).to_bytes())) - } + if ptr.is_null() { + None + } else { + Some(from_utf8_unchecked(CStr::from_ptr(ptr as _).to_bytes())) + } + }) } } fn bytes(&self) -> usize { // Returns the size of the result in bytes. - let len = unsafe { sqlite3_column_bytes(self.statement.handle(), self.index) }; - len as usize + unsafe { + self.statement + .handle() + .map_or(0, |handle| sqlite3_column_bytes(handle, self.index)) as usize + } } /// Returns the BLOB result. pub(super) fn blob(&self) -> &'c [u8] { - let ptr = unsafe { sqlite3_column_blob(self.statement.handle(), self.index) }; + let ptr = unsafe { + if let Some(handle) = self.statement.handle() { + sqlite3_column_blob(handle, self.index) + } else { + // Null statements do not exist + return &[]; + } + }; if ptr.is_null() { // Empty BLOBs are received as null pointers