diff --git a/sqlx-core/src/sqlite/arguments.rs b/sqlx-core/src/sqlite/arguments.rs index 8bddda43..f5811233 100644 --- a/sqlx-core/src/sqlite/arguments.rs +++ b/sqlx-core/src/sqlite/arguments.rs @@ -4,7 +4,7 @@ use std::os::raw::c_int; use libsqlite3_sys::{ sqlite3_bind_blob, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, - sqlite3_bind_null, sqlite3_bind_text, SQLITE_OK, + sqlite3_bind_null, sqlite3_bind_text, SQLITE_OK, SQLITE_TRANSIENT, }; use crate::arguments::Arguments; @@ -13,6 +13,7 @@ use crate::sqlite::statement::SqliteStatement; use crate::sqlite::Sqlite; use crate::sqlite::SqliteError; use crate::types::Type; +use core::mem; #[derive(Debug, Clone)] pub enum SqliteArgumentValue { @@ -33,7 +34,22 @@ pub enum SqliteArgumentValue { #[derive(Default)] pub struct SqliteArguments { - pub(super) values: Vec, + index: usize, + values: Vec, +} + +impl SqliteArguments { + pub(crate) fn next(&mut self) -> Option { + if self.index >= self.values.len() { + return None; + } + + let mut value = SqliteArgumentValue::Null; + mem::swap(&mut value, &mut self.values[self.index]); + + self.index += 1; + Some(value) + } } impl Arguments for SqliteArguments { @@ -66,7 +82,13 @@ impl SqliteArgumentValue { let bytes_len = bytes.len() as i32; unsafe { - sqlite3_bind_blob(statement.handle.as_ptr(), index, bytes_ptr, bytes_len, None) + sqlite3_bind_blob( + statement.handle.as_ptr(), + index, + bytes_ptr, + bytes_len, + SQLITE_TRANSIENT(), + ) } } @@ -77,7 +99,13 @@ impl SqliteArgumentValue { let bytes_len = bytes.len() as i32; unsafe { - sqlite3_bind_text(statement.handle.as_ptr(), index, bytes_ptr, bytes_len, None) + sqlite3_bind_text( + statement.handle.as_ptr(), + index, + bytes_ptr, + bytes_len, + SQLITE_TRANSIENT(), + ) } } diff --git a/sqlx-core/src/sqlite/connection.rs b/sqlx-core/src/sqlite/connection.rs index 6a918239..663aa6f5 100644 --- a/sqlx-core/src/sqlite/connection.rs +++ b/sqlx-core/src/sqlite/connection.rs @@ -12,13 +12,15 @@ use libsqlite3_sys::{ }; use crate::connection::{Connect, Connection}; -use crate::runtime::spawn_blocking; use crate::sqlite::statement::SqliteStatement; use crate::sqlite::SqliteError; use crate::url::Url; pub struct SqliteConnection { pub(super) handle: NonNull, + // Storage of the most recently prepared, non-persistent statement + pub(super) statement: Option, + // Storage of persistent statements pub(super) statements: Vec, pub(super) statement_by_query: HashMap, } @@ -66,6 +68,7 @@ fn establish(url: crate::Result) -> crate::Result { Ok(SqliteConnection { handle: NonNull::new(handle).unwrap(), + statement: None, statements: Vec::with_capacity(10), statement_by_query: HashMap::with_capacity(10), }) diff --git a/sqlx-core/src/sqlite/cursor.rs b/sqlx-core/src/sqlite/cursor.rs index 1b1951af..0f13c4ab 100644 --- a/sqlx-core/src/sqlite/cursor.rs +++ b/sqlx-core/src/sqlite/cursor.rs @@ -3,23 +3,15 @@ use futures_core::future::BoxFuture; use crate::connection::ConnectionSource; use crate::cursor::Cursor; use crate::executor::Execute; -use crate::maybe_owned::MaybeOwned; use crate::pool::Pool; -use crate::sqlite::statement::{SqliteStatement, Step}; +use crate::sqlite::statement::Step; use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow}; -enum State<'q> { - Query((&'q str, Option)), - Statement { - query: &'q str, - arguments: Option, - statement: MaybeOwned, - }, -} - pub struct SqliteCursor<'c, 'q> { - source: ConnectionSource<'c, SqliteConnection>, - state: State<'q>, + pub(super) source: ConnectionSource<'c, SqliteConnection>, + query: &'q str, + arguments: Option, + pub(super) statement: Option>, } impl<'c, 'q> Cursor<'c, 'q> for SqliteCursor<'c, 'q> { @@ -30,9 +22,13 @@ impl<'c, 'q> Cursor<'c, 'q> for SqliteCursor<'c, 'q> { Self: Sized, E: Execute<'q, Sqlite>, { + let (query, arguments) = query.into_parts(); + Self { source: ConnectionSource::Pool(pool.clone()), - state: State::Query(query.into_parts()), + statement: None, + query, + arguments, } } @@ -41,9 +37,13 @@ impl<'c, 'q> Cursor<'c, 'q> for SqliteCursor<'c, 'q> { Self: Sized, E: Execute<'q, Sqlite>, { + let (query, arguments) = query.into_parts(); + Self { source: ConnectionSource::Connection(conn.into()), - state: State::Query(query.into_parts()), + statement: None, + query, + arguments, } } @@ -57,41 +57,38 @@ async fn next<'a, 'c: 'a, 'q: 'a>( ) -> crate::Result>> { let conn = cursor.source.resolve().await?; - let statement = loop { - match cursor.state { - State::Query((query, ref mut arguments)) => { - let mut statement = conn.prepare(query, arguments.is_some())?; - let statement_ = statement.resolve(&mut conn.statements); + loop { + if cursor.statement.is_none() { + let key = conn.prepare(&mut cursor.query, cursor.arguments.is_some())?; - if let Some(arguments) = arguments { - statement_.bind(arguments)?; - } - - cursor.state = State::Statement { - statement, - query, - arguments: arguments.take(), - }; + if let Some(arguments) = &mut cursor.arguments { + conn.statement_mut(key).bind(arguments)?; } - State::Statement { - ref mut statement, .. - } => { - break statement; + cursor.statement = Some(key); + } + + let key = cursor.statement.unwrap(); + let statement = conn.statement_mut(key); + + let step = statement.step().await?; + + match step { + Step::Row => { + return Ok(Some(SqliteRow { + statement: key, + connection: conn, + })); + } + + Step::Done if cursor.query.is_empty() => { + return Ok(None); + } + + Step::Done => { + cursor.statement = None; + // continue } } - }; - - let statement_ = statement.resolve(&mut conn.statements); - - match statement_.step().await? { - Step::Done => { - // TODO: If there is more to do, we need to do more - Ok(None) - } - - Step::Row => Ok(Some(SqliteRow { - statement: &*statement_, - })), } } diff --git a/sqlx-core/src/sqlite/database.rs b/sqlx-core/src/sqlite/database.rs index 16447a27..746c41c2 100644 --- a/sqlx-core/src/sqlite/database.rs +++ b/sqlx-core/src/sqlite/database.rs @@ -10,8 +10,7 @@ impl Database for Sqlite { type TypeInfo = super::SqliteTypeInfo; - // TODO? - type TableId = u32; + type TableId = String; type RawBuffer = Vec; } diff --git a/sqlx-core/src/sqlite/executor.rs b/sqlx-core/src/sqlite/executor.rs index 94da669f..5819f395 100644 --- a/sqlx-core/src/sqlite/executor.rs +++ b/sqlx-core/src/sqlite/executor.rs @@ -5,46 +5,67 @@ use libsqlite3_sys::sqlite3_changes; use crate::cursor::Cursor; use crate::describe::{Column, Describe}; use crate::executor::{Execute, Executor, RefExecutor}; -use crate::maybe_owned::MaybeOwned; use crate::sqlite::cursor::SqliteCursor; use crate::sqlite::statement::{SqliteStatement, Step}; use crate::sqlite::types::SqliteType; use crate::sqlite::{Sqlite, SqliteConnection, SqliteTypeInfo}; impl SqliteConnection { + pub(super) fn statement(&self, key: Option) -> &SqliteStatement { + match key { + Some(key) => &self.statements[key], + None => self.statement.as_ref().unwrap(), + } + } + + pub(super) fn statement_mut(&mut self, key: Option) -> &mut SqliteStatement { + match key { + Some(key) => &mut self.statements[key], + None => self.statement.as_mut().unwrap(), + } + } + pub(super) fn prepare( &mut self, - query: &str, + query: &mut &str, persistent: bool, - ) -> crate::Result> { + ) -> crate::Result> { // TODO: Revisit statement caching and allow cache expiration by using a // generational index if !persistent { - // A non-persistent query will be immediately prepared and returned - return SqliteStatement::new(&mut self.handle, query, false).map(MaybeOwned::Owned); + // A non-persistent query will be immediately prepared and returned, + // regardless of the current state of the cache + self.statement = Some(SqliteStatement::new(&mut self.handle, query, false)?); + return Ok(None); } - if let Some(key) = self.statement_by_query.get(query) { + if let Some(key) = self.statement_by_query.get(&**query) { let statement = &mut self.statements[*key]; + // Adjust the passed in query string as if [string3_prepare] + // did the tail parsing + *query = &query[statement.tail..]; + // As this statement has very likely been used before, we reset // it to clear the bindings and its program state statement.reset(); - return Ok(MaybeOwned::Borrowed(*key)); + return Ok(Some(*key)); } // Prepare a new statement object; ensuring to tell SQLite that this will be stored // for a "long" time and re-used multiple times + let query_key = query.to_owned(); + let statement = SqliteStatement::new(&mut self.handle, query, true)?; + let key = self.statements.len(); - self.statement_by_query.insert(query.to_owned(), key); - self.statements - .push(SqliteStatement::new(&mut self.handle, query, true)?); + self.statement_by_query.insert(query_key, key); + self.statements.push(statement); - Ok(MaybeOwned::Borrowed(key)) + Ok(Some(key)) } // This is used for [affected_rows] in the public API. @@ -72,15 +93,21 @@ impl Executor for SqliteConnection { let (mut query, mut arguments) = query.into_parts(); Box::pin(async move { - let mut statement = self.prepare(query, arguments.is_some())?; - let statement_ = statement.resolve(&mut self.statements); + loop { + let key = self.prepare(&mut query, arguments.is_some())?; + let statement = self.statement_mut(key); - if let Some(arguments) = &mut arguments { - statement_.bind(arguments)?; - } + if let Some(arguments) = &mut arguments { + statement.bind(arguments)?; + } - while let Step::Row = statement_.step().await? { - // We only care about the rows modified; ignore + while let Step::Row = statement.step().await? { + // We only care about the rows modified; ignore + } + + if query.is_empty() { + break; + } } Ok(self.changes()) @@ -102,9 +129,9 @@ impl Executor for SqliteConnection { E: Execute<'q, Self::Database>, { Box::pin(async move { - let (query, _) = query.into_parts(); - let mut statement = self.prepare(query, false)?; - let statement = statement.resolve(&mut self.statements); + let (mut query, _) = query.into_parts(); + let key = self.prepare(&mut query, false)?; + let statement = self.statement(key); // First let's attempt to describe what we can about parameter types // Which happens to just be the count, heh diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs index 3537884f..fee867f2 100644 --- a/sqlx-core/src/sqlite/row.rs +++ b/sqlx-core/src/sqlite/row.rs @@ -5,10 +5,17 @@ use crate::database::HasRow; use crate::row::{ColumnIndex, Row}; use crate::sqlite::statement::SqliteStatement; use crate::sqlite::value::SqliteResultValue; -use crate::sqlite::Sqlite; +use crate::sqlite::{Sqlite, SqliteConnection}; pub struct SqliteRow<'c> { - pub(super) statement: &'c SqliteStatement, + pub(super) statement: Option, + pub(super) connection: &'c SqliteConnection, +} + +impl SqliteRow<'_> { + fn statement(&self) -> &SqliteStatement { + self.connection.statement(self.statement) + } } impl<'c> Row<'c> for SqliteRow<'c> { @@ -24,7 +31,7 @@ impl<'c> Row<'c> for SqliteRow<'c> { // sqlite3_step that returned SQLITE_ROW. #[allow(unsafe_code)] - let count: c_int = unsafe { sqlite3_data_count(self.statement.handle.as_ptr()) }; + let count: c_int = unsafe { sqlite3_data_count(self.statement().handle.as_ptr()) }; count as usize } @@ -36,6 +43,7 @@ impl<'c> Row<'c> for SqliteRow<'c> { let index = index.resolve(self)?; let value = SqliteResultValue { index, + connection: self.connection, statement: self.statement, }; @@ -57,7 +65,7 @@ impl ColumnIndex for usize { impl ColumnIndex for &'_ str { fn resolve(self, row: &::Row) -> crate::Result { - row.statement + row.statement() .columns() .get(self) .ok_or_else(|| crate::Error::ColumnNotFound((*self).into())) diff --git a/sqlx-core/src/sqlite/statement.rs b/sqlx-core/src/sqlite/statement.rs index bb6a0b10..9f2fe628 100644 --- a/sqlx-core/src/sqlite/statement.rs +++ b/sqlx-core/src/sqlite/statement.rs @@ -1,6 +1,6 @@ use core::cell::{RefCell, RefMut}; use core::ops::Deref; -use core::ptr::{null_mut, NonNull}; +use core::ptr::{null, null_mut, NonNull}; use std::collections::HashMap; use std::ffi::CStr; @@ -22,6 +22,7 @@ pub(crate) enum Step { } pub struct SqliteStatement { + pub(super) tail: usize, pub(super) handle: NonNull, columns: RefCell>>, } @@ -37,7 +38,7 @@ unsafe impl Sync for SqliteStatement {} impl SqliteStatement { pub(super) fn new( handle: &mut NonNull, - query: &str, + query: &mut &str, persistent: bool, ) -> crate::Result { // TODO: Error on queries that are too large @@ -45,6 +46,7 @@ impl SqliteStatement { let query_len = query.len() as i32; let mut statement_handle: *mut sqlite3_stmt = null_mut(); let mut flags = SQLITE_PREPARE_NO_VTAB; + let mut tail: *const i8 = null(); if persistent { // SQLITE_PREPARE_PERSISTENT @@ -63,10 +65,15 @@ impl SqliteStatement { query_len, flags as u32, &mut statement_handle, - null_mut(), + &mut tail, ) }; + // If pzTail is not NULL then *pzTail is made to point to the first byte + // past the end of the first SQL statement in zSql. + let tail = (tail as usize) - (query_ptr as usize); + *query = &query[tail..].trim(); + if status != SQLITE_OK { return Err(SqliteError::new(status).into()); } @@ -74,6 +81,7 @@ impl SqliteStatement { Ok(Self { handle: NonNull::new(statement_handle).unwrap(), columns: RefCell::new(None), + tail, }) } @@ -132,8 +140,12 @@ impl SqliteStatement { } pub(super) fn bind(&mut self, arguments: &mut SqliteArguments) -> crate::Result<()> { - for (index, value) in arguments.values.iter().enumerate() { - value.bind(self, index + 1)?; + for index in 0..self.params() { + if let Some(value) = arguments.next() { + value.bind(self, index + 1)?; + } else { + break; + } } Ok(()) diff --git a/sqlx-core/src/sqlite/value.rs b/sqlx-core/src/sqlite/value.rs index 44550832..8f1dacc2 100644 --- a/sqlx-core/src/sqlite/value.rs +++ b/sqlx-core/src/sqlite/value.rs @@ -8,11 +8,19 @@ use libsqlite3_sys::{ use crate::sqlite::statement::SqliteStatement; use crate::sqlite::types::SqliteType; +use crate::sqlite::SqliteConnection; use core::slice; pub struct SqliteResultValue<'c> { pub(super) index: usize, - pub(super) statement: &'c SqliteStatement, + pub(super) statement: Option, + pub(super) connection: &'c SqliteConnection, +} + +impl SqliteResultValue<'_> { + fn statement(&self) -> &SqliteStatement { + self.connection.statement(self.statement) + } } // https://www.sqlite.org/c3ref/column_blob.html @@ -24,7 +32,7 @@ impl<'c> SqliteResultValue<'c> { pub(crate) fn r#type(&self) -> SqliteType { #[allow(unsafe_code)] let type_code = - unsafe { sqlite3_column_type(self.statement.handle.as_ptr(), self.index as i32) }; + unsafe { sqlite3_column_type(self.statement().handle.as_ptr(), self.index as i32) }; match type_code { SQLITE_INTEGER => SqliteType::Integer, @@ -40,21 +48,21 @@ impl<'c> SqliteResultValue<'c> { pub(crate) fn int(&self) -> i32 { #[allow(unsafe_code)] unsafe { - sqlite3_column_int(self.statement.handle.as_ptr(), self.index as i32) + sqlite3_column_int(self.statement().handle.as_ptr(), self.index as i32) } } pub(crate) fn int64(&self) -> i64 { #[allow(unsafe_code)] unsafe { - sqlite3_column_int64(self.statement.handle.as_ptr(), self.index as i32) + sqlite3_column_int64(self.statement().handle.as_ptr(), self.index as i32) } } pub(crate) fn double(&self) -> f64 { #[allow(unsafe_code)] unsafe { - sqlite3_column_double(self.statement.handle.as_ptr(), self.index as i32) + sqlite3_column_double(self.statement().handle.as_ptr(), self.index as i32) } } @@ -62,7 +70,8 @@ impl<'c> SqliteResultValue<'c> { #[allow(unsafe_code)] let raw = unsafe { CStr::from_ptr( - sqlite3_column_text(self.statement.handle.as_ptr(), self.index as i32) as *const i8, + sqlite3_column_text(self.statement().handle.as_ptr(), self.index as i32) + as *const i8, ) }; @@ -73,10 +82,10 @@ impl<'c> SqliteResultValue<'c> { let index = self.index as i32; #[allow(unsafe_code)] - let ptr = unsafe { sqlite3_column_blob(self.statement.handle.as_ptr(), index) }; + let ptr = unsafe { sqlite3_column_blob(self.statement().handle.as_ptr(), index) }; #[allow(unsafe_code)] - let len = unsafe { sqlite3_column_bytes(self.statement.handle.as_ptr(), index) }; + let len = unsafe { sqlite3_column_bytes(self.statement().handle.as_ptr(), index) }; #[allow(unsafe_code)] let raw = unsafe { slice::from_raw_parts(ptr as *const u8, len as usize) }; diff --git a/tests/sqlite-raw.rs b/tests/sqlite-raw.rs index be0b0545..04c3e107 100644 --- a/tests/sqlite-raw.rs +++ b/tests/sqlite-raw.rs @@ -2,3 +2,51 @@ use sqlx::{Cursor, Executor, Row, Sqlite}; use sqlx_test::new; + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_select_expression() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut cursor = conn.fetch("SELECT 5"); + let row = cursor.next().await?.unwrap(); + + assert!(5i32 == row.try_get::(0)?); + + Ok(()) +} + +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_multi_read_write() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut cursor = conn.fetch( + " +CREATE TABLE IF NOT EXISTS _sqlx_test ( + id INT PRIMARY KEY, + text TEXT NOT NULL +); + +SELECT 'Hello World' as _1; + +INSERT INTO _sqlx_test (text) VALUES ('this is a test'); + +SELECT id, text FROM _sqlx_test; + ", + ); + + let row = cursor.next().await?.unwrap(); + + assert!("Hello World" == row.try_get::<&str, _>("_1")?); + + let row = cursor.next().await?.unwrap(); + + let id: i64 = row.try_get("id")?; + let text: &str = row.try_get("text")?; + + assert_eq!(0, id); + assert_eq!("this is a test", text); + + Ok(()) +} diff --git a/tests/sqlite.rs b/tests/sqlite.rs index de1d8ccb..af4b1e03 100644 --- a/tests/sqlite.rs +++ b/tests/sqlite.rs @@ -54,6 +54,40 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY) Ok(()) } +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_can_execute_multiple_statements() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let affected = conn + .execute( + r#" +CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, other INTEGER); +INSERT INTO users DEFAULT VALUES; + "#, + ) + .await?; + + assert_eq!(affected, 1); + + for index in 2..5_i32 { + let (id, other): (i32, i32) = sqlx::query_as( + r#" +INSERT INTO users (other) VALUES (?); +SELECT id, other FROM users WHERE id = last_insert_rowid(); + "#, + ) + .bind(index) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, index); + assert_eq!(other, index); + } + + Ok(()) +} + #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_describes() -> anyhow::Result<()> {