diff --git a/Cargo.lock b/Cargo.lock index 836d7536..240ca808 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -268,34 +268,13 @@ dependencies = [ "constant_time_eq", ] -[[package]] -name = "block-buffer" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0940dc441f31689269e10ac70eb1002a3a1d3ad1390e030043662eb7fe4688b" -dependencies = [ - "block-padding", - "byte-tools", - "byteorder", - "generic-array 0.12.3", -] - [[package]] name = "block-buffer" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" dependencies = [ - "generic-array 0.14.2", -] - -[[package]] -name = "block-padding" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa79dedbb091f449f1f39e53edf88d5dbe95f895dae6135a8d7b881fb5af73f5" -dependencies = [ - "byte-tools", + "generic-array", ] [[package]] @@ -329,12 +308,6 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e8c087f005730276d1096a652e92a8bacee2e2472bcc9715a74d2bec38b5820" -[[package]] -name = "byte-tools" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3b5ca7a04898ad4bcd41c90c5285445ff5b791899bb1b0abdd2a2aa791211d7" - [[package]] name = "byteorder" version = "1.3.4" @@ -637,7 +610,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" dependencies = [ - "generic-array 0.14.2", + "generic-array", "subtle", ] @@ -691,22 +664,13 @@ dependencies = [ "tempfile", ] -[[package]] -name = "digest" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5" -dependencies = [ - "generic-array 0.12.3", -] - [[package]] name = "digest" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" dependencies = [ - "generic-array 0.14.2", + "generic-array", ] [[package]] @@ -726,6 +690,9 @@ name = "either" version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +dependencies = [ + "serde", +] [[package]] name = "encode_unicode" @@ -765,12 +732,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "fake-simd" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" - [[package]] name = "fastrand" version = "1.3.2" @@ -950,15 +911,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "generic-array" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c68f0274ae0e023facc3c97b2e00f076be70e254bc851d972503b328db79b2ec" -dependencies = [ - "typenum", -] - [[package]] name = "generic-array" version = "0.14.2" @@ -1070,7 +1022,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "126888268dcc288495a26bf004b38c5fdbb31682f992c84ceb046a1f0fe38840" dependencies = [ "crypto-mac", - "digest 0.9.0", + "digest", ] [[package]] @@ -1385,9 +1337,9 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5a279bb9607f9f53c22d496eade00d138d1bdcccd07d74650387cf94942a15" dependencies = [ - "block-buffer 0.9.0", - "digest 0.9.0", - "opaque-debug 0.3.0", + "block-buffer", + "digest", + "opaque-debug", ] [[package]] @@ -1601,12 +1553,6 @@ version = "11.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a170cebd8021a008ea92e4db85a72f80b35df514ec664b296fdcbb654eac0b2c" -[[package]] -name = "opaque-debug" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" - [[package]] name = "opaque-debug" version = "0.3.0" @@ -2104,7 +2050,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3648b669b10afeab18972c105e284a7b953a669b0be3514c27f9b17acab2f9cd" dependencies = [ "byteorder", - "digest 0.9.0", + "digest", "lazy_static", "num-bigint-dig", "num-integer", @@ -2112,7 +2058,7 @@ dependencies = [ "num-traits", "pem", "rand", - "sha2 0.9.1", + "sha2", "simple_asn1", "subtle", "thiserror", @@ -2294,11 +2240,11 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "170a36ea86c864a3f16dd2687712dd6646f7019f301e57537c7f4dc9f5916770" dependencies = [ - "block-buffer 0.9.0", + "block-buffer", "cfg-if", "cpuid-bool", - "digest 0.9.0", - "opaque-debug 0.3.0", + "digest", + "opaque-debug", ] [[package]] @@ -2307,29 +2253,17 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2579985fda508104f7587689507983eadd6a6e84dd35d6d115361f530916fa0d" -[[package]] -name = "sha2" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a256f46ea78a0c0d9ff00077504903ac881a1dafdc20da66545699e7776b3e69" -dependencies = [ - "block-buffer 0.7.3", - "digest 0.8.1", - "fake-simd", - "opaque-debug 0.2.3", -] - [[package]] name = "sha2" version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2933378ddfeda7ea26f48c555bdad8bb446bf8a3d17832dc83e380d444cfb8c1" dependencies = [ - "block-buffer 0.9.0", + "block-buffer", "cfg-if", "cpuid-bool", - "digest 0.9.0", - "opaque-debug 0.3.0", + "digest", + "opaque-debug", ] [[package]] @@ -2497,13 +2431,13 @@ dependencies = [ "crossbeam-channel", "crossbeam-queue", "crossbeam-utils 0.7.2", - "digest 0.9.0", + "digest", "either", "encoding_rs", "futures-channel", "futures-core", "futures-util", - "generic-array 0.14.2", + "generic-array", "hashbrown", "hex", "hmac", @@ -2526,7 +2460,7 @@ dependencies = [ "serde", "serde_json", "sha-1", - "sha2 0.9.1", + "sha2", "smallvec 1.4.0", "sqlformat", "sqlx-rt", @@ -2615,6 +2549,7 @@ name = "sqlx-macros" version = "0.4.0-pre" dependencies = [ "dotenv", + "either", "futures 0.3.5", "heck", "hex", @@ -2622,7 +2557,7 @@ dependencies = [ "quote", "serde", "serde_json", - "sha2 0.8.2", + "sha2", "sqlx-core", "sqlx-rt", "syn", @@ -2827,9 +2762,9 @@ dependencies = [ [[package]] name = "terminal_size" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8038f95fc7a6f351163f4b964af631bd26c9e828f7db085f2a84aca56f70d13b" +checksum = "9a14cd9f8c72704232f0bfc8455c0e861f0ad4eb60cc9ec8a170e231414c1e13" dependencies = [ "libc", "winapi 0.3.9", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index bfc37265..8fa0dbc5 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -34,7 +34,7 @@ runtime-tokio = [ "sqlx-rt/runtime-tokio" ] runtime-actix = [ "sqlx-rt/runtime-actix" ] # support offline/decoupled building (enables serialization of `Describe`) -offline = [ "serde" ] +offline = [ "serde", "either/serde" ] [dependencies] atoi = "0.3.2" diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs new file mode 100644 index 00000000..ede2b833 --- /dev/null +++ b/sqlx-core/src/column.rs @@ -0,0 +1,26 @@ +use crate::database::Database; +use std::fmt::Debug; + +pub trait Column: private_column::Sealed + 'static + Send + Sync + Debug { + type Database: Database; + + /// Gets the column ordinal. + /// + /// This can be used to unambiguously refer to this column within a row in case more than + /// one column have the same name + fn ordinal(&self) -> usize; + + /// Gets the column name or alias. + /// + /// The column name is unreliable (and can change between database minor versions) if this + /// column is an expression that has not been aliased. + fn name(&self) -> &str; + + /// Gets the type information for the column. + fn type_info(&self) -> &::TypeInfo; +} + +// Prevent users from implementing the `Row` trait. +pub(crate) mod private_column { + pub trait Sealed {} +} diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index 59a3c625..5963522f 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use crate::arguments::Arguments; +use crate::column::Column; use crate::connection::Connect; use crate::row::Row; use crate::transaction::TransactionManager; @@ -31,6 +32,9 @@ pub trait Database: /// The concrete `Row` implementation for this database. type Row: Row; + /// The concrete `Column` implementation for this database. + type Column: Column; + /// The concrete `TypeInfo` implementation for this database. type TypeInfo: TypeInfo; diff --git a/sqlx-core/src/describe.rs b/sqlx-core/src/describe.rs deleted file mode 100644 index 9036c90b..00000000 --- a/sqlx-core/src/describe.rs +++ /dev/null @@ -1,67 +0,0 @@ -//! Types for returning SQL type information about queries. -//! -//! The compile-time type checking within the query macros heavily lean on the information -//! provided within these types. - -use crate::database::Database; - -// TODO(@mehcode): Remove [pub] from Describe/Column and use methods to expose the properties - -/// A representation of a statement that _could_ have been executed against the database. -/// -/// Returned from [`Executor::describe`](crate::executor::Executor::describe). -/// -/// The compile-time verification within the query macros utilizes `describe` and this type to -/// act on an arbitrary query. -#[derive(Debug)] -#[non_exhaustive] -#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr( - feature = "offline", - serde(bound( - serialize = "DB::TypeInfo: serde::Serialize", - deserialize = "DB::TypeInfo: serde::de::DeserializeOwned" - )) -)] -pub struct Describe -where - DB: Database, -{ - /// The expected types of the parameters. This is currently always an array of `None` values - /// on all databases drivers aside from PostgreSQL. - pub params: Vec>, - - /// The columns that will be found in the results from this query. - pub columns: Vec>, -} - -#[derive(Debug)] -#[non_exhaustive] -#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] -#[cfg_attr( - feature = "offline", - serde(bound( - serialize = "DB::TypeInfo: serde::Serialize", - deserialize = "DB::TypeInfo: serde::de::DeserializeOwned" - )) -)] -pub struct Column -where - DB: Database, -{ - /// The name of the result column. - /// - /// The column name is unreliable (and can change between database minor versions) if this - /// result column is an expression that has not been aliased. - pub name: String, - - /// The type information for the result column. - /// - /// This may be `None` if the type cannot be determined. This occurs in SQLite when - /// the column is an expression. - pub type_info: Option, - - /// Whether the column cannot be `NULL` (or if that is even knowable). - /// This value is only not `None` if received from a call to `describe`. - pub not_null: Option, -} diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 0355d456..9abd859c 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -6,8 +6,8 @@ use futures_core::stream::BoxStream; use futures_util::{future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use crate::database::{Database, HasArguments}; -use crate::describe::Describe; use crate::error::Error; +use crate::statement::StatementInfo; /// A type that contains or can provide a database /// connection to use for executing queries against the database. @@ -130,7 +130,7 @@ pub trait Executor<'c>: Send + Debug + Sized { fn describe<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>; diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index efbb69e4..2a3db584 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -46,9 +46,9 @@ pub mod types; #[macro_use] pub mod query; +pub mod column; mod common; pub mod database; -pub mod describe; pub mod executor; pub mod from_row; mod io; @@ -56,6 +56,7 @@ mod net; pub mod query_as; pub mod query_scalar; pub mod row; +pub mod statement; pub mod type_info; pub mod value; diff --git a/sqlx-core/src/mssql/column.rs b/sqlx-core/src/mssql/column.rs new file mode 100644 index 00000000..82a0e6f4 --- /dev/null +++ b/sqlx-core/src/mssql/column.rs @@ -0,0 +1,42 @@ +use crate::column::Column; +use crate::ext::ustr::UStr; +use crate::mssql::protocol::col_meta_data::{ColumnData, Flags}; +use crate::mssql::{Mssql, MssqlTypeInfo}; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct MssqlColumn { + pub(crate) ordinal: usize, + pub(crate) name: UStr, + pub(crate) type_info: MssqlTypeInfo, + pub(crate) flags: Flags, +} + +impl crate::column::private_column::Sealed for MssqlColumn {} + +impl MssqlColumn { + pub(crate) fn new(meta: ColumnData, ordinal: usize) -> Self { + Self { + name: UStr::from(meta.col_name), + type_info: MssqlTypeInfo(meta.type_info), + ordinal, + flags: meta.flags, + } + } +} + +impl Column for MssqlColumn { + type Database = Mssql; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &*self.name + } + + fn type_info(&self) -> &MssqlTypeInfo { + &self.type_info + } +} diff --git a/sqlx-core/src/mssql/connection/describe.rs b/sqlx-core/src/mssql/connection/describe.rs new file mode 100644 index 00000000..4651f60f --- /dev/null +++ b/sqlx-core/src/mssql/connection/describe.rs @@ -0,0 +1,98 @@ +use crate::error::Error; +use crate::mssql::protocol::col_meta_data::Flags; +use crate::mssql::protocol::done::Status; +use crate::mssql::protocol::message::Message; +use crate::mssql::protocol::packet::PacketType; +use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest}; +use crate::mssql::{Mssql, MssqlArguments, MssqlConnection}; +use crate::statement::StatementInfo; +use either::Either; +use once_cell::sync::Lazy; +use regex::Regex; + +pub async fn describe( + conn: &mut MssqlConnection, + query: &str, +) -> Result, Error> { + // [sp_prepare] will emit the column meta data + // small issue is that we need to declare all the used placeholders with a "fallback" type + // we currently use regex to collect them; false positives are *okay* but false + // negatives would break the query + let proc = Either::Right(Procedure::Prepare); + + // NOTE: this does not support unicode identifiers; as we don't even support + // named parameters (yet) this is probably fine, for now + + static PARAMS_RE: Lazy = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap()); + + let mut params = String::new(); + let mut num_params = 0; + + for m in PARAMS_RE.captures_iter(query) { + if !params.is_empty() { + params.push_str(","); + } + + params.push_str(&m[0]); + + // NOTE: this means that a query! of `SELECT @p1` will have the macros believe + // it will return nvarchar(1); this is a greater issue with `query!` that we + // we need to circle back to. This doesn't happen much in practice however. + params.push_str(" nvarchar(1)"); + + num_params += 1; + } + + let params = if params.is_empty() { + None + } else { + Some(&*params) + }; + + let mut args = MssqlArguments::default(); + + args.declare("", 0_i32); + args.add_unnamed(params); + args.add_unnamed(query); + args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA + + conn.stream.write_packet( + PacketType::Rpc, + RpcRequest { + transaction_descriptor: conn.stream.transaction_descriptor, + arguments: &args, + procedure: proc, + options: OptionFlags::empty(), + }, + ); + + conn.stream.flush().await?; + conn.stream.wait_until_ready().await?; + conn.stream.pending_done_count += 1; + + loop { + match conn.stream.recv_message().await? { + Message::DoneProc(done) | Message::Done(done) => { + if !done.status.contains(Status::DONE_MORE) { + // done with prepare + conn.stream.handle_done(&done); + break; + } + } + + _ => {} + } + } + + let mut nullable = Vec::with_capacity(conn.stream.columns.len()); + + for col in conn.stream.columns.iter() { + nullable.push(Some(col.flags.contains(Flags::NULLABLE))); + } + + Ok(StatementInfo { + parameters: Some(Either::Right(num_params)), + columns: (*conn.stream.columns).clone(), + nullable, + }) +} diff --git a/sqlx-core/src/mssql/connection/executor.rs b/sqlx-core/src/mssql/connection/executor.rs index a38820c0..c6363ce1 100644 --- a/sqlx-core/src/mssql/connection/executor.rs +++ b/sqlx-core/src/mssql/connection/executor.rs @@ -1,20 +1,18 @@ -use either::Either; -use futures_core::future::BoxFuture; -use futures_core::stream::BoxStream; -use futures_util::TryStreamExt; -use once_cell::sync::Lazy; -use regex::Regex; - -use crate::describe::{Column, Describe}; use crate::error::Error; use crate::executor::{Execute, Executor}; -use crate::mssql::protocol::col_meta_data::Flags; +use crate::mssql::connection::describe::describe; use crate::mssql::protocol::done::Status; use crate::mssql::protocol::message::Message; use crate::mssql::protocol::packet::PacketType; use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest}; use crate::mssql::protocol::sql_batch::SqlBatch; -use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow, MssqlTypeInfo}; +use crate::mssql::{Mssql, MssqlArguments, MssqlConnection, MssqlRow}; +use crate::statement::StatementInfo; +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::{FutureExt, TryStreamExt}; +use std::sync::Arc; impl MssqlConnection { async fn run(&mut self, query: &str, arguments: Option) -> Result<(), Error> { @@ -84,7 +82,10 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { match message { Message::Row(row) => { - r#yield!(Either::Right(MssqlRow { row })); + let columns = Arc::clone(&self.stream.columns); + let column_names = Arc::clone(&self.stream.column_names); + + r#yield!(Either::Right(MssqlRow { row, column_names, columns })); } Message::Done(done) | Message::DoneProc(done) => { @@ -139,98 +140,11 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { fn describe<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { - let s = query.query(); - - // [sp_prepare] will emit the column meta data - // small issue is that we need to declare all the used placeholders with a "fallback" type - // we currently use regex to collect them; false positives are *okay* but false - // negatives would break the query - let proc = Either::Right(Procedure::Prepare); - - // NOTE: this does not support unicode identifiers; as we don't even support - // named parameters (yet) this is probably fine, for now - - static PARAMS_RE: Lazy = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap()); - - let mut params = String::new(); - let mut num_params = 0; - - for m in PARAMS_RE.captures_iter(s) { - if !params.is_empty() { - params.push_str(","); - } - - params.push_str(&m[0]); - - // NOTE: this means that a query! of `SELECT @p1` will have the macros believe - // it will return nvarchar(1); this is a greater issue with `query!` that we - // we need to circle back to. This doesn't happen much in practice however. - params.push_str(" nvarchar(1)"); - - num_params += 1; - } - - let params = if params.is_empty() { - None - } else { - Some(&*params) - }; - - let mut args = MssqlArguments::default(); - - args.declare("", 0_i32); - args.add_unnamed(params); - args.add_unnamed(s); - args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA - - self.stream.write_packet( - PacketType::Rpc, - RpcRequest { - transaction_descriptor: self.stream.transaction_descriptor, - arguments: &args, - procedure: proc, - options: OptionFlags::empty(), - }, - ); - - Box::pin(async move { - self.stream.flush().await?; - self.stream.wait_until_ready().await?; - self.stream.pending_done_count += 1; - - loop { - match self.stream.recv_message().await? { - Message::DoneProc(done) | Message::Done(done) => { - if !done.status.contains(Status::DONE_MORE) { - // done with prepare - self.stream.handle_done(&done); - break; - } - } - - _ => {} - } - } - - let mut columns = Vec::with_capacity(self.stream.columns.len()); - - for col in &self.stream.columns { - columns.push(Column { - name: col.col_name.clone(), - type_info: Some(MssqlTypeInfo(col.type_info.clone())), - not_null: Some(!col.flags.contains(Flags::NULLABLE)), - }); - } - - Ok(Describe { - params: vec![None; num_params], - columns, - }) - }) + describe(self, query.query()).boxed() } } diff --git a/sqlx-core/src/mssql/connection/mod.rs b/sqlx-core/src/mssql/connection/mod.rs index 1f4166b8..a3862282 100644 --- a/sqlx-core/src/mssql/connection/mod.rs +++ b/sqlx-core/src/mssql/connection/mod.rs @@ -1,15 +1,14 @@ -use std::fmt::{self, Debug, Formatter}; -use std::net::Shutdown; - -use futures_core::future::BoxFuture; -use futures_util::{future::ready, FutureExt, TryFutureExt}; - use crate::connection::{Connect, Connection}; use crate::error::Error; use crate::executor::Executor; use crate::mssql::connection::stream::MssqlStream; use crate::mssql::{Mssql, MssqlConnectOptions}; +use futures_core::future::BoxFuture; +use futures_util::{future::ready, FutureExt, TryFutureExt}; +use std::fmt::{self, Debug, Formatter}; +use std::net::Shutdown; +mod describe; mod establish; mod executor; mod stream; diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index f3088530..4d65b090 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -4,8 +4,9 @@ use bytes::Bytes; use sqlx_rt::TcpStream; use crate::error::Error; +use crate::ext::ustr::UStr; use crate::io::{BufStream, Encode}; -use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData}; +use crate::mssql::protocol::col_meta_data::ColMetaData; use crate::mssql::protocol::done::{Done, Status as DoneStatus}; use crate::mssql::protocol::env_change::EnvChange; use crate::mssql::protocol::error::Error as ProtocolError; @@ -17,8 +18,10 @@ use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status}; use crate::mssql::protocol::return_status::ReturnStatus; use crate::mssql::protocol::return_value::ReturnValue; use crate::mssql::protocol::row::Row; -use crate::mssql::{MssqlConnectOptions, MssqlDatabaseError}; +use crate::mssql::{MssqlColumn, MssqlConnectOptions, MssqlDatabaseError}; use crate::net::MaybeTlsStream; +use hashbrown::HashMap; +use std::sync::Arc; pub(crate) struct MssqlStream { inner: BufStream>, @@ -35,7 +38,8 @@ pub(crate) struct MssqlStream { // most recent column data from ColMetaData // we need to store this as its needed when decoding - pub(crate) columns: Vec, + pub(crate) columns: Arc>, + pub(crate) column_names: Arc>, } impl MssqlStream { @@ -46,7 +50,8 @@ impl MssqlStream { Ok(Self { inner, - columns: Vec::new(), + columns: Default::default(), + column_names: Default::default(), response: None, pending_done_count: 0, transaction_descriptor: 0, @@ -159,7 +164,11 @@ impl MssqlStream { MessageType::ColMetaData => { // NOTE: there isn't anything to return as the data gets // consumed by the stream for use in subsequent Row decoding - ColMetaData::get(buf, &mut self.columns)?; + ColMetaData::get( + buf, + Arc::make_mut(&mut self.columns), + Arc::make_mut(&mut self.column_names), + )?; continue; } }; diff --git a/sqlx-core/src/mssql/database.rs b/sqlx-core/src/mssql/database.rs index a5bea016..604a6467 100644 --- a/sqlx-core/src/mssql/database.rs +++ b/sqlx-core/src/mssql/database.rs @@ -1,7 +1,7 @@ use crate::database::{Database, HasArguments, HasValueRef}; use crate::mssql::{ - MssqlArguments, MssqlConnection, MssqlRow, MssqlTransactionManager, MssqlTypeInfo, MssqlValue, - MssqlValueRef, + MssqlArguments, MssqlColumn, MssqlConnection, MssqlRow, MssqlTransactionManager, MssqlTypeInfo, + MssqlValue, MssqlValueRef, }; /// MSSQL database driver. @@ -15,6 +15,8 @@ impl Database for Mssql { type Row = MssqlRow; + type Column = MssqlColumn; + type TypeInfo = MssqlTypeInfo; type Value = MssqlValue; diff --git a/sqlx-core/src/mssql/mod.rs b/sqlx-core/src/mssql/mod.rs index 22cc58de..d473564e 100644 --- a/sqlx-core/src/mssql/mod.rs +++ b/sqlx-core/src/mssql/mod.rs @@ -1,6 +1,7 @@ //! Microsoft SQL (MSSQL) database driver. mod arguments; +mod column; mod connection; mod database; mod error; @@ -14,6 +15,7 @@ pub mod types; mod value; pub use arguments::MssqlArguments; +pub use column::MssqlColumn; pub use connection::MssqlConnection; pub use database::Mssql; pub use error::MssqlDatabaseError; diff --git a/sqlx-core/src/mssql/protocol/col_meta_data.rs b/sqlx-core/src/mssql/protocol/col_meta_data.rs index 19f81bb8..73344372 100644 --- a/sqlx-core/src/mssql/protocol/col_meta_data.rs +++ b/sqlx-core/src/mssql/protocol/col_meta_data.rs @@ -2,8 +2,11 @@ use bitflags::bitflags; use bytes::{Buf, Bytes}; use crate::error::Error; +use crate::ext::ustr::UStr; use crate::mssql::io::MssqlBufExt; use crate::mssql::protocol::type_info::TypeInfo; +use crate::mssql::MssqlColumn; +use hashbrown::HashMap; #[derive(Debug)] pub(crate) struct ColMetaData; @@ -73,10 +76,16 @@ bitflags! { } impl ColMetaData { - pub(crate) fn get(buf: &mut Bytes, columns: &mut Vec) -> Result<(), Error> { + pub(crate) fn get( + buf: &mut Bytes, + columns: &mut Vec, + column_names: &mut HashMap, + ) -> Result<(), Error> { columns.clear(); + column_names.clear(); let mut count = buf.get_u16_le(); + let mut ordinal = 0; if count == 0xffff { // In the event that the client requested no metadata to be returned, the value of @@ -88,8 +97,13 @@ impl ColMetaData { } while count > 0 { - columns.push(ColumnData::get(buf)?); + let col = MssqlColumn::new(ColumnData::get(buf)?, ordinal); + + column_names.insert(col.name.clone(), ordinal); + columns.push(col); + count -= 1; + ordinal += 1; } Ok(()) diff --git a/sqlx-core/src/mssql/protocol/row.rs b/sqlx-core/src/mssql/protocol/row.rs index 02d21a53..1094019a 100644 --- a/sqlx-core/src/mssql/protocol/row.rs +++ b/sqlx-core/src/mssql/protocol/row.rs @@ -2,13 +2,10 @@ use bytes::Bytes; use crate::error::Error; use crate::io::BufExt; -use crate::mssql::protocol::col_meta_data::ColumnData; -use crate::mssql::MssqlTypeInfo; +use crate::mssql::{MssqlColumn, MssqlTypeInfo}; #[derive(Debug)] pub(crate) struct Row { - // TODO: Column names? - // FIXME: Columns Vec should be an Arc<_> pub(crate) column_types: Vec, pub(crate) values: Vec>, } @@ -17,7 +14,7 @@ impl Row { pub(crate) fn get( buf: &mut Bytes, nullable: bool, - columns: &[ColumnData], + columns: &[MssqlColumn], ) -> Result { let mut values = Vec::with_capacity(columns.len()); let mut column_types = Vec::with_capacity(columns.len()); @@ -29,10 +26,11 @@ impl Row { }; for (i, column) in columns.iter().enumerate() { - column_types.push(MssqlTypeInfo(column.type_info.clone())); + column_types.push(column.type_info.clone()); - if !(column.type_info.is_null() || (nullable && (nulls[i / 8] & (1 << (i % 8))) != 0)) { - values.push(column.type_info.get_value(buf)); + if !(column.type_info.0.is_null() || (nullable && (nulls[i / 8] & (1 << (i % 8))) != 0)) + { + values.push(column.type_info.0.get_value(buf)); } else { values.push(None); } diff --git a/sqlx-core/src/mssql/row.rs b/sqlx-core/src/mssql/row.rs index 59a5bed4..309571ed 100644 --- a/sqlx-core/src/mssql/row.rs +++ b/sqlx-core/src/mssql/row.rs @@ -1,10 +1,15 @@ use crate::error::Error; +use crate::ext::ustr::UStr; use crate::mssql::protocol::row::Row as ProtocolRow; -use crate::mssql::{Mssql, MssqlValueRef}; +use crate::mssql::{Mssql, MssqlColumn, MssqlValueRef}; use crate::row::{ColumnIndex, Row}; +use hashbrown::HashMap; +use std::sync::Arc; pub struct MssqlRow { pub(crate) row: ProtocolRow, + pub(crate) columns: Arc>, + pub(crate) column_names: Arc>, } impl crate::row::private_row::Sealed for MssqlRow {} @@ -12,9 +17,8 @@ impl crate::row::private_row::Sealed for MssqlRow {} impl Row for MssqlRow { type Database = Mssql; - #[inline] - fn len(&self) -> usize { - self.row.values.len() + fn columns(&self) -> &[MssqlColumn] { + &*self.columns } fn try_get_raw(&self, index: I) -> Result, Error> @@ -30,3 +34,12 @@ impl Row for MssqlRow { Ok(value) } } + +impl ColumnIndex for &'_ str { + fn index(&self, row: &MssqlRow) -> Result { + row.column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .map(|v| *v) + } +} diff --git a/sqlx-core/src/mssql/value.rs b/sqlx-core/src/mssql/value.rs index 3eae4bbc..55210056 100644 --- a/sqlx-core/src/mssql/value.rs +++ b/sqlx-core/src/mssql/value.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use bytes::Bytes; use crate::error::{BoxDynError, UnexpectedNullError}; @@ -32,8 +30,8 @@ impl ValueRef<'_> for MssqlValueRef<'_> { } } - fn type_info(&self) -> Option> { - Some(Cow::Borrowed(&self.type_info)) + fn type_info(&self) -> &MssqlTypeInfo { + &self.type_info } fn is_null(&self) -> bool { @@ -58,8 +56,8 @@ impl Value for MssqlValue { } } - fn type_info(&self) -> Option> { - Some(Cow::Borrowed(&self.type_info)) + fn type_info(&self) -> &MssqlTypeInfo { + &self.type_info } fn is_null(&self) -> bool { diff --git a/sqlx-core/src/mysql/column.rs b/sqlx-core/src/mysql/column.rs new file mode 100644 index 00000000..fac93b3b --- /dev/null +++ b/sqlx-core/src/mysql/column.rs @@ -0,0 +1,29 @@ +use crate::column::Column; +use crate::ext::ustr::UStr; +use crate::mysql::{MySql, MySqlTypeInfo}; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct MySqlColumn { + pub(crate) ordinal: usize, + pub(crate) name: UStr, + pub(crate) type_info: MySqlTypeInfo, +} + +impl crate::column::private_column::Sealed for MySqlColumn {} + +impl Column for MySqlColumn { + type Database = MySql; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &*self.name + } + + fn type_info(&self) -> &MySqlTypeInfo { + &self.type_info + } +} diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 7b28f360..ab9ba852 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -7,7 +7,6 @@ use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; -use crate::describe::{Column, Describe}; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; @@ -19,10 +18,10 @@ use crate::mysql::protocol::statement::{ }; use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow}; use crate::mysql::protocol::Packet; -use crate::mysql::row::MySqlColumn; use crate::mysql::{ - MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo, MySqlValueFormat, + MySql, MySqlArguments, MySqlColumn, MySqlConnection, MySqlRow, MySqlTypeInfo, MySqlValueFormat, }; +use crate::statement::StatementInfo; impl MySqlConnection { async fn prepare(&mut self, query: &str) -> Result { @@ -83,22 +82,23 @@ impl MySqlConnection { for i in 0..num_columns { let def: ColumnDefinition = self.stream.recv().await?; - let name = (match (def.name()?, def.alias()?) { - (_, alias) if !alias.is_empty() => Some(alias), + let name = match (def.name()?, def.alias()?) { + (_, alias) if !alias.is_empty() => UStr::new(alias), - (name, _) if !name.is_empty() => Some(name), + (name, _) if !name.is_empty() => UStr::new(name), - _ => None, - }) - .map(UStr::new); + _ => UStr::from(""), + }; - if let Some(name) = &name { - column_names.insert(name.clone(), i as usize); - } + column_names.insert(name.clone(), i as usize); let type_info = MySqlTypeInfo::from_column(&def); - columns.push(MySqlColumn { name, type_info }); + columns.push(MySqlColumn { + name, + type_info, + ordinal: i as usize, + }); } self.stream.maybe_recv_eof().await?; @@ -245,7 +245,10 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { } #[doc(hidden)] - fn describe<'e, 'q: 'e, E: 'q>(self, query: E) -> BoxFuture<'e, Result, Error>> + fn describe<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, @@ -257,14 +260,12 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { let ok: PrepareOk = self.stream.recv().await?; - let mut params = Vec::with_capacity(ok.params as usize); let mut columns = Vec::with_capacity(ok.columns as usize); + let mut nullable = Vec::with_capacity(ok.columns as usize); if ok.params > 0 { for _ in 0..ok.params { - let def: ColumnDefinition = self.stream.recv().await?; - - params.push(MySqlTypeInfo::from_column(&def)); + let _ = self.stream.recv_packet().await?; } self.stream.maybe_recv_eof().await?; @@ -275,22 +276,28 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { // once more on execute so we wait for that if ok.columns > 0 { - for _ in 0..(ok.columns as usize) { + for ordinal in 0..(ok.columns as usize) { let def: ColumnDefinition = self.stream.recv().await?; let ty = MySqlTypeInfo::from_column(&def); let alias = def.alias()?; - columns.push(Column { - name: if alias.is_empty() { def.name()? } else { alias }.to_owned(), + nullable.push(Some(!def.flags.contains(ColumnFlags::NOT_NULL))); + + columns.push(MySqlColumn { + ordinal, + name: UStr::new(if alias.is_empty() { def.name()? } else { alias }), type_info: ty, - not_null: Some(def.flags.contains(ColumnFlags::NOT_NULL)), }) } self.stream.maybe_recv_eof().await?; } - Ok(Describe { params, columns }) + Ok(StatementInfo { + parameters: Some(Either::Right(ok.params as usize)), + columns, + nullable, + }) }) } } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index ccdc1260..e7e61c4c 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -12,8 +12,7 @@ use crate::executor::Executor; use crate::ext::ustr::UStr; use crate::mysql::protocol::statement::StmtClose; use crate::mysql::protocol::text::{Ping, Quit}; -use crate::mysql::row::MySqlColumn; -use crate::mysql::{MySql, MySqlConnectOptions}; +use crate::mysql::{MySql, MySqlColumn, MySqlConnectOptions}; mod auth; mod establish; diff --git a/sqlx-core/src/mysql/database.rs b/sqlx-core/src/mysql/database.rs index 6bb80832..178414db 100644 --- a/sqlx-core/src/mysql/database.rs +++ b/sqlx-core/src/mysql/database.rs @@ -1,7 +1,7 @@ use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::mysql::value::{MySqlValue, MySqlValueRef}; use crate::mysql::{ - MySqlArguments, MySqlConnection, MySqlRow, MySqlTransactionManager, MySqlTypeInfo, + MySqlArguments, MySqlColumn, MySqlConnection, MySqlRow, MySqlTransactionManager, MySqlTypeInfo, }; /// MySQL database driver. @@ -15,6 +15,8 @@ impl Database for MySql { type Row = MySqlRow; + type Column = MySqlColumn; + type TypeInfo = MySqlTypeInfo; type Value = MySqlValue; diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index aca0cf22..4d34ea78 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -1,6 +1,7 @@ //! **MySQL** database driver. mod arguments; +mod column; mod connection; mod database; mod error; @@ -14,6 +15,7 @@ pub mod types; mod value; pub use arguments::MySqlArguments; +pub use column::MySqlColumn; pub use connection::MySqlConnection; pub use database::MySql; pub use error::MySqlDatabaseError; diff --git a/sqlx-core/src/mysql/protocol/row.rs b/sqlx-core/src/mysql/protocol/row.rs index 60e79c30..f027dada 100644 --- a/sqlx-core/src/mysql/protocol/row.rs +++ b/sqlx-core/src/mysql/protocol/row.rs @@ -9,10 +9,6 @@ pub(crate) struct Row { } impl Row { - pub(crate) fn len(&self) -> usize { - self.values.len() - } - pub(crate) fn get(&self, index: usize) -> Option<&[u8]> { self.values[index] .as_ref() diff --git a/sqlx-core/src/mysql/protocol/statement/row.rs b/sqlx-core/src/mysql/protocol/statement/row.rs index 29723daf..b3760de9 100644 --- a/sqlx-core/src/mysql/protocol/statement/row.rs +++ b/sqlx-core/src/mysql/protocol/statement/row.rs @@ -5,7 +5,7 @@ use crate::io::{BufExt, Decode}; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::text::ColumnType; use crate::mysql::protocol::Row; -use crate::mysql::row::MySqlColumn; +use crate::mysql::MySqlColumn; // https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html#packet-ProtocolBinary::ResultsetRow // https://dev.mysql.com/doc/internals/en/binary-protocol-value.html @@ -43,7 +43,7 @@ impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow { } // NOTE: MySQL will never generate NULL types for non-NULL values - let type_info = column.type_info.as_ref().unwrap(); + let type_info = &column.type_info; let size: usize = match type_info.r#type { ColumnType::String diff --git a/sqlx-core/src/mysql/protocol/text/row.rs b/sqlx-core/src/mysql/protocol/text/row.rs index c1a1e156..17d32f10 100644 --- a/sqlx-core/src/mysql/protocol/text/row.rs +++ b/sqlx-core/src/mysql/protocol/text/row.rs @@ -4,7 +4,7 @@ use crate::error::Error; use crate::io::Decode; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::Row; -use crate::mysql::row::MySqlColumn; +use crate::mysql::MySqlColumn; #[derive(Debug)] pub(crate) struct TextRow(pub(crate) Row); diff --git a/sqlx-core/src/mysql/row.rs b/sqlx-core/src/mysql/row.rs index 8305ee33..7d361358 100644 --- a/sqlx-core/src/mysql/row.rs +++ b/sqlx-core/src/mysql/row.rs @@ -4,16 +4,9 @@ use hashbrown::HashMap; use crate::error::Error; use crate::ext::ustr::UStr; -use crate::mysql::{protocol, MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; +use crate::mysql::{protocol, MySql, MySqlColumn, MySqlValueFormat, MySqlValueRef}; use crate::row::{ColumnIndex, Row}; -// TODO: Merge with the other XXColumn types -#[derive(Debug, Clone)] -pub(crate) struct MySqlColumn { - pub(crate) name: Option, - pub(crate) type_info: Option, -} - /// Implementation of [`Row`] for MySQL. #[derive(Debug)] pub struct MySqlRow { @@ -28,9 +21,8 @@ impl crate::row::private_row::Sealed for MySqlRow {} impl Row for MySqlRow { type Database = MySql; - #[inline] - fn len(&self) -> usize { - self.row.len() + fn columns(&self) -> &[MySqlColumn] { + &self.columns } fn try_get_raw(&self, index: I) -> Result, Error> diff --git a/sqlx-core/src/mysql/type_info.rs b/sqlx-core/src/mysql/type_info.rs index f9ac50d3..b7dce12f 100644 --- a/sqlx-core/src/mysql/type_info.rs +++ b/sqlx-core/src/mysql/type_info.rs @@ -44,15 +44,11 @@ impl MySqlTypeInfo { } } - pub(crate) fn from_column(column: &ColumnDefinition) -> Option { - if column.r#type == ColumnType::Null { - None - } else { - Some(Self { - r#type: column.r#type, - flags: column.flags, - char_set: column.char_set, - }) + pub(crate) fn from_column(column: &ColumnDefinition) -> Self { + Self { + r#type: column.r#type, + flags: column.flags, + char_set: column.char_set, } } } diff --git a/sqlx-core/src/mysql/value.rs b/sqlx-core/src/mysql/value.rs index e56e75af..82b2e903 100644 --- a/sqlx-core/src/mysql/value.rs +++ b/sqlx-core/src/mysql/value.rs @@ -1,12 +1,9 @@ -use std::borrow::Cow; -use std::str::from_utf8; - -use bytes::Bytes; - use crate::error::{BoxDynError, UnexpectedNullError}; use crate::mysql::protocol::text::ColumnType; use crate::mysql::{MySql, MySqlTypeInfo}; use crate::value::{Value, ValueRef}; +use bytes::Bytes; +use std::str::from_utf8; #[derive(Debug, Clone, Copy)] #[repr(u8)] @@ -19,7 +16,7 @@ pub enum MySqlValueFormat { #[derive(Clone)] pub struct MySqlValue { value: Option, - type_info: Option, + type_info: MySqlTypeInfo, format: MySqlValueFormat, } @@ -28,7 +25,7 @@ pub struct MySqlValue { pub struct MySqlValueRef<'r> { pub(crate) value: Option<&'r [u8]>, pub(crate) row: Option<&'r Bytes>, - pub(crate) type_info: Option, + pub(crate) type_info: MySqlTypeInfo, pub(crate) format: MySqlValueFormat, } @@ -61,12 +58,12 @@ impl Value for MySqlValue { } } - fn type_info(&self) -> Option> { - self.type_info.as_ref().map(Cow::Borrowed) + fn type_info(&self) -> &MySqlTypeInfo { + &self.type_info } fn is_null(&self) -> bool { - is_null(self.value.as_deref(), self.type_info.as_ref()) + is_null(self.value.as_deref(), &self.type_info) } } @@ -89,18 +86,18 @@ impl<'r> ValueRef<'r> for MySqlValueRef<'r> { } } - fn type_info(&self) -> Option> { - self.type_info.as_ref().map(Cow::Borrowed) + fn type_info(&self) -> &MySqlTypeInfo { + &self.type_info } #[inline] fn is_null(&self) -> bool { - is_null(self.value.as_deref(), self.type_info.as_ref()) + is_null(self.value.as_deref(), &self.type_info) } } -fn is_null(value: Option<&[u8]>, ty: Option<&MySqlTypeInfo>) -> bool { - if let (Some(value), Some(ty)) = (value, ty) { +fn is_null(value: Option<&[u8]>, ty: &MySqlTypeInfo) -> bool { + if let Some(value) = value { // zero dates and date times should be treated the same as NULL if matches!( ty.r#type, diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index ddb31f0d..7c8a6717 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -4,10 +4,10 @@ use futures_core::stream::BoxStream; use futures_util::TryStreamExt; use crate::database::Database; -use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::pool::Pool; +use crate::statement::StatementInfo; impl<'p, DB: Database> Executor<'p> for &'_ Pool where @@ -52,7 +52,7 @@ where fn describe<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result, Error>> where E: Execute<'q, Self::Database>, { @@ -103,7 +103,7 @@ macro_rules! impl_executor_for_pool_connection { query: E, ) -> futures_core::future::BoxFuture< 'e, - Result, crate::error::Error>, + Result, crate::error::Error>, > where 'c: 'e, diff --git a/sqlx-core/src/postgres/column.rs b/sqlx-core/src/postgres/column.rs new file mode 100644 index 00000000..64ea6067 --- /dev/null +++ b/sqlx-core/src/postgres/column.rs @@ -0,0 +1,31 @@ +use crate::column::Column; +use crate::ext::ustr::UStr; +use crate::postgres::{PgTypeInfo, Postgres}; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct PgColumn { + pub(crate) ordinal: usize, + pub(crate) name: UStr, + pub(crate) type_info: PgTypeInfo, + pub(crate) relation_id: Option, + pub(crate) relation_attribute_no: Option, +} + +impl crate::column::private_column::Sealed for PgColumn {} + +impl Column for PgColumn { + type Database = Postgres; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &*self.name + } + + fn type_info(&self) -> &PgTypeInfo { + &self.type_info + } +} diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index cfbb1cb6..07424173 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -1,20 +1,15 @@ -use std::fmt::Write; -use std::mem; -use std::sync::Arc; - -use futures_util::{stream, StreamExt, TryStreamExt}; -use hashbrown::HashMap; - -use crate::describe::Column; use crate::error::Error; use crate::ext::ustr::UStr; use crate::postgres::message::{ParameterDescription, RowDescription}; -use crate::postgres::row::PgColumn; use crate::postgres::type_info::{PgCustomType, PgType, PgTypeKind}; -use crate::postgres::{PgArguments, PgConnection, PgTypeInfo, Postgres}; -use crate::query_as::{query_as, query_as_with}; -use crate::query_scalar::query_scalar; +use crate::postgres::{PgArguments, PgColumn, PgConnection, PgTypeInfo}; +use crate::query_as::query_as; +use crate::query_scalar::{query_scalar, query_scalar_with}; use futures_core::future::BoxFuture; +use hashbrown::HashMap; +use std::fmt::Write; +use std::mem; +use std::sync::Arc; impl PgConnection { pub(super) async fn handle_row_description( @@ -52,6 +47,7 @@ impl PgConnection { .await?; let column = PgColumn { + ordinal: index, name: name.clone(), type_info, relation_id: field.relation_id, @@ -74,11 +70,11 @@ impl PgConnection { pub(super) async fn handle_parameter_description( &mut self, desc: ParameterDescription, - ) -> Result>, Error> { + ) -> Result, Error> { let mut params = Vec::with_capacity(desc.types.len()); for ty in desc.types { - params.push(Some(self.maybe_fetch_type_info_by_oid(ty, true).await?)); + params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?); } Ok(params) @@ -253,15 +249,15 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 Ok(oid) } - pub(crate) async fn map_result_columns( + pub(crate) async fn get_nullable_for_columns( &mut self, columns: &[PgColumn], - ) -> Result>, Error> { + ) -> Result>, Error> { if columns.is_empty() { return Ok(vec![]); } - let mut query = String::from("SELECT col.idx, pg_attribute.attnotnull FROM (VALUES "); + let mut query = String::from("SELECT NOT pg_attribute.attnotnull FROM (VALUES "); let mut args = PgArguments::default(); for (i, (column, bind)) in columns.iter().zip((1..).step_by(3)).enumerate() { @@ -291,22 +287,8 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 ORDER BY col.idx", ); - query_as_with::<_, (i32, Option), _>(&query, args) - .fetch(self) - .zip(stream::iter(columns.iter().enumerate())) - .map(|(row, (field_idx, column))| -> Result, Error> { - let (idx, not_null) = row?; - - // NOTE: it should be impossible for this to fire - debug_assert_eq!(idx, field_idx as i32); - - Ok(Column { - name: column.name.to_string(), - type_info: Some(column.type_info.clone()), - not_null, - }) - }) - .try_collect() + query_scalar_with::<_, Option, _>(&query, args) + .fetch_all(self) .await } } diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index 5cb7eeed..4af0a9d1 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -5,7 +5,6 @@ use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; use std::sync::Arc; -use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::postgres::message::{ @@ -14,6 +13,7 @@ use crate::postgres::message::{ }; use crate::postgres::type_info::PgType; use crate::postgres::{PgArguments, PgConnection, PgRow, PgValueFormat, Postgres}; +use crate::statement::StatementInfo; async fn prepare( conn: &mut PgConnection, @@ -321,7 +321,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { fn describe<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, @@ -343,10 +343,14 @@ impl<'c> Executor<'c> for &'c mut PgConnection { self.handle_row_description(rows, true).await?; - let columns = self.scratch_row_columns.clone(); - let columns = self.map_result_columns(&columns).await?; + let columns = (&*self.scratch_row_columns).clone(); + let nullable = self.get_nullable_for_columns(&columns).await?; - Ok(Describe { params, columns }) + Ok(StatementInfo { + columns, + nullable, + parameters: Some(Either::Left(params)), + }) }) } } diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index d4f6ea6a..31475708 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -15,8 +15,7 @@ use crate::postgres::connection::stream::PgStream; use crate::postgres::message::{ Close, Flush, Message, MessageFormat, ReadyForQuery, Terminate, TransactionStatus, }; -use crate::postgres::row::PgColumn; -use crate::postgres::{PgConnectOptions, PgTypeInfo, Postgres}; +use crate::postgres::{PgColumn, PgConnectOptions, PgTypeInfo, Postgres}; pub(crate) mod describe; mod establish; diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs index 8b3af756..562dcda0 100644 --- a/sqlx-core/src/postgres/database.rs +++ b/sqlx-core/src/postgres/database.rs @@ -1,7 +1,9 @@ use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::postgres::arguments::PgArgumentBuffer; use crate::postgres::value::{PgValue, PgValueRef}; -use crate::postgres::{PgArguments, PgConnection, PgRow, PgTransactionManager, PgTypeInfo}; +use crate::postgres::{ + PgArguments, PgColumn, PgConnection, PgRow, PgTransactionManager, PgTypeInfo, +}; /// PostgreSQL database driver. #[derive(Debug)] @@ -14,6 +16,8 @@ impl Database for Postgres { type Row = PgRow; + type Column = PgColumn; + type TypeInfo = PgTypeInfo; type Value = PgValue; diff --git a/sqlx-core/src/postgres/listener.rs b/sqlx-core/src/postgres/listener.rs index 90303c1d..d80b5541 100644 --- a/sqlx-core/src/postgres/listener.rs +++ b/sqlx-core/src/postgres/listener.rs @@ -6,12 +6,12 @@ use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; -use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::pool::{Pool, PoolConnection}; use crate::postgres::message::{MessageFormat, Notification}; use crate::postgres::{PgConnection, PgRow, Postgres}; +use crate::statement::StatementInfo; use either::Either; /// A stream of asynchronous notifications from Postgres. @@ -221,7 +221,7 @@ impl<'c> Executor<'c> for &'c mut PgListener { fn describe<'e, 'q: 'e, E: 'q>( self, query: E, - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, diff --git a/sqlx-core/src/postgres/message/data_row.rs b/sqlx-core/src/postgres/message/data_row.rs index 3ec28103..a7377f68 100644 --- a/sqlx-core/src/postgres/message/data_row.rs +++ b/sqlx-core/src/postgres/message/data_row.rs @@ -18,11 +18,6 @@ pub struct DataRow { } impl DataRow { - #[inline] - pub(crate) fn len(&self) -> usize { - self.values.len() - } - #[inline] pub(crate) fn get(&self, index: usize) -> Option<&'_ [u8]> { self.values[index] diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 60c335a4..13a9aa5e 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -1,6 +1,7 @@ //! **PostgreSQL** database driver. mod arguments; +mod column; mod connection; mod database; mod error; @@ -15,6 +16,7 @@ pub mod types; mod value; pub use arguments::{PgArgumentBuffer, PgArguments}; +pub use column::PgColumn; pub use connection::PgConnection; pub use database::Postgres; pub use error::{PgDatabaseError, PgErrorPosition}; diff --git a/sqlx-core/src/postgres/row.rs b/sqlx-core/src/postgres/row.rs index 63318df7..5cfbe864 100644 --- a/sqlx-core/src/postgres/row.rs +++ b/sqlx-core/src/postgres/row.rs @@ -6,19 +6,9 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::postgres::message::DataRow; use crate::postgres::value::PgValueFormat; -use crate::postgres::{PgTypeInfo, PgValueRef, Postgres}; +use crate::postgres::{PgColumn, PgValueRef, Postgres}; use crate::row::{ColumnIndex, Row}; -// Result column of a prepared statement -// See RowDescription/Field for more information -#[derive(Debug, Clone)] -pub(crate) struct PgColumn { - pub(crate) name: UStr, - pub(crate) type_info: PgTypeInfo, - pub(crate) relation_id: Option, - pub(crate) relation_attribute_no: Option, -} - /// Implementation of [`Row`] for PostgreSQL. pub struct PgRow { pub(crate) data: DataRow, @@ -32,9 +22,8 @@ impl crate::row::private_row::Sealed for PgRow {} impl Row for PgRow { type Database = Postgres; - #[inline] - fn len(&self) -> usize { - self.data.len() + fn columns(&self) -> &[PgColumn] { + &self.columns } fn try_get_raw(&self, index: I) -> Result, Error> diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index c5c351c0..963aa0ab 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -970,8 +970,19 @@ impl PartialEq for PgType { // If there are OIDs available, use OIDs to perform a direct match a == b } else { - // Otherwise, perform a match on the name - self.name().eq_ignore_ascii_case(other.name()) + if (matches!(self, PgType::DeclareWithName(_)) + && matches!(other, PgType::DeclareWithOid(_))) + || (matches!(other, PgType::DeclareWithName(_)) + && matches!(self, PgType::DeclareWithOid(_))) + { + // One is a declare-with-name and the other is a declare-with-id + // This only occurs in the TEXT protocol with custom types + // Just opt-out of type checking here + true + } else { + // Otherwise, perform a match on the name + self.name().eq_ignore_ascii_case(other.name()) + } } } } diff --git a/sqlx-core/src/postgres/value.rs b/sqlx-core/src/postgres/value.rs index 00a34ac2..24ab6ba3 100644 --- a/sqlx-core/src/postgres/value.rs +++ b/sqlx-core/src/postgres/value.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::str::from_utf8; use bytes::{Buf, Bytes}; @@ -81,14 +80,8 @@ impl Value for PgValue { } } - fn type_info(&self) -> Option> { - if self.format == PgValueFormat::Text { - // For TEXT encoding the type defined on the value is unreliable - // We don't even bother to return it so type checking is implicitly opted-out - None - } else { - Some(Cow::Borrowed(&self.type_info)) - } + fn type_info(&self) -> &PgTypeInfo { + &self.type_info } fn is_null(&self) -> bool { @@ -115,14 +108,8 @@ impl<'r> ValueRef<'r> for PgValueRef<'r> { } } - fn type_info(&self) -> Option> { - if self.format == PgValueFormat::Text { - // For TEXT encoding the type defined on the value is unreliable - // We don't even bother to return it so type checking is implicitly opted-out - None - } else { - Some(Cow::Borrowed(&self.type_info)) - } + fn type_info(&self) -> &PgTypeInfo { + &self.type_info } fn is_null(&self) -> bool { diff --git a/sqlx-core/src/row.rs b/sqlx-core/src/row.rs index 42971346..488cff30 100644 --- a/sqlx-core/src/row.rs +++ b/sqlx-core/src/row.rs @@ -75,7 +75,37 @@ pub trait Row: private_row::Sealed + Unpin + Send + Sync + 'static { } /// Returns the number of columns in this row. - fn len(&self) -> usize; + #[inline] + fn len(&self) -> usize { + self.columns().len() + } + + /// Gets the column information at `index`. + /// + /// A string index can be used to access a column by name and a `usize` index + /// can be used to access a column by position. + /// + /// # Panics + /// + /// Panics if `index` is out of bounds. + /// See [`try_column`](#method.try_column) for a non-panicking version. + fn column(&self, index: I) -> &::Column + where + I: ColumnIndex, + { + self.try_column(index).unwrap() + } + + /// Gets the column information at `index` or `None` if out of bounds. + fn try_column(&self, index: I) -> Result<&::Column, Error> + where + I: ColumnIndex, + { + Ok(&self.columns()[index.index(self)?]) + } + + /// Gets all columns in this statement. + fn columns(&self) -> &[::Column]; /// Index into the database row and decode a single value. /// @@ -139,15 +169,13 @@ pub trait Row: private_row::Sealed + Unpin + Send + Sync + 'static { let value = self.try_get_raw(&index)?; if !value.is_null() { - if let Some(ty) = value.type_info() { - // NOTE: we opt-out of asserting the type equivalency of NULL because of the - // high false-positive rate (e.g., `NULL` in Postgres is `TEXT`). - if !T::compatible(&ty) { - return Err(Error::ColumnDecode { - index: format!("{:?}", index), - source: mismatched_types::(&ty), - }); - } + let ty = value.type_info(); + + if !T::compatible(&ty) { + return Err(Error::ColumnDecode { + index: format!("{:?}", index), + source: mismatched_types::(&ty), + }); } } @@ -180,6 +208,7 @@ pub trait Row: private_row::Sealed + Unpin + Send + Sync + 'static { T: Decode<'r, Self::Database>, { let value = self.try_get_raw(&index)?; + T::decode(value).map_err(|source| Error::ColumnDecode { index: format!("{:?}", index), source, diff --git a/sqlx-core/src/sqlite/arguments.rs b/sqlx-core/src/sqlite/arguments.rs index cff26143..c82a0882 100644 --- a/sqlx-core/src/sqlite/arguments.rs +++ b/sqlx-core/src/sqlite/arguments.rs @@ -70,11 +70,13 @@ impl SqliteArguments<'_> { }; if n > self.values.len() { - return Err(err_protocol!( - "wrong number of parameters, parameter ?{} requested but have only {}", - n, - self.values.len() - )); + // SQLite treats unbound variables as NULL + // we reproduce this here + // If you are reading this and think this should be an error, open an issue and we can + // discuss configuring this somehow + // Note that the query macros have a different way of enforcing + // argument arity + break; } self.values[n - 1].bind(handle, param_i)?; diff --git a/sqlx-core/src/sqlite/column.rs b/sqlx-core/src/sqlite/column.rs new file mode 100644 index 00000000..b5dfc3fc --- /dev/null +++ b/sqlx-core/src/sqlite/column.rs @@ -0,0 +1,29 @@ +use crate::column::Column; +use crate::ext::ustr::UStr; +use crate::sqlite::{Sqlite, SqliteTypeInfo}; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct SqliteColumn { + pub(crate) name: UStr, + pub(crate) ordinal: usize, + pub(crate) type_info: SqliteTypeInfo, +} + +impl crate::column::private_column::Sealed for SqliteColumn {} + +impl Column for SqliteColumn { + type Database = Sqlite; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &*self.name + } + + fn type_info(&self) -> &SqliteTypeInfo { + &self.type_info + } +} diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs index 9a0aca87..1e627cb4 100644 --- a/sqlx-core/src/sqlite/connection/describe.rs +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -1,42 +1,31 @@ -use crate::describe::{Column, Describe}; use crate::error::Error; use crate::sqlite::connection::explain::explain; use crate::sqlite::statement::SqliteStatement; use crate::sqlite::type_info::DataType; -use crate::sqlite::{Sqlite, SqliteConnection, SqliteTypeInfo}; +use crate::sqlite::{Sqlite, SqliteColumn, SqliteConnection}; +use crate::statement::StatementInfo; +use either::Either; use futures_core::future::BoxFuture; +use std::convert::identity; -pub(super) async fn describe( - conn: &mut SqliteConnection, - query: &str, -) -> Result, Error> { - describe_with(conn, query, vec![]).await -} - -pub(super) fn describe_with<'c: 'e, 'q: 'e, 'e>( +pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( conn: &'c mut SqliteConnection, query: &'q str, - fallback: Vec, -) -> BoxFuture<'e, Result, Error>> { +) -> BoxFuture<'e, Result, Error>> { Box::pin(async move { // describing a statement from SQLite can be involved // each SQLx statement is comprised of multiple SQL statements - let SqliteConnection { - ref mut handle, - ref worker, - .. - } = conn; - - let statement = SqliteStatement::prepare(handle, query, false); + let statement = SqliteStatement::prepare(&mut conn.handle, query, false); let mut columns = Vec::new(); + let mut nullable = Vec::new(); let mut num_params = 0; let mut statement = statement?; // we start by finding the first statement that *can* return results - while let Some((statement, _)) = statement.execute()? { + while let Some((statement, ..)) = statement.execute()? { num_params += statement.bind_parameter_count(); let mut stepped = false; @@ -50,6 +39,20 @@ pub(super) fn describe_with<'c: 'e, 'q: 'e, 'e>( // next we try to use [column_decltype] to inspect the type of each column columns.reserve(num); + // as a last resort, we explain the original query and attempt to + // infer what would the expression types be as a fallback + // to [column_decltype] + + // if explain.. fails, ignore the failure and we'll have no fallback + let (fallback, fallback_nullable) = match explain(conn, statement.sql()).await { + Ok(v) => v, + Err(err) => { + log::debug!("describe: explain introspection failed: {}", err); + + (vec![], vec![]) + } + }; + for col in 0..num { let name = statement.column_name(col).to_owned(); @@ -59,32 +62,18 @@ pub(super) fn describe_with<'c: 'e, 'q: 'e, 'e>( // if that fails, we back up and attempt to step the statement // once *if* its read-only and then use [column_type] as a // fallback to [column_decltype] - if !stepped && statement.read_only() && fallback.is_empty() { + if !stepped && statement.read_only() { stepped = true; - worker.execute(statement); - worker.wake(); + conn.worker.execute(statement); + conn.worker.wake(); - let _ = worker.step(statement).await?; + let _ = conn.worker.step(statement).await; } let mut ty = statement.column_type_info(col); if ty.0 == DataType::Null { - if fallback.is_empty() { - // this will _still_ fail if there are no actual rows to return - // this happens more often than not for the macros as we tell - // users to execute against an empty database - - // as a last resort, we explain the original query and attempt to - // infer what would the expression types be as a fallback - // to [column_decltype] - - let fallback = explain(conn, statement.sql()).await?; - - return describe_with(conn, query, fallback).await; - } - if let Some(fallback) = fallback.get(col).cloned() { ty = fallback; } @@ -93,21 +82,23 @@ pub(super) fn describe_with<'c: 'e, 'q: 'e, 'e>( ty }; - let not_null = statement.column_not_null(col)?; + nullable.push(statement.column_nullable(col)?.or_else(|| { + // if we do not *know* if this is nullable, check the EXPLAIN fallback + fallback_nullable.get(col).copied().and_then(identity) + })); - columns.push(Column { - name, - type_info: Some(type_info), - not_null, + columns.push(SqliteColumn { + name: name.into(), + type_info, + ordinal: col, }); } } - // println!("describe ->> {:#?}", columns); - - Ok(Describe { + Ok(StatementInfo { columns, - params: vec![None; num_params], + parameters: Some(Either::Right(num_params)), + nullable, }) }) } diff --git a/sqlx-core/src/sqlite/connection/executor.rs b/sqlx-core/src/sqlite/connection/executor.rs index 43692546..c2e297e2 100644 --- a/sqlx-core/src/sqlite/connection/executor.rs +++ b/sqlx-core/src/sqlite/connection/executor.rs @@ -7,14 +7,14 @@ use futures_util::{FutureExt, TryStreamExt}; use hashbrown::HashMap; use crate::common::StatementCache; -use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; use crate::sqlite::connection::describe::describe; use crate::sqlite::connection::ConnectionHandle; use crate::sqlite::statement::{SqliteStatement, StatementHandle}; -use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow}; +use crate::sqlite::{Sqlite, SqliteArguments, SqliteColumn, SqliteConnection, SqliteRow}; +use crate::statement::StatementInfo; fn prepare<'a>( conn: &mut ConnectionHandle, @@ -59,16 +59,26 @@ fn bind( fn emplace_row_metadata( statement: &StatementHandle, + columns: &mut Vec, column_names: &mut HashMap, ) -> Result<(), Error> { + columns.clear(); column_names.clear(); let num = statement.column_count(); column_names.reserve(num); + columns.reserve(num); for i in 0..num { let name: UStr = statement.column_name(i).to_owned().into(); + let type_info = statement.column_type_info(i); + + columns.push(SqliteColumn { + ordinal: i, + name: name.clone(), + type_info, + }); column_names.insert(name, i); } @@ -106,7 +116,9 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // bind arguments, if any, to the statement bind(&mut stmt, arguments)?; - while let Some((handle, last_row_values)) = stmt.execute()? { + while let Some((handle, columns, last_row_values)) = stmt.execute()? { + let mut have_metadata = false; + // tell the worker about the new statement worker.execute(handle); @@ -114,17 +126,24 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // the worker parks its thread on async-std when not in use worker.wake(); - emplace_row_metadata( - handle, - Arc::make_mut(scratch_row_column_names), - )?; - loop { // save the rows from the _current_ position on the statement // and send them to the still-live row object - SqliteRow::inflate_if_needed(handle, last_row_values.take()); + SqliteRow::inflate_if_needed(handle, &*columns, last_row_values.take()); - match worker.step(handle).await? { + let s = worker.step(handle).await?; + + if !have_metadata { + have_metadata = true; + + emplace_row_metadata( + handle, + Arc::make_mut(columns), + Arc::make_mut(scratch_row_column_names), + )?; + } + + match s { Either::Left(changes) => { r#yield!(Either::Left(changes)); @@ -134,6 +153,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { Either::Right(()) => { let (row, weak_values_ref) = SqliteRow::current( *handle, + columns, scratch_row_column_names ); @@ -172,7 +192,10 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } #[doc(hidden)] - fn describe<'e, 'q: 'e, E: 'q>(self, query: E) -> BoxFuture<'e, Result, Error>> + fn describe<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index 20a5c170..a19cd939 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -3,21 +3,34 @@ use crate::query_as::query_as; use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteConnection, SqliteTypeInfo}; use hashbrown::HashMap; +use std::str::from_utf8; +// affinity +const SQLITE_AFF_NONE: u8 = 0x40; /* '@' */ +const SQLITE_AFF_BLOB: u8 = 0x41; /* 'A' */ +const SQLITE_AFF_TEXT: u8 = 0x42; /* 'B' */ +const SQLITE_AFF_NUMERIC: u8 = 0x43; /* 'C' */ +const SQLITE_AFF_INTEGER: u8 = 0x44; /* 'D' */ +const SQLITE_AFF_REAL: u8 = 0x45; /* 'E' */ + +// opcodes const OP_INIT: &str = "Init"; const OP_GOTO: &str = "Goto"; const OP_COLUMN: &str = "Column"; const OP_AGG_STEP: &str = "AggStep"; +const OP_FUNCTION: &str = "Function"; const OP_MOVE: &str = "Move"; const OP_COPY: &str = "Copy"; const OP_SCOPY: &str = "SCopy"; const OP_INT_COPY: &str = "IntCopy"; +const OP_CAST: &str = "Cast"; const OP_STRING8: &str = "String8"; const OP_INT64: &str = "Int64"; const OP_INTEGER: &str = "Integer"; const OP_REAL: &str = "Real"; const OP_NOT: &str = "Not"; const OP_BLOB: &str = "Blob"; +const OP_VARIABLE: &str = "Variable"; const OP_COUNT: &str = "Count"; const OP_ROWID: &str = "Rowid"; const OP_OR: &str = "Or"; @@ -34,7 +47,19 @@ const OP_REMAINDER: &str = "Remainder"; const OP_CONCAT: &str = "Concat"; const OP_RESULT_ROW: &str = "ResultRow"; -fn to_type(op: &str) -> DataType { +fn affinity_to_type(affinity: u8) -> DataType { + match affinity { + SQLITE_AFF_BLOB => DataType::Blob, + SQLITE_AFF_INTEGER => DataType::Int64, + SQLITE_AFF_NUMERIC => DataType::Numeric, + SQLITE_AFF_REAL => DataType::Float, + SQLITE_AFF_TEXT => DataType::Text, + + SQLITE_AFF_NONE | _ => DataType::Null, + } +} + +fn opcode_to_type(op: &str) -> DataType { match op { OP_REAL => DataType::Float, OP_BLOB => DataType::Blob, @@ -48,11 +73,12 @@ fn to_type(op: &str) -> DataType { pub(super) async fn explain( conn: &mut SqliteConnection, query: &str, -) -> Result, Error> { +) -> Result<(Vec, Vec>), Error> { let mut r = HashMap::::with_capacity(6); + let mut n = HashMap::::with_capacity(6); let program = - query_as::<_, (i64, String, i64, i64, i64, String)>(&*format!("EXPLAIN {}", query)) + query_as::<_, (i64, String, i64, i64, i64, Vec)>(&*format!("EXPLAIN {}", query)) .fetch_all(&mut *conn) .await?; @@ -78,15 +104,46 @@ pub(super) async fn explain( OP_COLUMN => { // r[p3] = r.insert(p3, DataType::Null); + n.insert(p3, true); + } + + OP_VARIABLE => { + // r[p2] = + r.insert(p2, DataType::Null); + n.insert(p3, true); + } + + OP_FUNCTION => { + // r[p1] = func( _ ) + match from_utf8(p4).map_err(Error::protocol)? { + "last_insert_rowid(0)" => { + // last_insert_rowid() -> INTEGER + r.insert(p3, DataType::Int64); + n.insert(p3, false); + } + + _ => {} + } } OP_AGG_STEP => { + let p4 = from_utf8(p4).map_err(Error::protocol)?; + if p4.starts_with("count(") { // count(_) -> INTEGER r.insert(p3, DataType::Int64); + n.insert(p3, false); } else if let Some(v) = r.get(&p2).copied() { // r[p3] = AGG ( r[p2] ) r.insert(p3, v); + n.insert(p3, n.get(&p2).copied().unwrap_or(true)); + } + } + + OP_CAST => { + // affinity(r[p1]) + if let Some(v) = r.get_mut(&p1) { + *v = affinity_to_type(p2 as u8); } } @@ -94,18 +151,21 @@ pub(super) async fn explain( // r[p2] = r[p1] if let Some(v) = r.get(&p1).copied() { r.insert(p2, v); + n.insert(p2, n.get(&p1).copied().unwrap_or(true)); } } OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => { // r[p2] = - r.insert(p2, to_type(&opcode)); + r.insert(p2, opcode_to_type(&opcode)); + n.insert(p2, false); } OP_NOT => { // r[p2] = NOT r[p1] if let Some(a) = r.get(&p1).copied() { r.insert(p2, a); + n.insert(p2, n.get(&p1).copied().unwrap_or(true)); } } @@ -127,16 +187,40 @@ pub(super) async fn explain( _ => {} } + + match (n.get(&p1).copied(), n.get(&p2).copied()) { + (Some(a), Some(b)) => { + n.insert(p3, a || b); + } + + (None, Some(b)) => { + n.insert(p3, b); + } + + (Some(a), None) => { + n.insert(p3, a); + } + + _ => {} + } } OP_RESULT_ROW => { // output = r[p1 .. p1 + p2] let mut output = Vec::with_capacity(p2 as usize); + let mut nullable = Vec::with_capacity(p2 as usize); + for i in p1..p1 + p2 { output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null))); + + nullable.push(if n.remove(&i).unwrap_or(true) { + None + } else { + Some(false) + }); } - return Ok(output); + return Ok((output, nullable)); } _ => { @@ -149,5 +233,5 @@ pub(super) async fn explain( } // no rows - Ok(vec![]) + Ok((vec![], vec![])) } diff --git a/sqlx-core/src/sqlite/database.rs b/sqlx-core/src/sqlite/database.rs index fa00660f..de627f89 100644 --- a/sqlx-core/src/sqlite/database.rs +++ b/sqlx-core/src/sqlite/database.rs @@ -1,7 +1,7 @@ use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::sqlite::{ - SqliteArgumentValue, SqliteArguments, SqliteConnection, SqliteRow, SqliteTransactionManager, - SqliteTypeInfo, SqliteValue, SqliteValueRef, + SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnection, SqliteRow, + SqliteTransactionManager, SqliteTypeInfo, SqliteValue, SqliteValueRef, }; /// Sqlite database driver. @@ -15,6 +15,8 @@ impl Database for Sqlite { type Row = SqliteRow; + type Column = SqliteColumn; + type TypeInfo = SqliteTypeInfo; type Value = SqliteValue; diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs index ff891dc8..1ad42a9a 100644 --- a/sqlx-core/src/sqlite/mod.rs +++ b/sqlx-core/src/sqlite/mod.rs @@ -6,6 +6,7 @@ #![allow(unsafe_code)] mod arguments; +mod column; mod connection; mod database; mod error; @@ -18,6 +19,7 @@ pub mod types; mod value; pub use arguments::{SqliteArgumentValue, SqliteArguments}; +pub use column::SqliteColumn; pub use connection::SqliteConnection; pub use database::Sqlite; pub use error::SqliteError; diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs index 5954d912..6c3e512f 100644 --- a/sqlx-core/src/sqlite/row.rs +++ b/sqlx-core/src/sqlite/row.rs @@ -9,7 +9,7 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::row::{ColumnIndex, Row}; use crate::sqlite::statement::StatementHandle; -use crate::sqlite::{Sqlite, SqliteValue, SqliteValueRef}; +use crate::sqlite::{Sqlite, SqliteColumn, SqliteValue, SqliteValueRef}; /// Implementation of [`Row`] for SQLite. pub struct SqliteRow { @@ -25,6 +25,7 @@ pub struct SqliteRow { pub(crate) values: Arc>, pub(crate) num_values: usize, + pub(crate) columns: Arc>, pub(crate) column_names: Arc>, } @@ -45,6 +46,7 @@ impl SqliteRow { // to increment the statement with [step] pub(crate) fn current( statement: StatementHandle, + columns: &Arc>, column_names: &Arc>, ) -> (Self, Weak>) { let values = Arc::new(AtomicPtr::new(null_mut())); @@ -55,6 +57,7 @@ impl SqliteRow { statement, values, num_values: size, + columns: Arc::clone(columns), column_names: Arc::clone(column_names), }; @@ -63,12 +66,20 @@ impl SqliteRow { // inflates this Row into memory as a list of owned, protected SQLite value objects // this is called by the - pub(crate) fn inflate(statement: &StatementHandle, values_ref: &AtomicPtr) { + pub(crate) fn inflate( + statement: &StatementHandle, + columns: &[SqliteColumn], + values_ref: &AtomicPtr, + ) { let size = statement.column_count(); let mut values = Vec::with_capacity(size); for i in 0..size { - values.push(statement.column_value(i)); + values.push(unsafe { + let raw = statement.column_value(i); + + SqliteValue::new(raw, columns[i].type_info.clone()) + }); } // decay the array signifier and become just a normal, leaked array @@ -80,10 +91,11 @@ impl SqliteRow { pub(crate) fn inflate_if_needed( statement: &StatementHandle, + columns: &[SqliteColumn], weak_values_ref: Option>>, ) { if let Some(v) = weak_values_ref.and_then(|v| v.upgrade()) { - SqliteRow::inflate(statement, &v); + SqliteRow::inflate(statement, &columns, &v); } } } @@ -91,8 +103,8 @@ impl SqliteRow { impl Row for SqliteRow { type Database = Sqlite; - fn len(&self) -> usize { - self.num_values + fn columns(&self) -> &[SqliteColumn] { + &self.columns } fn try_get_raw(&self, index: I) -> Result, Error> @@ -109,7 +121,11 @@ impl Row for SqliteRow { Ok(SqliteValueRef::value(&values[index])) } else { - Ok(SqliteValueRef::statement(&self.statement, index)) + Ok(SqliteValueRef::statement( + &self.statement, + self.columns[index].type_info.clone(), + index, + )) } } } diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index c559f9e2..32f226e8 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -1,7 +1,9 @@ use std::ffi::c_void; use std::ffi::CStr; use std::os::raw::{c_char, c_int}; +use std::ptr; use std::ptr::NonNull; +use std::slice::from_raw_parts; use std::str::{from_utf8, from_utf8_unchecked}; use libsqlite3_sys::{ @@ -12,14 +14,12 @@ use libsqlite3_sys::{ sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_sql, sqlite3_stmt, sqlite3_stmt_readonly, - sqlite3_table_column_metadata, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, + sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, }; use crate::error::{BoxDynError, Error}; use crate::sqlite::type_info::DataType; -use crate::sqlite::{SqliteError, SqliteTypeInfo, SqliteValue}; -use std::ptr; -use std::slice::from_raw_parts; +use crate::sqlite::{SqliteError, SqliteTypeInfo}; #[derive(Debug, Copy, Clone)] pub(crate) struct StatementHandle(pub(super) NonNull); @@ -104,7 +104,7 @@ impl StatementHandle { } } - pub(crate) fn column_not_null(&self, index: usize) -> Result, Error> { + pub(crate) fn column_nullable(&self, index: usize) -> Result, Error> { unsafe { // https://sqlite.org/c3ref/column_database_name.html // @@ -149,7 +149,7 @@ impl StatementHandle { return Err(SqliteError::new(self.db_handle()).into()); } - Ok(Some(not_null != 0)) + Ok(Some(not_null == 0)) } } @@ -249,8 +249,8 @@ impl StatementHandle { } #[inline] - pub(crate) fn column_value(&self, index: usize) -> SqliteValue { - unsafe { SqliteValue::new(sqlite3_column_value(self.0.as_ptr(), index as c_int)) } + pub(crate) fn column_value(&self, index: usize) -> *mut sqlite3_value { + unsafe { sqlite3_column_value(self.0.as_ptr(), index as c_int) } } pub(crate) fn column_blob(&self, index: usize) -> &[u8] { diff --git a/sqlx-core/src/sqlite/statement/mod.rs b/sqlx-core/src/sqlite/statement/mod.rs index 2730dd17..7d3e01c2 100644 --- a/sqlx-core/src/sqlite/statement/mod.rs +++ b/sqlx-core/src/sqlite/statement/mod.rs @@ -1,7 +1,7 @@ use std::i32; use std::os::raw::c_char; use std::ptr::{null, null_mut, NonNull}; -use std::sync::{atomic::AtomicPtr, Weak}; +use std::sync::{atomic::AtomicPtr, Arc, Weak}; use bytes::{Buf, Bytes}; use libsqlite3_sys::{ @@ -12,7 +12,7 @@ use smallvec::SmallVec; use crate::error::Error; use crate::sqlite::connection::ConnectionHandle; -use crate::sqlite::{SqliteError, SqliteRow, SqliteValue}; +use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; mod handle; mod worker; @@ -35,6 +35,9 @@ pub(crate) struct SqliteStatement { // we use a [`SmallVec`] to optimize for the most likely case of a single statement pub(crate) handles: SmallVec<[StatementHandle; 1]>, + // weak references to each set of columns + pub(crate) columns: SmallVec<[Arc>; 1]>, + // weak reference to the previous row from this connection // we use the notice of a successful upgrade of this reference as an indicator that the // row is still around, in which we then inflate the row such that we can let SQLite @@ -122,6 +125,7 @@ impl SqliteStatement { tail: query, handles, index: 0, + columns: SmallVec::from([Default::default(); 1]), last_row_values: SmallVec::from([None; 1]), }) } @@ -133,7 +137,14 @@ impl SqliteStatement { pub(crate) fn execute( &mut self, - ) -> Result>>)>, Error> { + ) -> Result< + Option<( + &StatementHandle, + &mut Arc>, + &mut Option>>, + )>, + Error, + > { while self.handles.len() == self.index { if self.tail.is_empty() { return Ok(None); @@ -143,6 +154,7 @@ impl SqliteStatement { unsafe { prepare(self.connection(), &mut self.tail, self.persistent)? } { self.handles.push(handle); + self.columns.push(Default::default()); self.last_row_values.push(None); } } @@ -152,6 +164,7 @@ impl SqliteStatement { Ok(Some(( &self.handles[index], + &mut self.columns[index], &mut self.last_row_values[index], ))) } @@ -160,7 +173,7 @@ impl SqliteStatement { self.index = 0; for (i, handle) in self.handles.iter().enumerate() { - SqliteRow::inflate_if_needed(&handle, self.last_row_values[i].take()); + SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); unsafe { // Reset A Prepared Statement Object @@ -176,7 +189,7 @@ impl SqliteStatement { impl Drop for SqliteStatement { fn drop(&mut self) { for (i, handle) in self.handles.drain(..).enumerate() { - SqliteRow::inflate_if_needed(&handle, self.last_row_values[i].take()); + SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); unsafe { // https://sqlite.org/c3ref/finalize.html diff --git a/sqlx-core/src/sqlite/types/bytes.rs b/sqlx-core/src/sqlite/types/bytes.rs index f3c97296..b67d4a1e 100644 --- a/sqlx-core/src/sqlite/types/bytes.rs +++ b/sqlx-core/src/sqlite/types/bytes.rs @@ -11,6 +11,10 @@ impl Type for [u8] { fn type_info() -> SqliteTypeInfo { SqliteTypeInfo(DataType::Blob) } + + fn compatible(ty: &SqliteTypeInfo) -> bool { + matches!(ty.0, DataType::Blob | DataType::Text) + } } impl<'q> Encode<'q, Sqlite> for &'q [u8] { @@ -31,6 +35,10 @@ impl Type for Vec { fn type_info() -> SqliteTypeInfo { <&[u8] as Type>::type_info() } + + fn compatible(ty: &SqliteTypeInfo) -> bool { + <&[u8] as Type>::compatible(ty) + } } impl<'q> Encode<'q, Sqlite> for Vec { diff --git a/sqlx-core/src/sqlite/value.rs b/sqlx-core/src/sqlite/value.rs index b89d05a9..3730a694 100644 --- a/sqlx-core/src/sqlite/value.rs +++ b/sqlx-core/src/sqlite/value.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::ptr::NonNull; use std::slice::from_raw_parts; use std::str::from_utf8; @@ -6,18 +5,19 @@ use std::sync::Arc; use libsqlite3_sys::{ sqlite3_value, sqlite3_value_blob, sqlite3_value_bytes, sqlite3_value_double, - sqlite3_value_dup, sqlite3_value_int, sqlite3_value_int64, sqlite3_value_type, SQLITE_NULL, + sqlite3_value_dup, sqlite3_value_free, sqlite3_value_int, sqlite3_value_int64, + sqlite3_value_type, SQLITE_NULL, }; use crate::error::BoxDynError; use crate::sqlite::statement::StatementHandle; -use crate::sqlite::type_info::DataType; use crate::sqlite::{Sqlite, SqliteTypeInfo}; use crate::value::{Value, ValueRef}; enum SqliteValueData<'r> { Statement { statement: &'r StatementHandle, + type_info: SqliteTypeInfo, index: usize, }, @@ -31,41 +31,64 @@ impl<'r> SqliteValueRef<'r> { Self(SqliteValueData::Value(value)) } - pub(crate) fn statement(statement: &'r StatementHandle, index: usize) -> Self { - Self(SqliteValueData::Statement { statement, index }) + pub(crate) fn statement( + statement: &'r StatementHandle, + type_info: SqliteTypeInfo, + index: usize, + ) -> Self { + Self(SqliteValueData::Statement { + statement, + type_info, + index, + }) } pub(super) fn int(&self) -> i32 { match self.0 { - SqliteValueData::Statement { statement, index } => statement.column_int(index), + SqliteValueData::Statement { + statement, index, .. + } => statement.column_int(index), + SqliteValueData::Value(v) => v.int(), } } pub(super) fn int64(&self) -> i64 { match self.0 { - SqliteValueData::Statement { statement, index } => statement.column_int64(index), + SqliteValueData::Statement { + statement, index, .. + } => statement.column_int64(index), + SqliteValueData::Value(v) => v.int64(), } } pub(super) fn double(&self) -> f64 { match self.0 { - SqliteValueData::Statement { statement, index } => statement.column_double(index), + SqliteValueData::Statement { + statement, index, .. + } => statement.column_double(index), + SqliteValueData::Value(v) => v.double(), } } pub(super) fn blob(&self) -> &'r [u8] { match self.0 { - SqliteValueData::Statement { statement, index } => statement.column_blob(index), + SqliteValueData::Statement { + statement, index, .. + } => statement.column_blob(index), + SqliteValueData::Value(v) => v.blob(), } } pub(super) fn text(&self) -> Result<&'r str, BoxDynError> { match self.0 { - SqliteValueData::Statement { statement, index } => statement.column_text(index), + SqliteValueData::Statement { + statement, index, .. + } => statement.column_text(index), + SqliteValueData::Value(v) => v.text(), } } @@ -76,32 +99,28 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> { fn to_owned(&self) -> SqliteValue { match self.0 { - SqliteValueData::Statement { statement, index } => statement.column_value(index), + SqliteValueData::Statement { + statement, + index, + ref type_info, + } => unsafe { SqliteValue::new(statement.column_value(index), type_info.clone()) }, + SqliteValueData::Value(v) => v.clone(), } } - fn type_info(&self) -> Option> { + fn type_info(&self) -> &SqliteTypeInfo { match self.0 { - SqliteValueData::Statement { statement, index } => statement - .column_decltype(index) - .or_else(|| { - // fall back to the storage class for expressions - Some(SqliteTypeInfo(DataType::from_code( - statement.column_type(index), - ))) - }) - .map(Cow::Owned), - - SqliteValueData::Value(v) => v.type_info(), + SqliteValueData::Statement { ref type_info, .. } => &type_info, + SqliteValueData::Value(v) => &v.type_info, } } fn is_null(&self) -> bool { match self.0 { - SqliteValueData::Statement { statement, index } => { - statement.column_type(index) == SQLITE_NULL - } + SqliteValueData::Statement { + statement, index, .. + } => statement.column_type(index) == SQLITE_NULL, SqliteValueData::Value(v) => v.is_null(), } @@ -109,43 +128,50 @@ impl<'r> ValueRef<'r> for SqliteValueRef<'r> { } #[derive(Clone)] -pub struct SqliteValue(pub(crate) Arc>); +pub struct SqliteValue { + pub(crate) handle: Arc, + pub(crate) type_info: SqliteTypeInfo, +} + +pub(crate) struct ValueHandle(NonNull); // SAFE: only protected value objects are stored in SqliteValue -unsafe impl Send for SqliteValue {} -unsafe impl Sync for SqliteValue {} +unsafe impl Send for ValueHandle {} +unsafe impl Sync for ValueHandle {} impl SqliteValue { - pub(crate) unsafe fn new(value: *mut sqlite3_value) -> Self { + pub(crate) unsafe fn new(value: *mut sqlite3_value, type_info: SqliteTypeInfo) -> Self { debug_assert!(!value.is_null()); - Self(Arc::new(NonNull::new_unchecked(sqlite3_value_dup(value)))) - } - fn r#type(&self) -> DataType { - DataType::from_code(unsafe { sqlite3_value_type(self.0.as_ptr()) }) + Self { + type_info, + handle: Arc::new(ValueHandle(NonNull::new_unchecked(sqlite3_value_dup( + value, + )))), + } } fn int(&self) -> i32 { - unsafe { sqlite3_value_int(self.0.as_ptr()) } + unsafe { sqlite3_value_int(self.handle.0.as_ptr()) } } fn int64(&self) -> i64 { - unsafe { sqlite3_value_int64(self.0.as_ptr()) } + unsafe { sqlite3_value_int64(self.handle.0.as_ptr()) } } fn double(&self) -> f64 { - unsafe { sqlite3_value_double(self.0.as_ptr()) } + unsafe { sqlite3_value_double(self.handle.0.as_ptr()) } } fn blob(&self) -> &[u8] { - let len = unsafe { sqlite3_value_bytes(self.0.as_ptr()) } as usize; + let len = unsafe { sqlite3_value_bytes(self.handle.0.as_ptr()) } as usize; if len == 0 { // empty blobs are NULL so just return an empty slice return &[]; } - let ptr = unsafe { sqlite3_value_blob(self.0.as_ptr()) } as *const u8; + let ptr = unsafe { sqlite3_value_blob(self.handle.0.as_ptr()) } as *const u8; debug_assert!(!ptr.is_null()); unsafe { from_raw_parts(ptr, len) } @@ -163,11 +189,19 @@ impl Value for SqliteValue { SqliteValueRef::value(self) } - fn type_info(&self) -> Option> { - Some(Cow::Owned(SqliteTypeInfo(self.r#type()))) + fn type_info(&self) -> &SqliteTypeInfo { + &self.type_info } fn is_null(&self) -> bool { - unsafe { sqlite3_value_type(self.0.as_ptr()) == SQLITE_NULL } + unsafe { sqlite3_value_type(self.handle.0.as_ptr()) == SQLITE_NULL } + } +} + +impl Drop for ValueHandle { + fn drop(&mut self) { + unsafe { + sqlite3_value_free(self.0.as_ptr()); + } } } diff --git a/sqlx-core/src/statement.rs b/sqlx-core/src/statement.rs new file mode 100644 index 00000000..bb6681a0 --- /dev/null +++ b/sqlx-core/src/statement.rs @@ -0,0 +1,61 @@ +use crate::database::Database; +use either::Either; +use std::convert::identity; + +/// Provides information on a prepared statement. +/// +/// Returned from [`Executor::describe`](trait.Executor.html#method.describe). +/// +/// The query macros (e.g., `query!`, `query_as!`, etc.) use the information here to validate +/// output and parameter types; and, generate an anonymous record. +#[derive(Debug)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr( + feature = "offline", + serde(bound( + serialize = "DB::TypeInfo: serde::Serialize, DB::Column: serde::Serialize", + deserialize = "DB::TypeInfo: serde::de::DeserializeOwned, DB::Column: serde::de::DeserializeOwned", + )) +)] +#[doc(hidden)] +pub struct StatementInfo { + pub(crate) columns: Vec, + pub(crate) parameters: Option, usize>>, + pub(crate) nullable: Vec>, +} + +impl StatementInfo { + /// Gets the column information at `index`. + /// + /// Panics if `index` is out of bounds. + pub fn column(&self, index: usize) -> &DB::Column { + &self.columns[index] + } + + /// Gets the column information at `index` or `None` if out of bounds. + pub fn try_column(&self, index: usize) -> Option<&DB::Column> { + self.columns.get(index) + } + + /// Gets all columns in this statement. + pub fn columns(&self) -> &[DB::Column] { + &self.columns + } + + /// Gets the available information for parameters in this statement. + /// + /// Some drivers may return more or less than others. As an example, **PostgreSQL** will + /// return `Some(Either::Left(_))` with a full list of type information for each parameter. + /// However, **MSSQL** will return `None` as there is no information available. + pub fn parameters(&self) -> Option> { + self.parameters.as_ref().map(|p| match p { + Either::Left(params) => Either::Left(&**params), + Either::Right(count) => Either::Right(*count), + }) + } + + /// Gets whether a column may be `NULL`, if this information is available. + pub fn nullable(&self, column: usize) -> Option { + self.nullable.get(column).copied().and_then(identity) + } +} diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 3c90dc1f..2c67a3d2 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -190,7 +190,7 @@ macro_rules! impl_executor_for_transaction { query: E, ) -> futures_core::future::BoxFuture< 'e, - Result, crate::error::Error>, + Result, crate::error::Error>, > where 't: 'e, diff --git a/sqlx-core/src/value.rs b/sqlx-core/src/value.rs index 1ef4c3b7..f800814d 100644 --- a/sqlx-core/src/value.rs +++ b/sqlx-core/src/value.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use crate::database::{Database, HasValueRef}; use crate::decode::Decode; use crate::error::{mismatched_types, Error}; @@ -12,11 +10,8 @@ pub trait Value { /// Get this value as a reference. fn as_ref(&self) -> >::ValueRef; - /// Get the type information, if available, for this value. - /// - /// Some database implementations do not implement type deduction for - /// expressions (`SELECT 2 + 5`); and, this will return `None` in those cases. - fn type_info(&self) -> Option::TypeInfo>>; + /// Get the type information for this value. + fn type_info(&self) -> &::TypeInfo; /// Returns `true` if the SQL value is `NULL`. fn is_null(&self) -> bool; @@ -68,10 +63,10 @@ pub trait Value { T: Decode<'r, Self::Database> + Type, { if !self.is_null() { - if let Some(ty) = self.type_info() { - if !T::compatible(&ty) { - return Err(Error::Decode(mismatched_types::(&ty))); - } + let ty = self.type_info(); + + if !T::compatible(&ty) { + return Err(Error::Decode(mismatched_types::(&ty))); } } @@ -108,11 +103,8 @@ pub trait ValueRef<'r>: Sized { /// this is a copy. fn to_owned(&self) -> ::Value; - /// Get the type information, if available, for this value. - /// - /// Some database implementations do not implement type deduction for - /// expressions (`SELECT 2 + 5`); and, this will return `None` in those cases. - fn type_info(&self) -> Option::TypeInfo>>; + /// Get the type information for this value. + fn type_info(&self) -> &::TypeInfo; /// Returns `true` if the SQL value is `NULL`. fn is_null(&self) -> bool; diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 9bab0661..966b67af 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -45,12 +45,13 @@ dotenv = { version = "0.15.0", default-features = false } futures = { version = "0.3.4", default-features = false, features = [ "executor" ] } hex = { version = "0.4.2", optional = true } heck = "0.3.1" +either = "1.5.3" proc-macro2 = { version = "1.0.9", default-features = false } sqlx-core = { version = "0.4.0-pre", default-features = false, path = "../sqlx-core" } sqlx-rt = { version = "0.1.0-pre", default-features = false, path = "../sqlx-rt" } serde = { version = "1.0.111", optional = true } serde_json = { version = "1.0.30", features = [ "preserve_order" ], optional = true } -sha2 = { version = "0.8.2", optional = true } +sha2 = { version = "0.9.1", optional = true } syn = { version = "1.0.30", default-features = false, features = [ "full" ] } quote = { version = "1.0.6", default-features = false } url = { version = "2.1.1", default-features = false } diff --git a/sqlx-macros/src/query/args.rs b/sqlx-macros/src/query/args.rs index 756e3984..4b630a99 100644 --- a/sqlx-macros/src/query/args.rs +++ b/sqlx-macros/src/query/args.rs @@ -1,18 +1,17 @@ +use crate::database::DatabaseExt; +use crate::query::QueryMacroInput; +use either::Either; use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use sqlx_core::statement::StatementInfo; use syn::spanned::Spanned; use syn::{Expr, Type}; -use quote::{quote, quote_spanned}; -use sqlx_core::describe::Describe; - -use crate::database::{DatabaseExt, ParamChecking}; -use crate::query::QueryMacroInput; - /// Returns a tokenstream which typechecks the arguments passed to the macro /// and binds them to `DB::Arguments` with the ident `query_args`. pub fn quote_args( input: &QueryMacroInput, - describe: &Describe, + info: &StatementInfo, ) -> crate::Result { let db_path = DB::db_path(); @@ -24,71 +23,71 @@ pub fn quote_args( let arg_name = &input.arg_names; - let args_check = if input.checked && DB::PARAM_CHECKING == ParamChecking::Strong { - describe - .params - .iter() - .zip(input.arg_names.iter().zip(&input.arg_exprs)) - .enumerate() - .map(|(i, (param_ty, (name, expr)))| -> crate::Result<_> { - // TODO: We could remove the ParamChecking flag and just filter to only test params that are non-null - let param_ty = param_ty.as_ref().unwrap(); + let args_check = match info.parameters() { + None | Some(Either::Right(_)) => { + // all we can do is check arity which we did + TokenStream::new() + } - let param_ty = match get_type_override(expr) { - // TODO: enable this in 1.45 when we can strip `as _` - // without stripping these we get some pretty nasty type errors - Some(Type::Infer(_)) => return Err( - syn::Error::new_spanned( - expr, - "casts to `_` are not allowed in bind parameters yet" - ).into() - ), - // cast or type ascription will fail to compile if the type does not match - Some(_) => return Ok(quote!()), - None => { - DB::param_type_for_id(¶m_ty) - .ok_or_else(|| { - if let Some(feature_gate) = ::get_feature_gate(¶m_ty) { - format!( - "optional feature `{}` required for type {} of param #{}", - feature_gate, - param_ty, - i + 1, - ) - } else { - format!("unsupported type {} for param #{}", param_ty, i + 1) - } - })? - .parse::() - .map_err(|_| format!("Rust type mapping for {} not parsable", param_ty))? + Some(Either::Left(params)) => { + params + .iter() + .zip(input.arg_names.iter().zip(&input.arg_exprs)) + .enumerate() + .map(|(i, (param_ty, (name, expr)))| -> crate::Result<_> { + let param_ty = match get_type_override(expr) { + // TODO: enable this in 1.45 when we can strip `as _` + // without stripping these we get some pretty nasty type errors + Some(Type::Infer(_)) => return Err( + syn::Error::new_spanned( + expr, + "casts to `_` are not allowed in bind parameters yet" + ).into() + ), + // cast or type ascription will fail to compile if the type does not match + Some(_) => return Ok(quote!()), + None => { + DB::param_type_for_id(¶m_ty) + .ok_or_else(|| { + if let Some(feature_gate) = ::get_feature_gate(¶m_ty) { + format!( + "optional feature `{}` required for type {} of param #{}", + feature_gate, + param_ty, + i + 1, + ) + } else { + format!("unsupported type {} for param #{}", param_ty, i + 1) + } + })? + .parse::() + .map_err(|_| format!("Rust type mapping for {} not parsable", param_ty))? - } - }; + } + }; - Ok(quote_spanned!(expr.span() => - // this shouldn't actually run - if false { - use sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _}; + Ok(quote_spanned!(expr.span() => + // this shouldn't actually run + if false { + use sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _}; - // evaluate the expression only once in case it contains moves - let _expr = sqlx::ty_match::dupe_value(&$#name); + // evaluate the expression only once in case it contains moves + let _expr = sqlx::ty_match::dupe_value(&$#name); - // if `_expr` is `Option`, get `Option<$ty>`, otherwise `$ty` - let ty_check = sqlx::ty_match::WrapSame::<#param_ty, _>::new(&_expr).wrap_same(); - // if `_expr` is `&str`, convert `String` to `&str` - let (mut ty_check, match_borrow) = sqlx::ty_match::MatchBorrow::new(ty_check, &_expr); + // if `_expr` is `Option`, get `Option<$ty>`, otherwise `$ty` + let ty_check = sqlx::ty_match::WrapSame::<#param_ty, _>::new(&_expr).wrap_same(); + // if `_expr` is `&str`, convert `String` to `&str` + let (mut ty_check, match_borrow) = sqlx::ty_match::MatchBorrow::new(ty_check, &_expr); - ty_check = match_borrow.match_borrow(); + ty_check = match_borrow.match_borrow(); - // this causes move-analysis to effectively ignore this block - panic!(); - } - )) - }) - .collect::>()? - } else { - // all we can do is check arity which we did in `QueryMacroInput::describe_validate()` - TokenStream::new() + // this causes move-analysis to effectively ignore this block + panic!(); + } + )) + }) + .collect::>()? + } }; let args_count = input.arg_names.len(); diff --git a/sqlx-macros/src/query/data.rs b/sqlx-macros/src/query/data.rs index 9af35a48..97c4ea6f 100644 --- a/sqlx-macros/src/query/data.rs +++ b/sqlx-macros/src/query/data.rs @@ -1,20 +1,20 @@ use sqlx_core::database::Database; -use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::statement::StatementInfo; #[cfg_attr(feature = "offline", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr( feature = "offline", serde(bound( - serialize = "Describe: serde::Serialize", - deserialize = "Describe: serde::de::DeserializeOwned" + serialize = "StatementInfo: serde::Serialize", + deserialize = "StatementInfo: serde::de::DeserializeOwned" )) )] #[derive(Debug)] pub struct QueryData { #[allow(dead_code)] pub(super) query: String, - pub(super) describe: Describe, + pub(super) describe: StatementInfo, #[cfg(feature = "offline")] pub(super) hash: String, } @@ -43,7 +43,7 @@ pub mod offline { use crate::database::DatabaseExt; use proc_macro2::Span; use serde::de::{Deserializer, IgnoredAny, MapAccess, Visitor}; - use sqlx_core::describe::Describe; + use sqlx_core::statement::StatementInfo; use std::path::Path; #[derive(serde::Deserialize)] @@ -75,14 +75,14 @@ pub mod offline { impl QueryData where - Describe: serde::Serialize + serde::de::DeserializeOwned, + StatementInfo: serde::Serialize + serde::de::DeserializeOwned, { pub fn from_dyn_data(dyn_data: DynQueryData) -> crate::Result { assert!(!dyn_data.db_name.is_empty()); assert!(!dyn_data.hash.is_empty()); if DB::NAME == dyn_data.db_name { - let describe: Describe = serde_json::from_value(dyn_data.describe)?; + let describe: StatementInfo = serde_json::from_value(dyn_data.describe)?; Ok(QueryData { query: dyn_data.query, describe, diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 196d29b2..b6d4ba86 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -9,12 +9,13 @@ pub use input::QueryMacroInput; use quote::{format_ident, quote}; use sqlx_core::connection::Connect; use sqlx_core::database::Database; -use sqlx_core::describe::Describe; +use sqlx_core::statement::StatementInfo; use sqlx_rt::block_on; use crate::database::DatabaseExt; use crate::query::data::QueryData; use crate::query::input::RecordType; +use either::Either; mod args; mod data; @@ -154,47 +155,52 @@ pub fn expand_from_file( } } -// marker trait for `Describe` that lets us conditionally require it to be `Serialize + Deserialize` +// marker trait for `StatementInfo` that lets us conditionally require it to be `Serialize + Deserialize` #[cfg(feature = "offline")] -trait DescribeExt: serde::Serialize + serde::de::DeserializeOwned {} +trait StatementInfoExt: serde::Serialize + serde::de::DeserializeOwned {} #[cfg(feature = "offline")] -impl DescribeExt for Describe where +impl StatementInfoExt for StatementInfo where Describe: serde::Serialize + serde::de::DeserializeOwned { } #[cfg(not(feature = "offline"))] -trait DescribeExt {} +trait StatementInfoExt {} #[cfg(not(feature = "offline"))] -impl DescribeExt for Describe {} +impl StatementInfoExt for StatementInfo {} fn expand_with_data( input: QueryMacroInput, data: QueryData, ) -> crate::Result where - Describe: DescribeExt, + StatementInfo: StatementInfoExt, { // validate at the minimum that our args match the query's input parameters - if input.arg_names.len() != data.describe.params.len() { - return Err(syn::Error::new( - Span::call_site(), - format!( - "expected {} parameters, got {}", - data.describe.params.len(), - input.arg_names.len() - ), - ) - .into()); + let num_parameters = match data.describe.parameters() { + Some(Either::Left(params)) => Some(params.len()), + Some(Either::Right(num)) => Some(num), + + None => None, + }; + + if let Some(num) = num_parameters { + if num != input.arg_names.len() { + return Err(syn::Error::new( + Span::call_site(), + format!("expected {} parameters, got {}", num, input.arg_names.len()), + ) + .into()); + } } let args_tokens = args::quote_args(&input, &data.describe)?; let query_args = format_ident!("query_args"); - let output = if data.describe.columns.is_empty() { + let output = if data.describe.columns().is_empty() { let db_path = DB::db_path(); let sql = &input.src; diff --git a/sqlx-macros/src/query/output.rs b/sqlx-macros/src/query/output.rs index 659e747b..249ec9d9 100644 --- a/sqlx-macros/src/query/output.rs +++ b/sqlx-macros/src/query/output.rs @@ -2,7 +2,8 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::Type; -use sqlx_core::describe::{Column, Describe}; +use sqlx_core::column::Column; +use sqlx_core::statement::StatementInfo; use crate::database::DatabaseExt; @@ -41,29 +42,31 @@ impl Display for DisplayColumn<'_> { } } -pub fn columns_to_rust(describe: &Describe) -> crate::Result> { +pub fn columns_to_rust( + describe: &StatementInfo, +) -> crate::Result> { describe - .columns + .columns() .iter() .enumerate() .map(|(i, column)| -> crate::Result<_> { // add raw prefix to all identifiers - let decl = ColumnDecl::parse(&column.name) - .map_err(|e| format!("column name {:?} is invalid: {}", column.name, e))?; + let decl = ColumnDecl::parse(&column.name()) + .map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?; let type_ = match decl.r#override { Some(ColumnOverride::Exact(ty)) => Some(ty.to_token_stream()), Some(ColumnOverride::Wildcard) => None, // these three could be combined but I prefer the clarity here - Some(ColumnOverride::NonNull) => Some(get_column_type(i, column)), + Some(ColumnOverride::NonNull) => Some(get_column_type::(i, column)), Some(ColumnOverride::Nullable) => { - let type_ = get_column_type(i, column); + let type_ = get_column_type::(i, column); Some(quote! { Option<#type_> }) } None => { - let type_ = get_column_type(i, column); + let type_ = get_column_type::(i, column); - if column.not_null.unwrap_or(false) { + if !describe.nullable(i).unwrap_or(true) { Some(type_) } else { Some(quote! { Option<#type_> }) @@ -121,49 +124,36 @@ pub fn quote_query_as( } } -fn get_column_type(i: usize, column: &Column) -> TokenStream { - if let Some(type_info) = &column.type_info { - ::return_type_for_id(&type_info).map_or_else( - || { - let message = - if let Some(feature_gate) = ::get_feature_gate(&type_info) { - format!( - "optional feature `{feat}` required for type {ty} of {col}", - ty = &type_info, - feat = feature_gate, - col = DisplayColumn { - idx: i, - name: &*column.name - } - ) - } else { - format!( - "unsupported type {ty} of {col}", - ty = type_info, - col = DisplayColumn { - idx: i, - name: &*column.name - } - ) - }; - syn::Error::new(Span::call_site(), message).to_compile_error() - }, - |t| t.parse().unwrap(), - ) - } else { - syn::Error::new( - Span::call_site(), - format!( - "database couldn't tell us the type of {col}; \ - this can happen for columns that are the result of an expression", - col = DisplayColumn { - idx: i, - name: &*column.name - } - ), - ) - .to_compile_error() - } +fn get_column_type(i: usize, column: &DB::Column) -> TokenStream { + let type_info = &*column.type_info(); + + ::return_type_for_id(&type_info).map_or_else( + || { + let message = + if let Some(feature_gate) = ::get_feature_gate(&type_info) { + format!( + "optional feature `{feat}` required for type {ty} of {col}", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: &*column.name() + } + ) + } else { + format!( + "unsupported type {ty} of {col}", + ty = type_info, + col = DisplayColumn { + idx: i, + name: &*column.name() + } + ) + }; + syn::Error::new(Span::call_site(), message).to_compile_error() + }, + |t| t.parse().unwrap(), + ) } impl ColumnDecl { diff --git a/src/lib.rs b/src/lib.rs index 5c9803be..cab4562f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(docsrs, feature(doc_cfg))] pub use sqlx_core::arguments::{Arguments, IntoArguments}; +pub use sqlx_core::column::Column; pub use sqlx_core::connection::{Connect, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::executor::{Execute, Executor}; @@ -10,14 +11,12 @@ pub use sqlx_core::query::{query, query_with}; pub use sqlx_core::query_as::{query_as, query_as_with}; pub use sqlx_core::query_scalar::{query_scalar, query_scalar_with}; pub use sqlx_core::row::{ColumnIndex, Row}; +pub use sqlx_core::statement::StatementInfo; pub use sqlx_core::transaction::{Transaction, TransactionManager}; pub use sqlx_core::type_info::TypeInfo; pub use sqlx_core::types::Type; pub use sqlx_core::value::{Value, ValueRef}; -#[doc(hidden)] -pub use sqlx_core::describe; - #[doc(inline)] pub use sqlx_core::error::{self, Error, Result}; diff --git a/tests/mssql/describe.rs b/tests/mssql/describe.rs index ae53c8d9..109e08ee 100644 --- a/tests/mssql/describe.rs +++ b/tests/mssql/describe.rs @@ -1,37 +1,41 @@ use sqlx::mssql::Mssql; -use sqlx::{describe::Column, Executor}; +use sqlx::{Column, Executor, TypeInfo}; use sqlx_test::new; -fn type_names(columns: &[Column]) -> Vec { - columns - .iter() - .filter_map(|col| Some(col.type_info.as_ref()?.to_string())) - .collect() -} - #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn.describe("SELECT * FROM tweet").await?; - let columns = d.columns; - assert_eq!(columns[0].name, "id"); - assert_eq!(columns[1].name, "text"); - assert_eq!(columns[2].name, "is_sent"); - assert_eq!(columns[3].name, "owner_id"); + assert_eq!(d.column(0).name(), "id"); + assert_eq!(d.column(1).name(), "text"); + assert_eq!(d.column(2).name(), "is_sent"); + assert_eq!(d.column(3).name(), "owner_id"); - assert_eq!(columns[0].not_null, Some(true)); - assert_eq!(columns[1].not_null, Some(true)); - assert_eq!(columns[2].not_null, Some(true)); - assert_eq!(columns[3].not_null, Some(false)); + assert_eq!(d.nullable(0), Some(false)); + assert_eq!(d.nullable(1), Some(false)); + assert_eq!(d.nullable(2), Some(false)); + assert_eq!(d.nullable(3), Some(true)); - let column_type_names = type_names(&columns); - - assert_eq!(column_type_names[0], "BIGINT"); - assert_eq!(column_type_names[1], "NVARCHAR"); - assert_eq!(column_type_names[2], "TINYINT"); - assert_eq!(column_type_names[3], "BIGINT"); + assert_eq!(d.column(0).type_info().name(), "BIGINT"); + assert_eq!(d.column(1).type_info().name(), "NVARCHAR"); + assert_eq!(d.column(2).type_info().name(), "TINYINT"); + assert_eq!(d.column(3).type_info().name(), "BIGINT"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_with_params() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn + .describe("SELECT text FROM tweet WHERE id = @p1") + .await?; + + assert_eq!(d.column(0).name(), "text"); + assert_eq!(d.nullable(0), Some(false)); Ok(()) } diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index 14b1fb25..0aeb4dfb 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -27,6 +27,18 @@ async fn it_can_select_expression() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_can_select_expression_by_name() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let row: MssqlRow = conn.fetch_one("SELECT 4 as _3").await?; + let v: i32 = row.try_get("_3")?; + + assert_eq!(v, 4); + + Ok(()) +} + #[sqlx_macros::test] async fn it_can_fail_to_connect() -> anyhow::Result<()> { let res = MssqlConnection::connect("mssql://sa@localhost").await; diff --git a/tests/mysql/describe.rs b/tests/mysql/describe.rs index 50f52207..5bf18267 100644 --- a/tests/mysql/describe.rs +++ b/tests/mysql/describe.rs @@ -1,38 +1,27 @@ use sqlx::mysql::MySql; -use sqlx::Executor; -use sqlx_core::describe::Column; +use sqlx::{Column, Executor, TypeInfo}; use sqlx_test::new; -fn type_names(columns: &[Column]) -> Vec { - columns - .iter() - .filter_map(|col| Some(col.type_info.as_ref()?.to_string())) - .collect() -} - #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn.describe("SELECT * FROM tweet").await?; - let columns = d.columns; - assert_eq!(columns[0].name, "id"); - assert_eq!(columns[1].name, "created_at"); - assert_eq!(columns[2].name, "text"); - assert_eq!(columns[3].name, "owner_id"); + assert_eq!(d.column(0).name(), "id"); + assert_eq!(d.column(1).name(), "created_at"); + assert_eq!(d.column(2).name(), "text"); + assert_eq!(d.column(3).name(), "owner_id"); - assert_eq!(columns[0].not_null, Some(true)); - assert_eq!(columns[1].not_null, Some(true)); - assert_eq!(columns[2].not_null, Some(true)); - assert_eq!(columns[3].not_null, Some(false)); + assert_eq!(d.nullable(0), Some(false)); + assert_eq!(d.nullable(1), Some(false)); + assert_eq!(d.nullable(2), Some(false)); + assert_eq!(d.nullable(3), Some(true)); - let column_type_names = type_names(&columns); - - assert_eq!(column_type_names[0], "BIGINT"); - assert_eq!(column_type_names[1], "TIMESTAMP"); - assert_eq!(column_type_names[2], "TEXT"); - assert_eq!(column_type_names[3], "BIGINT"); + assert_eq!(d.column(0).type_info().name(), "BIGINT"); + assert_eq!(d.column(1).type_info().name(), "TIMESTAMP"); + assert_eq!(d.column(2).type_info().name(), "TEXT"); + assert_eq!(d.column(3).type_info().name(), "BIGINT"); Ok(()) } @@ -45,9 +34,7 @@ async fn uses_alias_name() -> anyhow::Result<()> { .describe("SELECT text AS tweet_text FROM tweet") .await?; - let columns = d.columns; - - assert_eq!(columns[0].name, "tweet_text"); + assert_eq!(d.column(0).name(), "tweet_text"); Ok(()) } diff --git a/tests/postgres/describe.rs b/tests/postgres/describe.rs index d628a8dc..81ccbf41 100644 --- a/tests/postgres/describe.rs +++ b/tests/postgres/describe.rs @@ -1,37 +1,26 @@ -use sqlx::{postgres::Postgres, Executor}; -use sqlx_core::describe::Column; +use sqlx::{postgres::Postgres, Column, Executor, TypeInfo}; use sqlx_test::new; -fn type_names(columns: &[Column]) -> Vec { - columns - .iter() - .filter_map(|col| Some(col.type_info.as_ref()?.to_string())) - .collect() -} - #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn.describe("SELECT * FROM tweet").await?; - let columns = d.columns; - assert_eq!(columns[0].name, "id"); - assert_eq!(columns[1].name, "created_at"); - assert_eq!(columns[2].name, "text"); - assert_eq!(columns[3].name, "owner_id"); + assert_eq!(d.column(0).name(), "id"); + assert_eq!(d.column(1).name(), "created_at"); + assert_eq!(d.column(2).name(), "text"); + assert_eq!(d.column(3).name(), "owner_id"); - assert_eq!(columns[0].not_null, Some(true)); - assert_eq!(columns[1].not_null, Some(true)); - assert_eq!(columns[2].not_null, Some(true)); - assert_eq!(columns[3].not_null, Some(false)); + assert_eq!(d.nullable(0), Some(false)); + assert_eq!(d.nullable(1), Some(false)); + assert_eq!(d.nullable(2), Some(false)); + assert_eq!(d.nullable(3), Some(true)); - let column_type_names = type_names(&columns); - - assert_eq!(column_type_names[0], "INT8"); - assert_eq!(column_type_names[1], "TIMESTAMPTZ"); - assert_eq!(column_type_names[2], "TEXT"); - assert_eq!(column_type_names[3], "INT8"); + assert_eq!(d.column(0).type_info().name(), "INT8"); + assert_eq!(d.column(1).type_info().name(), "TIMESTAMPTZ"); + assert_eq!(d.column(2).type_info().name(), "TEXT"); + assert_eq!(d.column(3).type_info().name(), "INT8"); Ok(()) } @@ -41,18 +30,14 @@ async fn it_describes_expression() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn.describe("SELECT 1::int8 + 10").await?; - let columns = d.columns; // ?column? will cause the macro to emit an error ad ask the user to explicitly name the type - assert_eq!(columns[0].name, "?column?"); + assert_eq!(d.column(0).name(), "?column?"); // postgres cannot infer nullability from an expression // this will cause the macro to emit `Option<_>` - assert_eq!(columns[0].not_null, None); - - let column_type_names = type_names(&columns); - - assert_eq!(column_type_names[0], "INT8"); + assert_eq!(d.nullable(0), None); + assert_eq!(d.column(0).type_info().name(), "INT8"); Ok(()) } @@ -62,14 +47,13 @@ async fn it_describes_enum() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn.describe("SELECT 'open'::status as _1").await?; - let columns = d.columns; - assert_eq!(columns[0].name, "_1"); - assert_eq!(columns[0].not_null, None); + assert_eq!(d.column(0).name(), "_1"); - let ty = columns[0].type_info.as_ref().unwrap(); + let ty = d.column(0).type_info(); + + assert_eq!(ty.name(), "status"); - assert_eq!(ty.to_string(), "status"); assert_eq!( format!("{:?}", ty.__kind()), r#"Enum(["new", "open", "closed"])"# @@ -83,10 +67,9 @@ async fn it_describes_record() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn.describe("SELECT (true, 10::int2)").await?; - let columns = d.columns; - let ty = columns[0].type_info.as_ref().unwrap(); - assert_eq!(ty.to_string(), "RECORD"); + let ty = d.column(0).type_info(); + assert_eq!(ty.name(), "RECORD"); Ok(()) } @@ -99,11 +82,9 @@ async fn it_describes_composite() -> anyhow::Result<()> { .describe("SELECT ROW('name',10,500)::inventory_item") .await?; - let columns = d.columns; + let ty = d.column(0).type_info(); - let ty = columns[0].type_info.as_ref().unwrap(); - - assert_eq!(ty.to_string(), "inventory_item"); + assert_eq!(ty.name(), "inventory_item"); assert_eq!( format!("{:?}", ty.__kind()), diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index ec263f49..ee699823 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -14,7 +14,30 @@ async fn test_query() -> anyhow::Result<()> { .fetch_one(&mut conn) .await?; - println!("account ID: {:?}", account.id); + assert_eq!(account.id, Some(1)); + assert_eq!(account.name.as_deref(), Some("Herp Derpinson")); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_non_null() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + let _ = sqlx::query!("INSERT INTO tweet (text) VALUES ('Hello')") + .execute(&mut tx) + .await?; + + let row = sqlx::query!("SELECT id, text, owner_id FROM tweet LIMIT 1") + .fetch_one(&mut tx) + .await?; + + assert_eq!(row.id, 1); + assert_eq!(row.text, "Hello"); + assert_eq!(row.owner_id, None); + + // let the transaction rollback so we don't actually insert the tweet Ok(()) } diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index a890bd8d..4f420e6d 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -1,42 +1,55 @@ -use sqlx::describe::Column; use sqlx::error::DatabaseError; use sqlx::sqlite::{SqliteConnectOptions, SqliteError}; -use sqlx::{sqlite::Sqlite, Executor}; +use sqlx::{sqlite::Sqlite, Column, Executor}; use sqlx::{Connect, SqliteConnection, TypeInfo}; use sqlx_test::new; use std::env; -fn type_names(columns: &[Column]) -> Vec { - columns - .iter() - .filter_map(|col| Some(col.type_info.as_ref()?.to_string())) - .collect() -} - #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT * FROM tweet").await?; + let info = conn.describe("SELECT * FROM tweet").await?; + let columns = info.columns(); - let columns = d.columns; + assert_eq!(columns[0].name(), "id"); + assert_eq!(columns[1].name(), "text"); + assert_eq!(columns[2].name(), "is_sent"); + assert_eq!(columns[3].name(), "owner_id"); - assert_eq!(columns[0].name, "id"); - assert_eq!(columns[1].name, "text"); - assert_eq!(columns[2].name, "is_sent"); - assert_eq!(columns[3].name, "owner_id"); + assert_eq!(columns[0].ordinal(), 0); + assert_eq!(columns[1].ordinal(), 1); + assert_eq!(columns[2].ordinal(), 2); + assert_eq!(columns[3].ordinal(), 3); - assert_eq!(columns[0].not_null, Some(true)); - assert_eq!(columns[1].not_null, Some(true)); - assert_eq!(columns[2].not_null, Some(true)); - assert_eq!(columns[3].not_null, Some(false)); // owner_id + assert_eq!(info.nullable(0), Some(false)); + assert_eq!(info.nullable(1), Some(false)); + assert_eq!(info.nullable(2), Some(false)); + assert_eq!(info.nullable(3), Some(true)); // owner_id - let column_type_names = type_names(&columns); + assert_eq!(columns[0].type_info().name(), "INTEGER"); + assert_eq!(columns[1].type_info().name(), "TEXT"); + assert_eq!(columns[2].type_info().name(), "BOOLEAN"); + assert_eq!(columns[3].type_info().name(), "INTEGER"); - assert_eq!(column_type_names[0], "INTEGER"); - assert_eq!(column_type_names[1], "TEXT"); - assert_eq!(column_type_names[2], "BOOLEAN"); - assert_eq!(column_type_names[3], "INTEGER"); + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_variables() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // without any context, we resolve to NULL + let info = conn.describe("SELECT ?1").await?; + + assert_eq!(info.column(0).type_info().name(), "NULL"); + assert_eq!(info.nullable(0), None); // unknown + + // context can be provided by using CAST(_ as _) + let info = conn.describe("SELECT CAST(?1 AS REAL)").await?; + + assert_eq!(info.column(0).type_info().name(), "REAL"); + assert_eq!(info.nullable(0), None); // unknown Ok(()) } @@ -49,23 +62,23 @@ async fn it_describes_expression() -> anyhow::Result<()> { .describe("SELECT 1 + 10, 5.12 * 2, 'Hello', x'deadbeef'") .await?; - let columns = d.columns; + let columns = d.columns(); - assert_eq!(columns[0].type_info.as_ref().unwrap().name(), "INTEGER"); - assert_eq!(columns[0].name, "1 + 10"); - assert_eq!(columns[0].not_null, None); + assert_eq!(columns[0].type_info().name(), "INTEGER"); + assert_eq!(columns[0].name(), "1 + 10"); + assert_eq!(d.nullable(0), Some(false)); // literal constant - assert_eq!(columns[1].type_info.as_ref().unwrap().name(), "REAL"); - assert_eq!(columns[1].name, "5.12 * 2"); - assert_eq!(columns[1].not_null, None); + assert_eq!(columns[1].type_info().name(), "REAL"); + assert_eq!(columns[1].name(), "5.12 * 2"); + assert_eq!(d.nullable(1), Some(false)); // literal constant - assert_eq!(columns[2].type_info.as_ref().unwrap().name(), "TEXT"); - assert_eq!(columns[2].name, "'Hello'"); - assert_eq!(columns[2].not_null, None); + assert_eq!(columns[2].type_info().name(), "TEXT"); + assert_eq!(columns[2].name(), "'Hello'"); + assert_eq!(d.nullable(2), Some(false)); // literal constant - assert_eq!(columns[3].type_info.as_ref().unwrap().name(), "BLOB"); - assert_eq!(columns[3].name, "x'deadbeef'"); - assert_eq!(columns[3].not_null, None); + assert_eq!(columns[3].type_info().name(), "BLOB"); + assert_eq!(columns[3].name(), "x'deadbeef'"); + assert_eq!(d.nullable(3), Some(false)); // literal constant Ok(()) } @@ -74,18 +87,46 @@ async fn it_describes_expression() -> anyhow::Result<()> { async fn it_describes_expression_from_empty_table() -> anyhow::Result<()> { let mut conn = new::().await?; - conn.execute("CREATE TEMP TABLE _temp_empty ( name TEXT, a INT )") + conn.execute("CREATE TEMP TABLE _temp_empty ( name TEXT NOT NULL, a INT )") .await?; let d = conn .describe("SELECT COUNT(*), a + 1, name, 5.12, 'Hello' FROM _temp_empty") .await?; - assert_eq!(d.columns[0].type_info.as_ref().unwrap().name(), "INTEGER"); - assert_eq!(d.columns[1].type_info.as_ref().unwrap().name(), "INTEGER"); - assert_eq!(d.columns[2].type_info.as_ref().unwrap().name(), "TEXT"); - assert_eq!(d.columns[3].type_info.as_ref().unwrap().name(), "REAL"); - assert_eq!(d.columns[4].type_info.as_ref().unwrap().name(), "TEXT"); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); // COUNT(*) + + assert_eq!(d.column(1).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(1), None); // `a + 1` is potentially nullable but we don't know for sure currently + + assert_eq!(d.column(2).type_info().name(), "TEXT"); + assert_eq!(d.nullable(2), Some(false)); // `name` is not nullable + + assert_eq!(d.column(3).type_info().name(), "REAL"); + assert_eq!(d.nullable(3), Some(false)); // literal constant + + assert_eq!(d.column(4).type_info().name(), "TEXT"); + assert_eq!(d.nullable(4), Some(false)); // literal constant + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_expression_from_empty_table_with_star() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute("CREATE TEMP TABLE _temp_empty ( name TEXT, a INT )") + .await?; + + let d = conn + .describe("SELECT *, 5, 'Hello' FROM _temp_empty") + .await?; + + assert_eq!(d.column(0).type_info().name(), "TEXT"); + assert_eq!(d.column(1).type_info().name(), "INTEGER"); + assert_eq!(d.column(2).type_info().name(), "INTEGER"); + assert_eq!(d.column(3).type_info().name(), "TEXT"); Ok(()) } @@ -98,14 +139,15 @@ async fn it_describes_insert() -> anyhow::Result<()> { .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')") .await?; - assert_eq!(d.columns.len(), 0); + assert_eq!(d.columns().len(), 0); let d = conn .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello'); SELECT last_insert_rowid();") .await?; - assert_eq!(d.columns.len(), 1); - assert_eq!(d.columns[0].type_info.as_ref().unwrap().name(), "INTEGER"); + assert_eq!(d.columns().len(), 1); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); Ok(()) } @@ -123,7 +165,7 @@ async fn it_describes_insert_with_read_only() -> anyhow::Result<()> { .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')") .await?; - assert_eq!(d.columns.len(), 0); + assert_eq!(d.columns().len(), 0); Ok(()) } diff --git a/tests/sqlite/macros.rs b/tests/sqlite/macros.rs index 80ce50b0..1e9b842c 100644 --- a/tests/sqlite/macros.rs +++ b/tests/sqlite/macros.rs @@ -24,8 +24,8 @@ async fn macro_select_expression() -> anyhow::Result<()> { .fetch_one(&mut conn) .await?; - assert_eq!(Some(10), row._1); - assert_eq!(Some("Hello"), row._2.as_deref()); + assert_eq!(10, row._1); + assert_eq!("Hello", &*row._2); Ok(()) } @@ -40,9 +40,9 @@ async fn macro_select_partial_expression() -> anyhow::Result<()> { .fetch_one(&mut conn) .await?; - assert_eq!(Some(10), row._1); - assert_eq!(Some("Hello"), row._2.as_deref()); - assert_eq!(Some(6), row.id_p); + assert_eq!(10, row._1); + assert_eq!("Hello", &*row._2); + assert_eq!(6, row.id_p); assert_eq!("Herp Derpinson", row.name); assert_eq!(row.is_active, Some(true)); diff --git a/tests/sqlite/sqlite.db b/tests/sqlite/sqlite.db index eadde272..0dddba46 100644 Binary files a/tests/sqlite/sqlite.db and b/tests/sqlite/sqlite.db differ