From a7117dd71be740cb57fd7b8b2bff8649e0a2ac7a Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sat, 27 Jun 2020 04:07:40 -0700 Subject: [PATCH] feat(any): introduce the Any database driver which enables choosing the database driver at runtime --- Cargo.toml | 9 + sqlx-core/Cargo.toml | 1 + sqlx-core/src/any/arguments.rs | 133 ++++++++ sqlx-core/src/any/connection/establish.rs | 40 +++ sqlx-core/src/any/connection/executor.rs | 221 ++++++++++++++ sqlx-core/src/any/connection/mod.rs | 170 +++++++++++ sqlx-core/src/any/database.rs | 39 +++ sqlx-core/src/any/decode.rs | 356 ++++++++++++++++++++++ sqlx-core/src/any/encode.rs | 354 +++++++++++++++++++++ sqlx-core/src/any/mod.rs | 29 ++ sqlx-core/src/any/options.rs | 88 ++++++ sqlx-core/src/any/row.rs | 80 +++++ sqlx-core/src/any/transaction.rs | 111 +++++++ sqlx-core/src/any/type_info.rs | 69 +++++ sqlx-core/src/any/types.rs | 243 +++++++++++++++ sqlx-core/src/any/value.rs | 174 +++++++++++ sqlx-core/src/lib.rs | 19 +- src/lib.rs | 11 + tests/any/any.rs | 87 ++++++ tests/postgres/postgres.rs | 4 + 20 files changed, 2236 insertions(+), 2 deletions(-) create mode 100644 sqlx-core/src/any/arguments.rs create mode 100644 sqlx-core/src/any/connection/establish.rs create mode 100644 sqlx-core/src/any/connection/executor.rs create mode 100644 sqlx-core/src/any/connection/mod.rs create mode 100644 sqlx-core/src/any/database.rs create mode 100644 sqlx-core/src/any/decode.rs create mode 100644 sqlx-core/src/any/encode.rs create mode 100644 sqlx-core/src/any/mod.rs create mode 100644 sqlx-core/src/any/options.rs create mode 100644 sqlx-core/src/any/row.rs create mode 100644 sqlx-core/src/any/transaction.rs create mode 100644 sqlx-core/src/any/type_info.rs create mode 100644 sqlx-core/src/any/types.rs create mode 100644 sqlx-core/src/any/value.rs create mode 100644 tests/any/any.rs diff --git a/Cargo.toml b/Cargo.toml index 4351d233..232c7bc8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ runtime-actix = [ "sqlx-core/runtime-actix", "sqlx-macros/runtime-actix" ] runtime-tokio = [ "sqlx-core/runtime-tokio", "sqlx-macros/runtime-tokio" ] # database +any = [ "sqlx-core/any" ] postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ] mysql = [ "sqlx-core/mysql", "sqlx-macros/mysql" ] sqlite = [ "sqlx-core/sqlite", "sqlx-macros/sqlite" ] @@ -89,6 +90,14 @@ paste = "0.1.16" serde = { version = "1.0.111", features = [ "derive" ] } serde_json = "1.0.53" +# +# Any +# + +[[test]] +name = "any" +path = "tests/any/any.rs" + # # SQLite # diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index d8219763..fd8a8ff7 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -21,6 +21,7 @@ postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand", "rsa" ] sqlite = [ "libsqlite3-sys" ] mssql = [ "uuid", "encoding_rs", "regex" ] +any = [] # types all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ] diff --git a/sqlx-core/src/any/arguments.rs b/sqlx-core/src/any/arguments.rs new file mode 100644 index 00000000..41b0b729 --- /dev/null +++ b/sqlx-core/src/any/arguments.rs @@ -0,0 +1,133 @@ +use crate::any::Any; +use crate::arguments::Arguments; +use crate::encode::Encode; +use crate::types::Type; + +#[derive(Default)] +pub struct AnyArguments<'q> { + values: Vec + Send + 'q>>, +} + +impl<'q> Arguments<'q> for AnyArguments<'q> { + type Database = Any; + + fn reserve(&mut self, additional: usize, _size: usize) { + self.values.reserve(additional); + } + + fn add(&mut self, value: T) + where + T: 'q + Send + Encode<'q, Self::Database> + Type, + { + self.values.push(Box::new(value)); + } +} + +pub struct AnyArgumentBuffer<'q>(pub(crate) AnyArgumentBufferKind<'q>); + +pub(crate) enum AnyArgumentBufferKind<'q> { + #[cfg(feature = "postgres")] + Postgres( + crate::postgres::PgArguments, + std::marker::PhantomData<&'q ()>, + ), + + #[cfg(feature = "mysql")] + MySql( + crate::mysql::MySqlArguments, + std::marker::PhantomData<&'q ()>, + ), + + #[cfg(feature = "sqlite")] + Sqlite(crate::sqlite::SqliteArguments<'q>), + + #[cfg(feature = "mssql")] + Mssql( + crate::mssql::MssqlArguments, + std::marker::PhantomData<&'q ()>, + ), +} + +// control flow inferred type bounds would be fun +// the compiler should know the branch is totally unreachable + +#[cfg(feature = "sqlite")] +#[allow(irrefutable_let_patterns)] +impl<'q> From> for crate::sqlite::SqliteArguments<'q> { + fn from(args: AnyArguments<'q>) -> Self { + let mut buf = AnyArgumentBuffer(AnyArgumentBufferKind::Sqlite(Default::default())); + + for value in args.values { + let _ = value.encode_by_ref(&mut buf); + } + + if let AnyArgumentBufferKind::Sqlite(args) = buf.0 { + args + } else { + unreachable!() + } + } +} + +#[cfg(feature = "mysql")] +#[allow(irrefutable_let_patterns)] +impl<'q> From> for crate::mysql::MySqlArguments { + fn from(args: AnyArguments<'q>) -> Self { + let mut buf = AnyArgumentBuffer(AnyArgumentBufferKind::MySql( + Default::default(), + std::marker::PhantomData, + )); + + for value in args.values { + let _ = value.encode_by_ref(&mut buf); + } + + if let AnyArgumentBufferKind::MySql(args, _) = buf.0 { + args + } else { + unreachable!() + } + } +} + +#[cfg(feature = "mssql")] +#[allow(irrefutable_let_patterns)] +impl<'q> From> for crate::mssql::MssqlArguments { + fn from(args: AnyArguments<'q>) -> Self { + let mut buf = AnyArgumentBuffer(AnyArgumentBufferKind::Mssql( + Default::default(), + std::marker::PhantomData, + )); + + for value in args.values { + let _ = value.encode_by_ref(&mut buf); + } + + if let AnyArgumentBufferKind::Mssql(args, _) = buf.0 { + args + } else { + unreachable!() + } + } +} + +#[cfg(feature = "postgres")] +#[allow(irrefutable_let_patterns)] +impl<'q> From> for crate::postgres::PgArguments { + fn from(args: AnyArguments<'q>) -> Self { + let mut buf = AnyArgumentBuffer(AnyArgumentBufferKind::Postgres( + Default::default(), + std::marker::PhantomData, + )); + + for value in args.values { + let _ = value.encode_by_ref(&mut buf); + } + + if let AnyArgumentBufferKind::Postgres(args, _) = buf.0 { + args + } else { + unreachable!() + } + } +} diff --git a/sqlx-core/src/any/connection/establish.rs b/sqlx-core/src/any/connection/establish.rs new file mode 100644 index 00000000..954a593b --- /dev/null +++ b/sqlx-core/src/any/connection/establish.rs @@ -0,0 +1,40 @@ +use crate::any::connection::AnyConnectionKind; +use crate::any::options::{AnyConnectOptions, AnyConnectOptionsKind}; +use crate::any::AnyConnection; +use crate::connection::Connect; +use crate::error::Error; + +impl AnyConnection { + pub(crate) async fn establish(options: &AnyConnectOptions) -> Result { + match &options.0 { + #[cfg(feature = "mysql")] + AnyConnectOptionsKind::MySql(options) => { + crate::mysql::MySqlConnection::connect_with(options) + .await + .map(AnyConnectionKind::MySql) + } + + #[cfg(feature = "postgres")] + AnyConnectOptionsKind::Postgres(options) => { + crate::postgres::PgConnection::connect_with(options) + .await + .map(AnyConnectionKind::Postgres) + } + + #[cfg(feature = "sqlite")] + AnyConnectOptionsKind::Sqlite(options) => { + crate::sqlite::SqliteConnection::connect_with(options) + .await + .map(AnyConnectionKind::Sqlite) + } + + #[cfg(feature = "mssql")] + AnyConnectOptionsKind::Mssql(options) => { + crate::mssql::MssqlConnection::connect_with(options) + .await + .map(AnyConnectionKind::Mssql) + } + } + .map(AnyConnection) + } +} diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs new file mode 100644 index 00000000..5637a748 --- /dev/null +++ b/sqlx-core/src/any/connection/executor.rs @@ -0,0 +1,221 @@ +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::{StreamExt, TryStreamExt}; + +use crate::any::connection::AnyConnectionKind; +use crate::any::row::AnyRowKind; +use crate::any::type_info::AnyTypeInfoKind; +use crate::any::{Any, AnyConnection, AnyRow, AnyTypeInfo}; +use crate::describe::{Column, Describe}; +use crate::error::Error; +use crate::executor::{Execute, Executor}; + +// FIXME: Some of the below, describe especially, is very messy/duplicated; perhaps we should have +// an `Into` that goes from `PgTypeInfo` to `AnyTypeInfo` and so on + +impl<'c> Executor<'c> for &'c mut AnyConnection { + type Database = Any; + + fn fetch_many<'e, 'q: 'e, E: 'q>( + self, + mut query: E, + ) -> BoxStream<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + let arguments = query.take_arguments(); + let query = query.query(); + + match &mut self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn + .fetch_many((query, arguments.map(Into::into))) + .map_ok(|v| match v { + Either::Right(row) => Either::Right(AnyRow(AnyRowKind::Postgres(row))), + Either::Left(count) => Either::Left(count), + }) + .boxed(), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn + .fetch_many((query, arguments.map(Into::into))) + .map_ok(|v| match v { + Either::Right(row) => Either::Right(AnyRow(AnyRowKind::MySql(row))), + Either::Left(count) => Either::Left(count), + }) + .boxed(), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn + .fetch_many((query, arguments.map(Into::into))) + .map_ok(|v| match v { + Either::Right(row) => Either::Right(AnyRow(AnyRowKind::Sqlite(row))), + Either::Left(count) => Either::Left(count), + }) + .boxed(), + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => conn + .fetch_many((query, arguments.map(Into::into))) + .map_ok(|v| match v { + Either::Right(row) => Either::Right(AnyRow(AnyRowKind::Mssql(row))), + Either::Left(count) => Either::Left(count), + }) + .boxed(), + } + } + + fn fetch_optional<'e, 'q: 'e, E: 'q>( + self, + mut query: E, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + let arguments = query.take_arguments(); + let query = query.query(); + + Box::pin(async move { + Ok(match &mut self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn + .fetch_optional((query, arguments.map(Into::into))) + .await? + .map(AnyRowKind::Postgres), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn + .fetch_optional((query, arguments.map(Into::into))) + .await? + .map(AnyRowKind::MySql), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn + .fetch_optional((query, arguments.map(Into::into))) + .await? + .map(AnyRowKind::Sqlite), + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => conn + .fetch_optional((query, arguments.map(Into::into))) + .await? + .map(AnyRowKind::Mssql), + } + .map(AnyRow)) + }) + } + + fn describe<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + let query = query.query(); + + Box::pin(async move { + Ok(match &mut self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => { + conn.describe(query).await.map(|desc| Describe { + params: desc + .params + .into_iter() + .map(|ty| ty.map(AnyTypeInfoKind::Postgres).map(AnyTypeInfo)) + .collect(), + + columns: desc + .columns + .into_iter() + .map(|column| Column { + name: column.name, + not_null: column.not_null, + type_info: column + .type_info + .map(AnyTypeInfoKind::Postgres) + .map(AnyTypeInfo), + }) + .collect(), + })? + } + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => { + conn.describe(query).await.map(|desc| Describe { + params: desc + .params + .into_iter() + .map(|ty| ty.map(AnyTypeInfoKind::MySql).map(AnyTypeInfo)) + .collect(), + + columns: desc + .columns + .into_iter() + .map(|column| Column { + name: column.name, + not_null: column.not_null, + type_info: column + .type_info + .map(AnyTypeInfoKind::MySql) + .map(AnyTypeInfo), + }) + .collect(), + })? + } + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => { + conn.describe(query).await.map(|desc| Describe { + params: desc + .params + .into_iter() + .map(|ty| ty.map(AnyTypeInfoKind::Sqlite).map(AnyTypeInfo)) + .collect(), + + columns: desc + .columns + .into_iter() + .map(|column| Column { + name: column.name, + not_null: column.not_null, + type_info: column + .type_info + .map(AnyTypeInfoKind::Sqlite) + .map(AnyTypeInfo), + }) + .collect(), + })? + } + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => { + conn.describe(query).await.map(|desc| Describe { + params: desc + .params + .into_iter() + .map(|ty| ty.map(AnyTypeInfoKind::Mssql).map(AnyTypeInfo)) + .collect(), + + columns: desc + .columns + .into_iter() + .map(|column| Column { + name: column.name, + not_null: column.not_null, + type_info: column + .type_info + .map(AnyTypeInfoKind::Mssql) + .map(AnyTypeInfo), + }) + .collect(), + })? + } + }) + }) + } +} diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs new file mode 100644 index 00000000..cc549445 --- /dev/null +++ b/sqlx-core/src/any/connection/mod.rs @@ -0,0 +1,170 @@ +use futures_core::future::BoxFuture; + +use crate::any::{Any, AnyConnectOptions}; +use crate::connection::{Connect, Connection}; +use crate::error::Error; + +#[cfg(feature = "postgres")] +use crate::postgres; + +#[cfg(feature = "sqlite")] +use crate::sqlite; + +#[cfg(feature = "mssql")] +use crate::mssql; + +#[cfg(feature = "mysql")] +use crate::mysql; + +mod establish; +mod executor; + +/// A connection to _any_ SQLx database. +/// +/// The database driver used is determined by the scheme +/// of the connection url. +/// +/// ```text +/// postgres://postgres@localhost/test +/// sqlite://a.sqlite +/// ``` +#[derive(Debug)] +pub struct AnyConnection(pub(super) AnyConnectionKind); + +#[derive(Debug)] +pub(crate) enum AnyConnectionKind { + #[cfg(feature = "postgres")] + Postgres(postgres::PgConnection), + + #[cfg(feature = "mssql")] + Mssql(mssql::MssqlConnection), + + #[cfg(feature = "mysql")] + MySql(mysql::MySqlConnection), + + #[cfg(feature = "sqlite")] + Sqlite(sqlite::SqliteConnection), +} + +macro_rules! delegate_to { + ($self:ident.$method:ident($($arg:ident),*)) => { + match &$self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn.$method($($arg),*), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn.$method($($arg),*), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn.$method($($arg),*), + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => conn.$method($($arg),*), + } + }; +} + +macro_rules! delegate_to_mut { + ($self:ident.$method:ident($($arg:ident),*)) => { + match &mut $self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn.$method($($arg),*), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn.$method($($arg),*), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn.$method($($arg),*), + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => conn.$method($($arg),*), + } + }; +} + +impl Connection for AnyConnection { + type Database = Any; + + fn close(self) -> BoxFuture<'static, Result<(), Error>> { + match self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn.close(), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn.close(), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn.close(), + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => conn.close(), + } + } + + fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { + delegate_to_mut!(self.ping()) + } + + fn cached_statements_size(&self) -> usize { + match &self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn.cached_statements_size(), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn.cached_statements_size(), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn.cached_statements_size(), + + // no cache + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(_) => 0, + } + } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + match &mut self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => conn.clear_cached_statements(), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => conn.clear_cached_statements(), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => conn.clear_cached_statements(), + + // no cache + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(_) => Box::pin(futures_util::future::ok(())), + } + } + + #[doc(hidden)] + fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { + delegate_to_mut!(self.flush()) + } + + #[doc(hidden)] + fn should_flush(&self) -> bool { + delegate_to!(self.should_flush()) + } + + #[doc(hidden)] + fn get_ref(&self) -> &Self { + self + } + + #[doc(hidden)] + fn get_mut(&mut self) -> &mut Self { + self + } +} + +impl Connect for AnyConnection { + type Options = AnyConnectOptions; + + #[inline] + fn connect_with(options: &Self::Options) -> BoxFuture<'_, Result> { + Box::pin(AnyConnection::establish(options)) + } +} diff --git a/sqlx-core/src/any/database.rs b/sqlx-core/src/any/database.rs new file mode 100644 index 00000000..65e52f53 --- /dev/null +++ b/sqlx-core/src/any/database.rs @@ -0,0 +1,39 @@ +use crate::any::{ + AnyArgumentBuffer, AnyArguments, AnyConnection, AnyRow, AnyTransactionManager, AnyTypeInfo, + AnyValue, AnyValueRef, +}; +use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; + +/// Opaque database driver. Capable of being used in place of any SQLx database driver. The actual +/// driver used will be selected at runtime, from the connection uri. +#[derive(Debug)] +pub struct Any; + +impl Database for Any { + type Connection = AnyConnection; + + type TransactionManager = AnyTransactionManager; + + type Row = AnyRow; + + type TypeInfo = AnyTypeInfo; + + type Value = AnyValue; +} + +impl<'r> HasValueRef<'r> for Any { + type Database = Any; + + type ValueRef = AnyValueRef<'r>; +} + +impl<'q> HasArguments<'q> for Any { + type Database = Any; + + type Arguments = AnyArguments<'q>; + + type ArgumentBuffer = AnyArgumentBuffer<'q>; +} + +// This _may_ be true, depending on the selected database +impl HasStatementCache for Any {} diff --git a/sqlx-core/src/any/decode.rs b/sqlx-core/src/any/decode.rs new file mode 100644 index 00000000..2d44d993 --- /dev/null +++ b/sqlx-core/src/any/decode.rs @@ -0,0 +1,356 @@ +use crate::any::value::AnyValueRefKind; +use crate::any::{Any, AnyValueRef}; +use crate::decode::Decode; +use crate::error::BoxDynError; +use crate::types::Type; + +#[cfg(feature = "postgres")] +use crate::postgres::Postgres; + +#[cfg(feature = "mysql")] +use crate::mysql::MySql; + +#[cfg(feature = "mssql")] +use crate::mssql::Mssql; + +#[cfg(feature = "sqlite")] +use crate::sqlite::Sqlite; + +// Implements Decode for any T where T supports Decode for any database that has support currently +// compiled into SQLx +impl<'r, T> Decode<'r, Any> for T +where + T: AnyDecode<'r>, +{ + fn decode(value: AnyValueRef<'r>) -> Result { + match value.0 { + #[cfg(feature = "mysql")] + AnyValueRefKind::MySql(value) => >::decode(value), + + #[cfg(feature = "sqlite")] + AnyValueRefKind::Sqlite(value) => { + >::decode(value) + } + + #[cfg(feature = "mssql")] + AnyValueRefKind::Mssql(value) => >::decode(value), + + #[cfg(feature = "postgres")] + AnyValueRefKind::Postgres(value) => { + >::decode(value) + } + } + } +} + +// FIXME: Find a nice way to auto-generate the below or petition Rust to add support for #[cfg] +// to trait bounds + +// all 4 + +#[cfg(all( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" +))] +pub trait AnyDecode<'r>: + Decode<'r, Postgres> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Mssql> + + Type + + Decode<'r, Sqlite> + + Type +{ +} + +#[cfg(all( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Postgres> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Mssql> + + Type + + Decode<'r, Sqlite> + + Type +{ +} + +// only 3 (4) + +#[cfg(all( + not(feature = "mssql"), + all(feature = "postgres", feature = "mysql", feature = "sqlite") +))] +pub trait AnyDecode<'r>: + Decode<'r, Postgres> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "mssql"), + all(feature = "postgres", feature = "mysql", feature = "sqlite") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Postgres> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "mysql"), + all(feature = "postgres", feature = "mssql", feature = "sqlite") +))] +pub trait AnyDecode<'r>: + Decode<'r, Postgres> + + Type + + Decode<'r, Mssql> + + Type + + Decode<'r, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "mysql"), + all(feature = "postgres", feature = "mssql", feature = "sqlite") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Postgres> + + Type + + Decode<'r, Mssql> + + Type + + Decode<'r, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "sqlite"), + all(feature = "postgres", feature = "mysql", feature = "mssql") +))] +pub trait AnyDecode<'r>: + Decode<'r, Postgres> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Mssql> + + Type +{ +} + +#[cfg(all( + not(feature = "sqlite"), + all(feature = "postgres", feature = "mysql", feature = "mssql") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Postgres> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Mssql> + + Type +{ +} + +#[cfg(all( + not(feature = "postgres"), + all(feature = "sqlite", feature = "mysql", feature = "mssql") +))] +pub trait AnyDecode<'r>: + Decode<'r, Sqlite> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Mssql> + + Type +{ +} + +#[cfg(all( + not(feature = "postgres"), + all(feature = "sqlite", feature = "mysql", feature = "mssql") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Sqlite> + + Type + + Decode<'r, MySql> + + Type + + Decode<'r, Mssql> + + Type +{ +} + +// only 2 (6) + +#[cfg(all( + not(any(feature = "mssql", feature = "sqlite")), + all(feature = "postgres", feature = "mysql") +))] +pub trait AnyDecode<'r>: + Decode<'r, Postgres> + Type + Decode<'r, MySql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mssql", feature = "sqlite")), + all(feature = "postgres", feature = "mysql") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Postgres> + Type + Decode<'r, MySql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "sqlite")), + all(feature = "postgres", feature = "mssql") +))] +pub trait AnyDecode<'r>: + Decode<'r, Postgres> + Type + Decode<'r, Mssql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "sqlite")), + all(feature = "postgres", feature = "mssql") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Postgres> + Type + Decode<'r, Mssql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql")), + all(feature = "postgres", feature = "sqlite") +))] +pub trait AnyDecode<'r>: + Decode<'r, Postgres> + Type + Decode<'r, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql")), + all(feature = "postgres", feature = "sqlite") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Postgres> + Type + Decode<'r, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "sqlite")), + all(feature = "mssql", feature = "mysql") +))] +pub trait AnyDecode<'r>: Decode<'r, Mssql> + Type + Decode<'r, MySql> + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "sqlite")), + all(feature = "mssql", feature = "mysql") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Mssql> + Type + Decode<'r, MySql> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mysql")), + all(feature = "mssql", feature = "sqlite") +))] +pub trait AnyDecode<'r>: + Decode<'r, Mssql> + Type + Decode<'r, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mysql")), + all(feature = "mssql", feature = "sqlite") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, Mssql> + Type + Decode<'r, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql")), + all(feature = "mysql", feature = "sqlite") +))] +pub trait AnyDecode<'r>: + Decode<'r, MySql> + Type + Decode<'r, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql")), + all(feature = "mysql", feature = "sqlite") +))] +impl<'r, T> AnyDecode<'r> for T where + T: Decode<'r, MySql> + Type + Decode<'r, Sqlite> + Type +{ +} + +// only 1 (4) + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), + feature = "postgres" +))] +pub trait AnyDecode<'r>: Decode<'r, Postgres> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), + feature = "postgres" +))] +impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, Postgres> + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), + feature = "mysql" +))] +pub trait AnyDecode<'r>: Decode<'r, MySql> + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), + feature = "mysql" +))] +impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, MySql> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), + feature = "mssql" +))] +pub trait AnyDecode<'r>: Decode<'r, Mssql> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), + feature = "mssql" +))] +impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, Mssql> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "postgres")), + feature = "sqlite" +))] +pub trait AnyDecode<'r>: Decode<'r, Sqlite> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "postgres")), + feature = "sqlite" +))] +impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, Sqlite> + Type {} diff --git a/sqlx-core/src/any/encode.rs b/sqlx-core/src/any/encode.rs new file mode 100644 index 00000000..a686b29a --- /dev/null +++ b/sqlx-core/src/any/encode.rs @@ -0,0 +1,354 @@ +use crate::any::arguments::AnyArgumentBufferKind; +use crate::any::{Any, AnyArgumentBuffer}; +use crate::encode::{Encode, IsNull}; +use crate::types::Type; + +#[cfg(feature = "postgres")] +use crate::postgres::Postgres; + +#[cfg(feature = "mysql")] +use crate::mysql::MySql; + +#[cfg(feature = "mssql")] +use crate::mssql::Mssql; + +#[cfg(feature = "sqlite")] +use crate::sqlite::Sqlite; + +// Implements Encode for any T where T supports Encode for any database that has support currently +// compiled into SQLx +impl<'q, T> Encode<'q, Any> for T +where + T: AnyEncode<'q>, +{ + fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer<'q>) -> IsNull { + match &mut buf.0 { + #[cfg(feature = "postgres")] + AnyArgumentBufferKind::Postgres(args, _) => args.add(self), + + #[cfg(feature = "mysql")] + AnyArgumentBufferKind::MySql(args, _) => args.add(self), + + #[cfg(feature = "mssql")] + AnyArgumentBufferKind::Mssql(args, _) => args.add(self), + + #[cfg(feature = "sqlite")] + AnyArgumentBufferKind::Sqlite(args) => args.add(self), + } + + // unused + IsNull::No + } +} + +// FIXME: Find a nice way to auto-generate the below or petition Rust to add support for #[cfg] +// to trait bounds + +// all 4 + +#[cfg(all( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" +))] +pub trait AnyEncode<'q>: + Encode<'q, Postgres> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Mssql> + + Type + + Encode<'q, Sqlite> + + Type +{ +} + +#[cfg(all( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Postgres> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Mssql> + + Type + + Encode<'q, Sqlite> + + Type +{ +} + +// only 3 (4) + +#[cfg(all( + not(feature = "mssql"), + all(feature = "postgres", feature = "mysql", feature = "sqlite") +))] +pub trait AnyEncode<'q>: + Encode<'q, Postgres> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "mssql"), + all(feature = "postgres", feature = "mysql", feature = "sqlite") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Postgres> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "mysql"), + all(feature = "postgres", feature = "mssql", feature = "sqlite") +))] +pub trait AnyEncode<'q>: + Encode<'q, Postgres> + + Type + + Encode<'q, Mssql> + + Type + + Encode<'q, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "mysql"), + all(feature = "postgres", feature = "mssql", feature = "sqlite") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Postgres> + + Type + + Encode<'q, Mssql> + + Type + + Encode<'q, Sqlite> + + Type +{ +} + +#[cfg(all( + not(feature = "sqlite"), + all(feature = "postgres", feature = "mysql", feature = "mssql") +))] +pub trait AnyEncode<'q>: + Encode<'q, Postgres> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Mssql> + + Type +{ +} + +#[cfg(all( + not(feature = "sqlite"), + all(feature = "postgres", feature = "mysql", feature = "mssql") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Postgres> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Mssql> + + Type +{ +} + +#[cfg(all( + not(feature = "postgres"), + all(feature = "sqlite", feature = "mysql", feature = "mssql") +))] +pub trait AnyEncode<'q>: + Encode<'q, Sqlite> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Mssql> + + Type +{ +} + +#[cfg(all( + not(feature = "postgres"), + all(feature = "sqlite", feature = "mysql", feature = "mssql") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Sqlite> + + Type + + Encode<'q, MySql> + + Type + + Encode<'q, Mssql> + + Type +{ +} + +// only 2 (6) + +#[cfg(all( + not(any(feature = "mssql", feature = "sqlite")), + all(feature = "postgres", feature = "mysql") +))] +pub trait AnyEncode<'q>: + Encode<'q, Postgres> + Type + Encode<'q, MySql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mssql", feature = "sqlite")), + all(feature = "postgres", feature = "mysql") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Postgres> + Type + Encode<'q, MySql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "sqlite")), + all(feature = "postgres", feature = "mssql") +))] +pub trait AnyEncode<'q>: + Encode<'q, Postgres> + Type + Encode<'q, Mssql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "sqlite")), + all(feature = "postgres", feature = "mssql") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Postgres> + Type + Encode<'q, Mssql> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql")), + all(feature = "postgres", feature = "sqlite") +))] +pub trait AnyEncode<'q>: + Encode<'q, Postgres> + Type + Encode<'q, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql")), + all(feature = "postgres", feature = "sqlite") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Postgres> + Type + Encode<'q, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "sqlite")), + all(feature = "mssql", feature = "mysql") +))] +pub trait AnyEncode<'q>: Encode<'q, Mssql> + Type + Encode<'q, MySql> + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "sqlite")), + all(feature = "mssql", feature = "mysql") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Mssql> + Type + Encode<'q, MySql> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mysql")), + all(feature = "mssql", feature = "sqlite") +))] +pub trait AnyEncode<'q>: + Encode<'q, Mssql> + Type + Encode<'q, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mysql")), + all(feature = "mssql", feature = "sqlite") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, Mssql> + Type + Encode<'q, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql")), + all(feature = "mysql", feature = "sqlite") +))] +pub trait AnyEncode<'q>: + Encode<'q, MySql> + Type + Encode<'q, Sqlite> + Type +{ +} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql")), + all(feature = "mysql", feature = "sqlite") +))] +impl<'q, T> AnyEncode<'q> for T where + T: Encode<'q, MySql> + Type + Encode<'q, Sqlite> + Type +{ +} + +// only 1 (4) + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), + feature = "postgres" +))] +pub trait AnyEncode<'q>: Encode<'q, Postgres> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), + feature = "postgres" +))] +impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, Postgres> + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), + feature = "mysql" +))] +pub trait AnyEncode<'q>: Encode<'q, MySql> + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), + feature = "mysql" +))] +impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, MySql> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), + feature = "mssql" +))] +pub trait AnyEncode<'q>: Encode<'q, Mssql> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), + feature = "mssql" +))] +impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, Mssql> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "postgres")), + feature = "sqlite" +))] +pub trait AnyEncode<'q>: Encode<'q, Sqlite> + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "postgres")), + feature = "sqlite" +))] +impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, Sqlite> + Type {} diff --git a/sqlx-core/src/any/mod.rs b/sqlx-core/src/any/mod.rs new file mode 100644 index 00000000..4d1ff894 --- /dev/null +++ b/sqlx-core/src/any/mod.rs @@ -0,0 +1,29 @@ +mod arguments; +mod connection; +mod database; +mod decode; +mod encode; +mod options; +mod row; +mod transaction; +mod type_info; +mod types; +mod value; + +pub use arguments::{AnyArgumentBuffer, AnyArguments}; +pub use connection::AnyConnection; +pub use database::Any; +pub use decode::AnyDecode; +pub use encode::AnyEncode; +pub use options::AnyConnectOptions; +pub use row::AnyRow; +pub use transaction::AnyTransactionManager; +pub use type_info::AnyTypeInfo; +pub use value::{AnyValue, AnyValueRef}; + +pub type AnyPool = crate::pool::Pool; + +// NOTE: required due to the lack of lazy normalization +impl_into_arguments_for_arguments!(AnyArguments<'q>); +impl_executor_for_pool_connection!(Any, AnyConnection, AnyRow); +impl_executor_for_transaction!(Any, AnyRow); diff --git a/sqlx-core/src/any/options.rs b/sqlx-core/src/any/options.rs new file mode 100644 index 00000000..a4d5ce52 --- /dev/null +++ b/sqlx-core/src/any/options.rs @@ -0,0 +1,88 @@ +use std::str::FromStr; + +use crate::error::BoxDynError; + +#[cfg(feature = "postgres")] +use crate::postgres::PgConnectOptions; + +#[cfg(feature = "mysql")] +use crate::mysql::MySqlConnectOptions; + +#[cfg(feature = "sqlite")] +use crate::sqlite::SqliteConnectOptions; + +#[cfg(feature = "mssql")] +use crate::mssql::MssqlConnectOptions; + +/// Opaque options for connecting to a database. These may only be constructed by parsing from +/// a connection uri. +/// +/// ```text +/// postgres://postgres:password@localhost/database +/// mysql://root:password@localhost/database +/// ``` +pub struct AnyConnectOptions(pub(crate) AnyConnectOptionsKind); + +pub(crate) enum AnyConnectOptionsKind { + #[cfg(feature = "postgres")] + Postgres(PgConnectOptions), + + #[cfg(feature = "mysql")] + MySql(MySqlConnectOptions), + + #[cfg(feature = "sqlite")] + Sqlite(SqliteConnectOptions), + + #[cfg(feature = "mssql")] + Mssql(MssqlConnectOptions), +} + +impl FromStr for AnyConnectOptions { + type Err = BoxDynError; + + fn from_str(url: &str) -> Result { + match url { + #[cfg(feature = "postgres")] + _ if url.starts_with("postgres:") || url.starts_with("postgresql:") => { + PgConnectOptions::from_str(url).map(AnyConnectOptionsKind::Postgres) + } + + #[cfg(not(feature = "postgres"))] + _ if url.starts_with("postgres:") || url.starts_with("postgresql:") => { + Err("database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled".into()) + } + + #[cfg(feature = "mysql")] + _ if url.starts_with("mysql:") || url.starts_with("mariadb:") => { + MySqlConnectOptions::from_str(url).map(AnyConnectOptionsKind::MySql) + } + + #[cfg(not(feature = "mysql"))] + _ if url.starts_with("mysql:") || url.starts_with("mariadb:") => { + Err("database URL has the scheme of a MySQL database but the `mysql` feature is not enabled".into()) + } + + #[cfg(feature = "sqlite")] + _ if url.starts_with("sqlite:") => { + SqliteConnectOptions::from_str(url).map(AnyConnectOptionsKind::Sqlite) + } + + #[cfg(not(feature = "sqlite"))] + _ if url.starts_with("sqlite:") => { + Err("database URL has the scheme of a SQLite database but the `sqlite` feature is not enabled".into()) + } + + #[cfg(feature = "mssql")] + _ if url.starts_with("mssql:") || url.starts_with("sqlserver:") => { + MssqlConnectOptions::from_str(url).map(AnyConnectOptionsKind::Mssql) + } + + #[cfg(not(feature = "mssql"))] + _ if url.starts_with("mssql:") || url.starts_with("sqlserver:") => { + Err("database URL has the scheme of a MSSQL database but the `mssql` feature is not enabled".into()) + } + + _ => Err(format!("unrecognized database url: {:?}", url).into()) + }.map(AnyConnectOptions) + } +} diff --git a/sqlx-core/src/any/row.rs b/sqlx-core/src/any/row.rs new file mode 100644 index 00000000..e876723f --- /dev/null +++ b/sqlx-core/src/any/row.rs @@ -0,0 +1,80 @@ +use crate::any::value::AnyValueRefKind; +use crate::any::{Any, AnyValueRef}; +use crate::database::HasValueRef; +use crate::error::Error; +use crate::row::{ColumnIndex, Row}; + +#[cfg(feature = "postgres")] +use crate::postgres::PgRow; + +#[cfg(feature = "mysql")] +use crate::mysql::MySqlRow; + +#[cfg(feature = "sqlite")] +use crate::sqlite::SqliteRow; + +#[cfg(feature = "mssql")] +use crate::mssql::MssqlRow; + +pub struct AnyRow(pub(crate) AnyRowKind); + +impl crate::row::private_row::Sealed for AnyRow {} + +pub(crate) enum AnyRowKind { + #[cfg(feature = "postgres")] + Postgres(PgRow), + + #[cfg(feature = "mysql")] + MySql(MySqlRow), + + #[cfg(feature = "sqlite")] + Sqlite(SqliteRow), + + #[cfg(feature = "mssql")] + Mssql(MssqlRow), +} + +impl Row for AnyRow { + type Database = Any; + + fn len(&self) -> usize { + match &self.0 { + #[cfg(feature = "postgres")] + AnyRowKind::Postgres(row) => row.len(), + + #[cfg(feature = "mysql")] + AnyRowKind::MySql(row) => row.len(), + + #[cfg(feature = "sqlite")] + AnyRowKind::Sqlite(row) => row.len(), + + #[cfg(feature = "mssql")] + AnyRowKind::Mssql(row) => row.len(), + } + } + + fn try_get_raw( + &self, + index: I, + ) -> Result<>::ValueRef, Error> + where + I: ColumnIndex, + { + let index = index.index(self)?; + + match &self.0 { + #[cfg(feature = "postgres")] + AnyRowKind::Postgres(row) => row.try_get_raw(index).map(AnyValueRefKind::Postgres), + + #[cfg(feature = "mysql")] + AnyRowKind::MySql(row) => row.try_get_raw(index).map(AnyValueRefKind::MySql), + + #[cfg(feature = "sqlite")] + AnyRowKind::Sqlite(row) => row.try_get_raw(index).map(AnyValueRefKind::Sqlite), + + #[cfg(feature = "mssql")] + AnyRowKind::Mssql(row) => row.try_get_raw(index).map(AnyValueRefKind::Mssql), + } + .map(AnyValueRef) + } +} diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs new file mode 100644 index 00000000..9b5ab3c0 --- /dev/null +++ b/sqlx-core/src/any/transaction.rs @@ -0,0 +1,111 @@ +use futures_util::future::BoxFuture; + +use crate::any::connection::AnyConnectionKind; +use crate::any::{Any, AnyConnection}; +use crate::database::Database; +use crate::error::Error; +use crate::transaction::TransactionManager; + +pub struct AnyTransactionManager; + +impl TransactionManager for AnyTransactionManager { + type Database = Any; + + fn begin(conn: &mut AnyConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> { + match &mut conn.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => { + ::TransactionManager::begin(conn, depth) + } + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => { + ::TransactionManager::begin(conn, depth) + } + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => { + ::TransactionManager::begin(conn, depth) + } + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => { + ::TransactionManager::begin(conn, depth) + } + } + } + + fn commit(conn: &mut AnyConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> { + match &mut conn.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => { + ::TransactionManager::commit(conn, depth) + } + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => { + ::TransactionManager::commit(conn, depth) + } + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => { + ::TransactionManager::commit(conn, depth) + } + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => { + ::TransactionManager::commit(conn, depth) + } + } + } + + fn rollback(conn: &mut AnyConnection, depth: usize) -> BoxFuture<'_, Result<(), Error>> { + match &mut conn.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => { + ::TransactionManager::rollback(conn, depth) + } + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => { + ::TransactionManager::rollback(conn, depth) + } + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => { + ::TransactionManager::rollback(conn, depth) + } + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => { + ::TransactionManager::rollback(conn, depth) + } + } + } + + fn start_rollback(conn: &mut AnyConnection, depth: usize) { + match &mut conn.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => { + ::TransactionManager::start_rollback( + conn, depth, + ) + } + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => { + ::TransactionManager::start_rollback(conn, depth) + } + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => { + ::TransactionManager::start_rollback(conn, depth) + } + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => { + ::TransactionManager::start_rollback(conn, depth) + } + } + } +} diff --git a/sqlx-core/src/any/type_info.rs b/sqlx-core/src/any/type_info.rs new file mode 100644 index 00000000..a056f3bd --- /dev/null +++ b/sqlx-core/src/any/type_info.rs @@ -0,0 +1,69 @@ +use std::fmt::{self, Display, Formatter}; + +use crate::type_info::TypeInfo; + +#[cfg(feature = "postgres")] +use crate::postgres::PgTypeInfo; + +#[cfg(feature = "mysql")] +use crate::mysql::MySqlTypeInfo; + +#[cfg(feature = "sqlite")] +use crate::sqlite::SqliteTypeInfo; + +#[cfg(feature = "mssql")] +use crate::mssql::MssqlTypeInfo; + +#[derive(Debug, Clone, PartialEq)] +pub struct AnyTypeInfo(pub(crate) AnyTypeInfoKind); + +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum AnyTypeInfoKind { + #[cfg(feature = "postgres")] + Postgres(PgTypeInfo), + + #[cfg(feature = "mysql")] + MySql(MySqlTypeInfo), + + #[cfg(feature = "sqlite")] + Sqlite(SqliteTypeInfo), + + #[cfg(feature = "mssql")] + Mssql(MssqlTypeInfo), +} + +impl TypeInfo for AnyTypeInfo { + fn name(&self) -> &str { + match &self.0 { + #[cfg(feature = "postgres")] + AnyTypeInfoKind::Postgres(ty) => ty.name(), + + #[cfg(feature = "mysql")] + AnyTypeInfoKind::MySql(ty) => ty.name(), + + #[cfg(feature = "sqlite")] + AnyTypeInfoKind::Sqlite(ty) => ty.name(), + + #[cfg(feature = "mssql")] + AnyTypeInfoKind::Mssql(ty) => ty.name(), + } + } +} + +impl Display for AnyTypeInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match &self.0 { + #[cfg(feature = "postgres")] + AnyTypeInfoKind::Postgres(ty) => ty.fmt(f), + + #[cfg(feature = "mysql")] + AnyTypeInfoKind::MySql(ty) => ty.fmt(f), + + #[cfg(feature = "sqlite")] + AnyTypeInfoKind::Sqlite(ty) => ty.fmt(f), + + #[cfg(feature = "mssql")] + AnyTypeInfoKind::Mssql(ty) => ty.fmt(f), + } + } +} diff --git a/sqlx-core/src/any/types.rs b/sqlx-core/src/any/types.rs new file mode 100644 index 00000000..a32c0ba6 --- /dev/null +++ b/sqlx-core/src/any/types.rs @@ -0,0 +1,243 @@ +use crate::any::type_info::AnyTypeInfoKind; +use crate::any::{Any, AnyTypeInfo}; +use crate::database::Database; +use crate::types::Type; + +#[cfg(feature = "postgres")] +use crate::postgres::Postgres; + +#[cfg(feature = "mysql")] +use crate::mysql::MySql; + +#[cfg(feature = "mssql")] +use crate::mssql::Mssql; + +#[cfg(feature = "sqlite")] +use crate::sqlite::Sqlite; + +// Type is required by the bounds of the [Row] and [Arguments] trait but its been overridden in +// AnyRow and AnyArguments to not use this implementation; but instead, delegate to the +// database-specific implementation. +// +// The other use of this trait is for compile-time verification which is not feasible to support +// for the [Any] driver. +impl Type for T +where + T: AnyType, +{ + fn type_info() -> ::TypeInfo { + // FIXME: nicer panic explaining why this isn't possible + unimplemented!() + } + + fn compatible(ty: &AnyTypeInfo) -> bool { + match &ty.0 { + #[cfg(feature = "postgres")] + AnyTypeInfoKind::Postgres(ty) => >::compatible(&ty), + + #[cfg(feature = "mysql")] + AnyTypeInfoKind::MySql(ty) => >::compatible(&ty), + + #[cfg(feature = "sqlite")] + AnyTypeInfoKind::Sqlite(ty) => >::compatible(&ty), + + #[cfg(feature = "mssql")] + AnyTypeInfoKind::Mssql(ty) => >::compatible(&ty), + } + } +} + +// FIXME: Find a nice way to auto-generate the below or petition Rust to add support for #[cfg] +// to trait bounds + +// all 4 + +#[cfg(all( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" +))] +pub trait AnyType: Type + Type + Type + Type {} + +#[cfg(all( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" +))] +impl AnyType for T where T: Type + Type + Type + Type {} + +// only 3 (4) + +#[cfg(all( + not(feature = "mssql"), + all(feature = "postgres", feature = "mysql", feature = "sqlite") +))] +pub trait AnyType: Type + Type + Type {} + +#[cfg(all( + not(feature = "mssql"), + all(feature = "postgres", feature = "mysql", feature = "sqlite") +))] +impl AnyType for T where T: Type + Type + Type {} + +#[cfg(all( + not(feature = "mysql"), + all(feature = "postgres", feature = "mssql", feature = "sqlite") +))] +pub trait AnyType: Type + Type + Type {} + +#[cfg(all( + not(feature = "mysql"), + all(feature = "postgres", feature = "mssql", feature = "sqlite") +))] +impl AnyType for T where T: Type + Type + Type {} + +#[cfg(all( + not(feature = "sqlite"), + all(feature = "postgres", feature = "mysql", feature = "mssql") +))] +pub trait AnyType: Type + Type + Type {} + +#[cfg(all( + not(feature = "sqlite"), + all(feature = "postgres", feature = "mysql", feature = "mssql") +))] +impl AnyType for T where T: Type + Type + Type {} + +#[cfg(all( + not(feature = "postgres"), + all(feature = "sqlite", feature = "mysql", feature = "mssql") +))] +pub trait AnyType: Type + Type + Type {} + +#[cfg(all( + not(feature = "postgres"), + all(feature = "sqlite", feature = "mysql", feature = "mssql") +))] +impl AnyType for T where T: Type + Type + Type {} + +// only 2 (6) + +#[cfg(all( + not(any(feature = "mssql", feature = "sqlite")), + all(feature = "postgres", feature = "mysql") +))] +pub trait AnyType: Type + Type {} + +#[cfg(all( + not(any(feature = "mssql", feature = "sqlite")), + all(feature = "postgres", feature = "mysql") +))] +impl AnyType for T where T: Type + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "sqlite")), + all(feature = "postgres", feature = "mssql") +))] +pub trait AnyType: Type + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "sqlite")), + all(feature = "postgres", feature = "mssql") +))] +impl AnyType for T where T: Type + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql")), + all(feature = "postgres", feature = "sqlite") +))] +pub trait AnyType: Type + Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql")), + all(feature = "postgres", feature = "sqlite") +))] +impl AnyType for T where T: Type + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "sqlite")), + all(feature = "mssql", feature = "mysql") +))] +pub trait AnyType: Type + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "sqlite")), + all(feature = "mssql", feature = "mysql") +))] +impl AnyType for T where T: Type + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mysql")), + all(feature = "mssql", feature = "sqlite") +))] +pub trait AnyType: Type + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mysql")), + all(feature = "mssql", feature = "sqlite") +))] +impl AnyType for T where T: Type + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql")), + all(feature = "mysql", feature = "sqlite") +))] +pub trait AnyType: Type + Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql")), + all(feature = "mysql", feature = "sqlite") +))] +impl AnyType for T where T: Type + Type {} + +// only 1 (4) + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), + feature = "postgres" +))] +pub trait AnyType: Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), + feature = "postgres" +))] +impl AnyType for T where T: Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), + feature = "mysql" +))] +pub trait AnyType: Type {} + +#[cfg(all( + not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), + feature = "mysql" +))] +impl AnyType for T where T: Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), + feature = "mssql" +))] +pub trait AnyType: Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), + feature = "mssql" +))] +impl AnyType for T where T: Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "postgres")), + feature = "sqlite" +))] +pub trait AnyType: Type {} + +#[cfg(all( + not(any(feature = "mysql", feature = "mssql", feature = "postgres")), + feature = "sqlite" +))] +impl AnyType for T where T: Type {} diff --git a/sqlx-core/src/any/value.rs b/sqlx-core/src/any/value.rs new file mode 100644 index 00000000..60095860 --- /dev/null +++ b/sqlx-core/src/any/value.rs @@ -0,0 +1,174 @@ +use std::borrow::Cow; + +use crate::any::type_info::AnyTypeInfoKind; +use crate::any::{Any, AnyTypeInfo}; +use crate::database::HasValueRef; +use crate::value::{Value, ValueRef}; + +#[cfg(feature = "postgres")] +use crate::postgres::{PgValue, PgValueRef}; + +#[cfg(feature = "mysql")] +use crate::mysql::{MySqlValue, MySqlValueRef}; + +#[cfg(feature = "sqlite")] +use crate::sqlite::{SqliteValue, SqliteValueRef}; + +#[cfg(feature = "mssql")] +use crate::mssql::{MssqlValue, MssqlValueRef}; + +pub struct AnyValue(AnyValueKind); + +pub(crate) enum AnyValueKind { + #[cfg(feature = "postgres")] + Postgres(PgValue), + + #[cfg(feature = "mysql")] + MySql(MySqlValue), + + #[cfg(feature = "sqlite")] + Sqlite(SqliteValue), + + #[cfg(feature = "mssql")] + Mssql(MssqlValue), +} + +pub struct AnyValueRef<'r>(pub(crate) AnyValueRefKind<'r>); + +pub(crate) enum AnyValueRefKind<'r> { + #[cfg(feature = "postgres")] + Postgres(PgValueRef<'r>), + + #[cfg(feature = "mysql")] + MySql(MySqlValueRef<'r>), + + #[cfg(feature = "sqlite")] + Sqlite(SqliteValueRef<'r>), + + #[cfg(feature = "mssql")] + Mssql(MssqlValueRef<'r>), +} + +impl Value for AnyValue { + type Database = Any; + + fn as_ref(&self) -> >::ValueRef { + AnyValueRef(match &self.0 { + #[cfg(feature = "postgres")] + AnyValueKind::Postgres(value) => AnyValueRefKind::Postgres(value.as_ref()), + + #[cfg(feature = "mysql")] + AnyValueKind::MySql(value) => AnyValueRefKind::MySql(value.as_ref()), + + #[cfg(feature = "sqlite")] + AnyValueKind::Sqlite(value) => AnyValueRefKind::Sqlite(value.as_ref()), + + #[cfg(feature = "mssql")] + AnyValueKind::Mssql(value) => AnyValueRefKind::Mssql(value.as_ref()), + }) + } + + fn type_info(&self) -> Option> { + match &self.0 { + #[cfg(feature = "postgres")] + AnyValueKind::Postgres(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::Postgres(ty.into_owned())), + + #[cfg(feature = "mysql")] + AnyValueKind::MySql(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::MySql(ty.into_owned())), + + #[cfg(feature = "sqlite")] + AnyValueKind::Sqlite(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::Sqlite(ty.into_owned())), + + #[cfg(feature = "mssql")] + AnyValueKind::Mssql(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::Mssql(ty.into_owned())), + } + .map(AnyTypeInfo) + .map(Cow::Owned) + } + + fn is_null(&self) -> bool { + match &self.0 { + #[cfg(feature = "postgres")] + AnyValueKind::Postgres(value) => value.is_null(), + + #[cfg(feature = "mysql")] + AnyValueKind::MySql(value) => value.is_null(), + + #[cfg(feature = "sqlite")] + AnyValueKind::Sqlite(value) => value.is_null(), + + #[cfg(feature = "mssql")] + AnyValueKind::Mssql(value) => value.is_null(), + } + } +} + +impl<'r> ValueRef<'r> for AnyValueRef<'r> { + type Database = Any; + + fn to_owned(&self) -> AnyValue { + AnyValue(match &self.0 { + #[cfg(feature = "postgres")] + AnyValueRefKind::Postgres(value) => AnyValueKind::Postgres(ValueRef::to_owned(value)), + + #[cfg(feature = "mysql")] + AnyValueRefKind::MySql(value) => AnyValueKind::MySql(ValueRef::to_owned(value)), + + #[cfg(feature = "sqlite")] + AnyValueRefKind::Sqlite(value) => AnyValueKind::Sqlite(ValueRef::to_owned(value)), + + #[cfg(feature = "mssql")] + AnyValueRefKind::Mssql(value) => AnyValueKind::Mssql(ValueRef::to_owned(value)), + }) + } + + fn type_info(&self) -> Option> { + match &self.0 { + #[cfg(feature = "postgres")] + AnyValueRefKind::Postgres(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::Postgres(ty.into_owned())), + + #[cfg(feature = "mysql")] + AnyValueRefKind::MySql(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::MySql(ty.into_owned())), + + #[cfg(feature = "sqlite")] + AnyValueRefKind::Sqlite(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::Sqlite(ty.into_owned())), + + #[cfg(feature = "mssql")] + AnyValueRefKind::Mssql(value) => value + .type_info() + .map(|ty| AnyTypeInfoKind::Mssql(ty.into_owned())), + } + .map(AnyTypeInfo) + .map(Cow::Owned) + } + + fn is_null(&self) -> bool { + match &self.0 { + #[cfg(feature = "postgres")] + AnyValueRefKind::Postgres(value) => value.is_null(), + + #[cfg(feature = "mysql")] + AnyValueRefKind::MySql(value) => value.is_null(), + + #[cfg(feature = "sqlite")] + AnyValueRefKind::Sqlite(value) => value.is_null(), + + #[cfg(feature = "mssql")] + AnyValueRefKind::Mssql(value) => value.is_null(), + } + } +} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 118b42c8..5db5b076 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -37,9 +37,14 @@ pub mod transaction; #[macro_use] pub mod encode; +#[macro_use] +pub mod decode; + +#[macro_use] +pub mod types; + mod common; pub mod database; -pub mod decode; pub mod describe; pub mod executor; pub mod from_row; @@ -50,9 +55,19 @@ pub mod query_as; pub mod query_scalar; pub mod row; pub mod type_info; -pub mod types; pub mod value; +#[cfg(all( + any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" + ), + feature = "any" +))] +pub mod any; + #[cfg(feature = "postgres")] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] pub mod postgres; diff --git a/src/lib.rs b/src/lib.rs index a11f3ce3..5d9d9b52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,17 @@ pub use sqlx_core::describe; #[doc(inline)] pub use sqlx_core::error::{self, Error, Result}; +#[cfg(all( + any( + feature = "mysql", + feature = "sqlite", + feature = "postgres", + feature = "mssql" + ), + feature = "any" +))] +pub use sqlx_core::any::{self, Any, AnyConnection, AnyPool}; + #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] pub use sqlx_core::mysql::{self, MySql, MySqlConnection, MySqlPool}; diff --git a/tests/any/any.rs b/tests/any/any.rs new file mode 100644 index 00000000..e32b338b --- /dev/null +++ b/tests/any/any.rs @@ -0,0 +1,87 @@ +use sqlx::any::AnyRow; +use sqlx::{Any, Connection, Executor, Row}; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_connects() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let value = sqlx::query("select 1 + 5") + .try_map(|row: AnyRow| row.try_get::(0)) + .fetch_one(&mut conn) + .await?; + + assert_eq!(6i32, value); + + conn.close().await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_pings() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.ping().await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_executes_with_pool() -> anyhow::Result<()> { + let pool = sqlx_test::pool::().await?; + + let rows = pool.fetch_all("SELECT 1; SElECT 2").await?; + + assert_eq!(rows.len(), 2); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_fail_and_recover() -> anyhow::Result<()> { + let mut conn = new::().await?; + + for i in 0..10 { + // make a query that will fail + let res = conn + .execute("INSERT INTO not_found (column) VALUES (10)") + .await; + + assert!(res.is_err()); + + // now try and use the connection + let val: i32 = conn + .fetch_one(&*format!("SELECT {}", i)) + .await? + .get_unchecked(0); + + assert_eq!(val, i); + } + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { + let pool = sqlx_test::pool::().await?; + + for i in 0..10 { + // make a query that will fail + let res = pool + .execute("INSERT INTO not_found (column) VALUES (10)") + .await; + + assert!(res.is_err()); + + // now try and use the connection + let val: i32 = pool + .fetch_one(&*format!("SELECT {}", i)) + .await? + .get_unchecked(0); + + assert_eq!(val, i); + } + + Ok(()) +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index e582b671..eb83f490 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -153,6 +153,7 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { let res = conn .execute("INSERT INTO not_found (column) VALUES (10)") .await; + assert!(res.is_err()); // now try and use the connection @@ -160,6 +161,7 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { .fetch_one(&*format!("SELECT {}::int4", i)) .await? .get(0); + assert_eq!(val, i); } @@ -175,6 +177,7 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { let res = pool .execute("INSERT INTO not_found (column) VALUES (10)") .await; + assert!(res.is_err()); // now try and use the connection @@ -182,6 +185,7 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { .fetch_one(&*format!("SELECT {}::int4", i)) .await? .get(0); + assert_eq!(val, i); }