From d112c4d807749b99473195d6cd0d839db00bbad9 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sat, 4 Jul 2020 02:56:02 -0700 Subject: [PATCH] feat(sqlite): support expressions and multiple no-data statements in the macros --- sqlx-core/src/lib.rs | 1 + sqlx-core/src/sqlite/connection/describe.rs | 113 +++++++++++++++ sqlx-core/src/sqlite/connection/executor.rs | 35 +---- sqlx-core/src/sqlite/connection/explain.rs | 153 ++++++++++++++++++++ sqlx-core/src/sqlite/connection/mod.rs | 2 + sqlx-core/src/sqlite/statement/handle.rs | 23 ++- sqlx-core/src/sqlite/type_info.rs | 2 +- tests/sqlite/describe.rs | 99 ++++++++++++- 8 files changed, 390 insertions(+), 38 deletions(-) create mode 100644 sqlx-core/src/sqlite/connection/describe.rs diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index efbb69e45..a5193fd8b 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -46,6 +46,7 @@ pub mod types; #[macro_use] pub mod query; +mod column; mod common; pub mod database; pub mod describe; diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs new file mode 100644 index 000000000..9a0aca87b --- /dev/null +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -0,0 +1,113 @@ +use crate::describe::{Column, Describe}; +use crate::error::Error; +use crate::sqlite::connection::explain::explain; +use crate::sqlite::statement::SqliteStatement; +use crate::sqlite::type_info::DataType; +use crate::sqlite::{Sqlite, SqliteConnection, SqliteTypeInfo}; +use futures_core::future::BoxFuture; + +pub(super) async fn describe( + conn: &mut SqliteConnection, + query: &str, +) -> Result, Error> { + describe_with(conn, query, vec![]).await +} + +pub(super) fn describe_with<'c: 'e, 'q: 'e, 'e>( + conn: &'c mut SqliteConnection, + query: &'q str, + fallback: Vec, +) -> BoxFuture<'e, Result, Error>> { + Box::pin(async move { + // describing a statement from SQLite can be involved + // each SQLx statement is comprised of multiple SQL statements + + let SqliteConnection { + ref mut handle, + ref worker, + .. + } = conn; + + let statement = SqliteStatement::prepare(handle, query, false); + + let mut columns = Vec::new(); + let mut num_params = 0; + + let mut statement = statement?; + + // we start by finding the first statement that *can* return results + while let Some((statement, _)) = statement.execute()? { + num_params += statement.bind_parameter_count(); + + let mut stepped = false; + + let num = statement.column_count(); + if num == 0 { + // no columns in this statement; skip + continue; + } + + // next we try to use [column_decltype] to inspect the type of each column + columns.reserve(num); + + for col in 0..num { + let name = statement.column_name(col).to_owned(); + + let type_info = if let Some(ty) = statement.column_decltype(col) { + ty + } else { + // if that fails, we back up and attempt to step the statement + // once *if* its read-only and then use [column_type] as a + // fallback to [column_decltype] + if !stepped && statement.read_only() && fallback.is_empty() { + stepped = true; + + worker.execute(statement); + worker.wake(); + + let _ = worker.step(statement).await?; + } + + let mut ty = statement.column_type_info(col); + + if ty.0 == DataType::Null { + if fallback.is_empty() { + // this will _still_ fail if there are no actual rows to return + // this happens more often than not for the macros as we tell + // users to execute against an empty database + + // as a last resort, we explain the original query and attempt to + // infer what would the expression types be as a fallback + // to [column_decltype] + + let fallback = explain(conn, statement.sql()).await?; + + return describe_with(conn, query, fallback).await; + } + + if let Some(fallback) = fallback.get(col).cloned() { + ty = fallback; + } + } + + ty + }; + + let not_null = statement.column_not_null(col)?; + + columns.push(Column { + name, + type_info: Some(type_info), + not_null, + }); + } + } + + // println!("describe ->> {:#?}", columns); + + Ok(Describe { + columns, + params: vec![None; num_params], + }) + }) +} diff --git a/sqlx-core/src/sqlite/connection/executor.rs b/sqlx-core/src/sqlite/connection/executor.rs index aa454be8a..436925461 100644 --- a/sqlx-core/src/sqlite/connection/executor.rs +++ b/sqlx-core/src/sqlite/connection/executor.rs @@ -3,14 +3,15 @@ use std::sync::Arc; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use futures_util::TryStreamExt; +use futures_util::{FutureExt, TryStreamExt}; use hashbrown::HashMap; use crate::common::StatementCache; -use crate::describe::{Column, Describe}; +use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; +use crate::sqlite::connection::describe::describe; use crate::sqlite::connection::ConnectionHandle; use crate::sqlite::statement::{SqliteStatement, StatementHandle}; use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow}; @@ -176,34 +177,6 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'c: 'e, E: Execute<'q, Self::Database>, { - let query = query.query(); - let statement = SqliteStatement::prepare(&mut self.handle, query, false); - - Box::pin(async move { - let mut params = Vec::new(); - let mut columns = Vec::new(); - - if let Some(statement) = statement?.handles.get(0) { - // NOTE: we can infer *nothing* about parameters apart from the count - params.resize(statement.bind_parameter_count(), None); - - let num_columns = statement.column_count(); - columns.reserve(num_columns); - - for i in 0..num_columns { - let name = statement.column_name(i).to_owned(); - let type_info = statement.column_decltype(i); - let not_null = statement.column_not_null(i)?; - - columns.push(Column { - name, - type_info, - not_null, - }) - } - } - - Ok(Describe { params, columns }) - }) + describe(self, query.query()).boxed() } } diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index e69de29bb..20a5c1708 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -0,0 +1,153 @@ +use crate::error::Error; +use crate::query_as::query_as; +use crate::sqlite::type_info::DataType; +use crate::sqlite::{SqliteConnection, SqliteTypeInfo}; +use hashbrown::HashMap; + +const OP_INIT: &str = "Init"; +const OP_GOTO: &str = "Goto"; +const OP_COLUMN: &str = "Column"; +const OP_AGG_STEP: &str = "AggStep"; +const OP_MOVE: &str = "Move"; +const OP_COPY: &str = "Copy"; +const OP_SCOPY: &str = "SCopy"; +const OP_INT_COPY: &str = "IntCopy"; +const OP_STRING8: &str = "String8"; +const OP_INT64: &str = "Int64"; +const OP_INTEGER: &str = "Integer"; +const OP_REAL: &str = "Real"; +const OP_NOT: &str = "Not"; +const OP_BLOB: &str = "Blob"; +const OP_COUNT: &str = "Count"; +const OP_ROWID: &str = "Rowid"; +const OP_OR: &str = "Or"; +const OP_AND: &str = "And"; +const OP_BIT_AND: &str = "BitAnd"; +const OP_BIT_OR: &str = "BitOr"; +const OP_SHIFT_LEFT: &str = "ShiftLeft"; +const OP_SHIFT_RIGHT: &str = "ShiftRight"; +const OP_ADD: &str = "Add"; +const OP_SUBTRACT: &str = "Subtract"; +const OP_MULTIPLY: &str = "Multiply"; +const OP_DIVIDE: &str = "Divide"; +const OP_REMAINDER: &str = "Remainder"; +const OP_CONCAT: &str = "Concat"; +const OP_RESULT_ROW: &str = "ResultRow"; + +fn to_type(op: &str) -> DataType { + match op { + OP_REAL => DataType::Float, + OP_BLOB => DataType::Blob, + OP_AND | OP_OR => DataType::Bool, + OP_ROWID | OP_COUNT | OP_INT64 | OP_INTEGER => DataType::Int64, + OP_STRING8 => DataType::Text, + OP_COLUMN | _ => DataType::Null, + } +} + +pub(super) async fn explain( + conn: &mut SqliteConnection, + query: &str, +) -> Result, Error> { + let mut r = HashMap::::with_capacity(6); + + let program = + query_as::<_, (i64, String, i64, i64, i64, String)>(&*format!("EXPLAIN {}", query)) + .fetch_all(&mut *conn) + .await?; + + let mut program_i = 0; + let program_size = program.len(); + + while program_i < program_size { + let (_, ref opcode, p1, p2, p3, ref p4) = program[program_i]; + + match &**opcode { + OP_INIT => { + // start at + program_i = p2 as usize; + continue; + } + + OP_GOTO => { + // goto + program_i = p2 as usize; + continue; + } + + OP_COLUMN => { + // r[p3] = + r.insert(p3, DataType::Null); + } + + OP_AGG_STEP => { + if p4.starts_with("count(") { + // count(_) -> INTEGER + r.insert(p3, DataType::Int64); + } else if let Some(v) = r.get(&p2).copied() { + // r[p3] = AGG ( r[p2] ) + r.insert(p3, v); + } + } + + OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => { + // r[p2] = r[p1] + if let Some(v) = r.get(&p1).copied() { + r.insert(p2, v); + } + } + + OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => { + // r[p2] = + r.insert(p2, to_type(&opcode)); + } + + OP_NOT => { + // r[p2] = NOT r[p1] + if let Some(a) = r.get(&p1).copied() { + r.insert(p2, a); + } + } + + OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT + | OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => { + // r[p3] = r[p1] + r[p2] + match (r.get(&p1).copied(), r.get(&p2).copied()) { + (Some(a), Some(b)) => { + r.insert(p3, if matches!(a, DataType::Null) { b } else { a }); + } + + (Some(v), None) => { + r.insert(p3, v); + } + + (None, Some(v)) => { + r.insert(p3, v); + } + + _ => {} + } + } + + OP_RESULT_ROW => { + // output = r[p1 .. p1 + p2] + let mut output = Vec::with_capacity(p2 as usize); + for i in p1..p1 + p2 { + output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null))); + } + + return Ok(output); + } + + _ => { + // ignore unsupported operations + // if we fail to find an r later, we just give up + } + } + + program_i += 1; + } + + // no rows + Ok(vec![]) +} diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 25bf9bbbd..31ac8ac26 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -15,8 +15,10 @@ use crate::sqlite::connection::establish::establish; use crate::sqlite::statement::{SqliteStatement, StatementWorker}; use crate::sqlite::{Sqlite, SqliteConnectOptions}; +mod describe; mod establish; mod executor; +mod explain; mod handle; pub(crate) use handle::ConnectionHandle; diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index c4f9325e1..c559f9e22 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -11,8 +11,8 @@ use libsqlite3_sys::{ sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type, - sqlite3_column_value, sqlite3_db_handle, sqlite3_stmt, sqlite3_table_column_metadata, - SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, + sqlite3_column_value, sqlite3_db_handle, sqlite3_sql, sqlite3_stmt, sqlite3_stmt_readonly, + sqlite3_table_column_metadata, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, }; use crate::error::{BoxDynError, Error}; @@ -38,6 +38,21 @@ impl StatementHandle { sqlite3_db_handle(self.0.as_ptr()) } + pub(crate) fn read_only(&self) -> bool { + // https://sqlite.org/c3ref/stmt_readonly.html + unsafe { sqlite3_stmt_readonly(self.0.as_ptr()) != 0 } + } + + pub(crate) fn sql(&self) -> &str { + // https://sqlite.org/c3ref/expanded_sql.html + unsafe { + let raw = sqlite3_sql(self.0.as_ptr()); + debug_assert!(!raw.is_null()); + + from_utf8_unchecked(CStr::from_ptr(raw).to_bytes()) + } + } + #[inline] pub(crate) fn last_error(&self) -> SqliteError { SqliteError::new(unsafe { self.db_handle() }) @@ -68,6 +83,10 @@ impl StatementHandle { } } + pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo { + SqliteTypeInfo(DataType::from_code(self.column_type(index))) + } + #[inline] pub(crate) fn column_decltype(&self, index: usize) -> Option { unsafe { diff --git a/sqlx-core/src/sqlite/type_info.rs b/sqlx-core/src/sqlite/type_info.rs index 5d1a58f3c..240880376 100644 --- a/sqlx-core/src/sqlite/type_info.rs +++ b/sqlx-core/src/sqlite/type_info.rs @@ -7,7 +7,7 @@ use libsqlite3_sys::{SQLITE_BLOB, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQL use crate::error::BoxDynError; use crate::type_info::TypeInfo; -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub(crate) enum DataType { Null, diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index eac56fee8..a890bd8d9 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -1,6 +1,10 @@ +use sqlx::describe::Column; +use sqlx::error::DatabaseError; +use sqlx::sqlite::{SqliteConnectOptions, SqliteError}; use sqlx::{sqlite::Sqlite, Executor}; -use sqlx_core::describe::Column; +use sqlx::{Connect, SqliteConnection, TypeInfo}; use sqlx_test::new; +use std::env; fn type_names(columns: &[Column]) -> Vec { columns @@ -41,14 +45,101 @@ async fn it_describes_simple() -> anyhow::Result<()> { async fn it_describes_expression() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT 1 + 10").await?; + let d = conn + .describe("SELECT 1 + 10, 5.12 * 2, 'Hello', x'deadbeef'") + .await?; + let columns = d.columns; + assert_eq!(columns[0].type_info.as_ref().unwrap().name(), "INTEGER"); assert_eq!(columns[0].name, "1 + 10"); assert_eq!(columns[0].not_null, None); - // SQLite cannot infer types for expressions - assert_eq!(columns[0].type_info, None); + assert_eq!(columns[1].type_info.as_ref().unwrap().name(), "REAL"); + assert_eq!(columns[1].name, "5.12 * 2"); + assert_eq!(columns[1].not_null, None); + + assert_eq!(columns[2].type_info.as_ref().unwrap().name(), "TEXT"); + assert_eq!(columns[2].name, "'Hello'"); + assert_eq!(columns[2].not_null, None); + + assert_eq!(columns[3].type_info.as_ref().unwrap().name(), "BLOB"); + assert_eq!(columns[3].name, "x'deadbeef'"); + assert_eq!(columns[3].not_null, None); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_expression_from_empty_table() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute("CREATE TEMP TABLE _temp_empty ( name TEXT, a INT )") + .await?; + + let d = conn + .describe("SELECT COUNT(*), a + 1, name, 5.12, 'Hello' FROM _temp_empty") + .await?; + + assert_eq!(d.columns[0].type_info.as_ref().unwrap().name(), "INTEGER"); + assert_eq!(d.columns[1].type_info.as_ref().unwrap().name(), "INTEGER"); + assert_eq!(d.columns[2].type_info.as_ref().unwrap().name(), "TEXT"); + assert_eq!(d.columns[3].type_info.as_ref().unwrap().name(), "REAL"); + assert_eq!(d.columns[4].type_info.as_ref().unwrap().name(), "TEXT"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_insert() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn + .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')") + .await?; + + assert_eq!(d.columns.len(), 0); + + let d = conn + .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello'); SELECT last_insert_rowid();") + .await?; + + assert_eq!(d.columns.len(), 1); + assert_eq!(d.columns[0].type_info.as_ref().unwrap().name(), "INTEGER"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_insert_with_read_only() -> anyhow::Result<()> { + sqlx_test::setup_if_needed(); + + let mut options: SqliteConnectOptions = env::var("DATABASE_URL")?.parse().unwrap(); + options = options.read_only(true); + + let mut conn = SqliteConnection::connect_with(&options).await?; + + let d = conn + .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')") + .await?; + + assert_eq!(d.columns.len(), 0); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_bad_statement() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let err = conn.describe("SELECT 1 FROM not_found").await.unwrap_err(); + let err = err + .as_database_error() + .unwrap() + .downcast_ref::(); + + assert_eq!(err.message(), "no such table: not_found"); + assert_eq!(err.code().as_deref(), Some("1")); Ok(()) }