diff --git a/Cargo.lock b/Cargo.lock index 6fd861c2..ddd2b0ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1928,10 +1928,13 @@ dependencies = [ "dotenv", "futures 0.3.4", "heck", - "lazy_static", + "hex", + "once_cell", "proc-macro2", "quote", + "serde", "serde_json", + "sha2", "sqlx-core", "syn", "tokio 0.2.13", diff --git a/Cargo.toml b/Cargo.toml index 392ba259..3502d5ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,9 @@ default = [ "macros", "runtime-async-std" ] macros = [ "sqlx-macros" ] tls = [ "sqlx-core/tls" ] +# offline building support in `sqlx-macros` +offline = ["sqlx-macros/offline", "sqlx-core/offline"] + # intended mainly for CI and docs all = [ "tls", "all-database", "all-type" ] all-database = [ "mysql", "sqlite", "postgres" ] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 7ffdd0ba..e3a976f0 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -32,6 +32,9 @@ runtime-tokio = [ "async-native-tls/runtime-tokio", "tokio" ] # intended for internal benchmarking, do not use bench = [] +# support offline/decoupled building (enables serialization of `Describe`) +offline = ["serde"] + [dependencies] async-native-tls = { version = "0.3.2", default-features = false, optional = true } async-std = { version = "1.5.0", features = [ "unstable" ], optional = true } diff --git a/sqlx-core/src/describe.rs b/sqlx-core/src/describe.rs index 2a7c2294..d66e4b4c 100644 --- a/sqlx-core/src/describe.rs +++ b/sqlx-core/src/describe.rs @@ -7,6 +7,14 @@ use crate::database::Database; /// The return type of [`Executor::describe`]. /// /// [`Executor::describe`]: crate::executor::Executor::describe +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr( + feature = "offline", + serde(bound( + serialize = "DB::TypeInfo: serde::Serialize, Column: serde::Serialize", + deserialize = "DB::TypeInfo: serde::de::DeserializeOwned, Column: serde::de::DeserializeOwned" + )) +)] #[non_exhaustive] pub struct Describe where @@ -35,6 +43,14 @@ where } /// A single column of a result set. +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr( + feature = "offline", + serde(bound( + serialize = "DB::TableId: serde::Serialize, DB::TypeInfo: serde::Serialize", + deserialize = "DB::TableId: serde::de::DeserializeOwned, DB::TypeInfo: serde::de::DeserializeOwned" + )) +)] #[non_exhaustive] pub struct Column where diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index 70f64789..5b92a951 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -1,6 +1,7 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html // https://mariadb.com/kb/en/library/resultset/#field-types #[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct TypeId(pub u8); // https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429 diff --git a/sqlx-core/src/mysql/type_info.rs b/sqlx-core/src/mysql/type_info.rs index 9c86619a..10c8ba0c 100644 --- a/sqlx-core/src/mysql/type_info.rs +++ b/sqlx-core/src/mysql/type_info.rs @@ -4,6 +4,7 @@ use crate::mysql::protocol::{ColumnDefinition, FieldFlags, TypeId}; use crate::types::TypeInfo; #[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct MySqlTypeInfo { pub(crate) id: TypeId, pub(crate) is_unsigned: bool, diff --git a/sqlx-core/src/postgres/protocol/type_id.rs b/sqlx-core/src/postgres/protocol/type_id.rs index 6da81511..1b85ecc6 100644 --- a/sqlx-core/src/postgres/protocol/type_id.rs +++ b/sqlx-core/src/postgres/protocol/type_id.rs @@ -2,6 +2,7 @@ use crate::postgres::types::try_resolve_type_name; use std::fmt::{self, Display}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct TypeId(pub(crate) u32); // DEVELOPER PRO TIP: find builtin type OIDs easily by grepping this file diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index e953ae82..8f0f8b73 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -9,6 +9,7 @@ use std::sync::Arc; /// Type information for a Postgres SQL type. #[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgTypeInfo { pub(crate) id: Option, pub(crate) name: SharedStr, @@ -186,8 +187,38 @@ impl From for SharedStr { } } +impl From for String { + fn from(s: SharedStr) -> Self { + String::from(&*s) + } +} + impl fmt::Display for SharedStr { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.pad(self) } } + +// manual impls because otherwise things get a little screwy with lifetimes +#[cfg(feature = "offline")] +impl<'de> serde::Deserialize<'de> for SharedStr { + fn deserialize(deserializer: D) -> Result>::Error> + where + D: serde::Deserializer<'de>, + { + Ok(String::deserialize(deserializer)?.into()) + } +} + +#[cfg(feature = "offline")] +impl serde::Serialize for SharedStr { + fn serialize( + &self, + serializer: S, + ) -> Result<::Ok, ::Error> + where + S: serde::Serializer, + { + serializer.serialize_str(&self) + } +} diff --git a/sqlx-core/src/sqlite/type_info.rs b/sqlx-core/src/sqlite/type_info.rs index e77dbb7f..43088ab7 100644 --- a/sqlx-core/src/sqlite/type_info.rs +++ b/sqlx-core/src/sqlite/type_info.rs @@ -4,6 +4,7 @@ use crate::types::TypeInfo; // https://www.sqlite.org/c3ref/c_blob.html #[derive(Debug, PartialEq, Clone, Copy)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub(crate) enum SqliteType { Integer = 1, Float = 2, @@ -16,6 +17,7 @@ pub(crate) enum SqliteType { // https://www.sqlite.org/datatype3.html#type_affinity #[derive(Debug, PartialEq, Clone, Copy)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub(crate) enum SqliteTypeAffinity { Text, Numeric, @@ -25,6 +27,7 @@ pub(crate) enum SqliteTypeAffinity { } #[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct SqliteTypeInfo { pub(crate) r#type: SqliteType, pub(crate) affinity: Option, diff --git a/sqlx-core/src/url.rs b/sqlx-core/src/url.rs index ace27fdc..8015b9e1 100644 --- a/sqlx-core/src/url.rs +++ b/sqlx-core/src/url.rs @@ -28,6 +28,14 @@ impl<'s> TryFrom<&'s String> for Url { } } +impl TryFrom for Url { + type Error = url::ParseError; + + fn try_from(value: url::Url) -> Result { + Ok(Url(value)) + } +} + impl Url { #[allow(dead_code)] pub(crate) fn as_str(&self) -> &str { diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 14a4355f..52f3e21c 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -18,32 +18,38 @@ proc-macro = true [features] default = [ "runtime-async-std" ] -runtime-async-std = [ "sqlx/runtime-async-std", "async-std" ] -runtime-tokio = [ "sqlx/runtime-tokio", "tokio", "lazy_static" ] +runtime-async-std = [ "sqlx-core/runtime-async-std", "async-std" ] +runtime-tokio = [ "sqlx-core/runtime-tokio", "tokio", "once_cell" ] + +# offline building support +offline = ["sqlx-core/offline", "serde", "serde_json", "hex", "sha2"] # database -mysql = [ "sqlx/mysql" ] -postgres = [ "sqlx/postgres" ] -sqlite = [ "sqlx/sqlite" ] +mysql = [ "sqlx-core/mysql" ] +postgres = [ "sqlx-core/postgres" ] +sqlite = [ "sqlx-core/sqlite" ] # type -bigdecimal = [ "sqlx/bigdecimal" ] -chrono = [ "sqlx/chrono" ] -time = [ "sqlx/time" ] -ipnetwork = [ "sqlx/ipnetwork" ] -uuid = [ "sqlx/uuid" ] -json = [ "sqlx/json", "serde_json" ] +bigdecimal = [ "sqlx-core/bigdecimal" ] +chrono = [ "sqlx-core/chrono" ] +time = [ "sqlx-core/time" ] +ipnetwork = [ "sqlx-core/ipnetwork" ] +uuid = [ "sqlx-core/uuid" ] +json = [ "sqlx-core/json", "serde_json" ] [dependencies] async-std = { version = "1.5.0", default-features = false, optional = true } tokio = { version = "0.2.13", default-features = false, features = [ "rt-threaded" ], optional = true } dotenv = { version = "0.15.0", default-features = false } futures = { version = "0.3.4", default-features = false, features = [ "executor" ] } +hex = { version = "0.4.2", optional = true } heck = "0.3" proc-macro2 = { version = "1.0.9", default-features = false } -sqlx = { version = "0.3.5", default-features = false, path = "../sqlx-core", package = "sqlx-core" } +sqlx-core = { version = "0.3.5", default-features = false, path = "../sqlx-core" } +serde = { version = "1.0", optional = true } serde_json = { version = "1.0", features = [ "raw_value" ], optional = true } +sha2 = { version = "0.8.1", optional = true } syn = { version = "1.0.16", default-features = false, features = [ "full" ] } quote = { version = "1.0.2", default-features = false } url = { version = "2.1.1", default-features = false } -lazy_static = { version = "1.4.0", optional = true } +once_cell = { version = "1.3", features = ["std"], optional = true } diff --git a/sqlx-macros/src/database/mod.rs b/sqlx-macros/src/database/mod.rs index 10d56c60..efb40af2 100644 --- a/sqlx-macros/src/database/mod.rs +++ b/sqlx-macros/src/database/mod.rs @@ -1,4 +1,4 @@ -use sqlx::database::Database; +use sqlx_core::database::Database; #[derive(PartialEq, Eq)] #[allow(dead_code)] @@ -10,6 +10,7 @@ pub enum ParamChecking { pub trait DatabaseExt: Database { const DATABASE_PATH: &'static str; const ROW_PATH: &'static str; + const NAME: &'static str; const PARAM_CHECKING: ParamChecking; @@ -34,23 +35,25 @@ macro_rules! impl_database_ext { $($(#[$meta:meta])? $ty:ty $(| $input:ty)?),*$(,)? }, ParamChecking::$param_checking:ident, - feature-types: $name:ident => $get_gate:expr, - row = $row:path + feature-types: $ty_info:ident => $get_gate:expr, + row = $row:path, + name = $db_name:literal ) => { impl $crate::database::DatabaseExt for $database { const DATABASE_PATH: &'static str = stringify!($database); const ROW_PATH: &'static str = stringify!($row); const PARAM_CHECKING: $crate::database::ParamChecking = $crate::database::ParamChecking::$param_checking; + const NAME: &'static str = $db_name; fn param_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> { match () { $( $(#[$meta])? - _ if <$ty as sqlx::types::Type<$database>>::type_info() == *info => Some(input_ty!($ty $(, $input)?)), + _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some(input_ty!($ty $(, $input)?)), )* $( $(#[$meta])? - _ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)), + _ if sqlx_core::types::TypeInfo::compatible(&<$ty as sqlx_core::types::Type<$database>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)), )* _ => None } @@ -60,17 +63,17 @@ macro_rules! impl_database_ext { match () { $( $(#[$meta])? - _ if <$ty as sqlx::types::Type<$database>>::type_info() == *info => return Some(stringify!($ty)), + _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => return Some(stringify!($ty)), )* $( $(#[$meta])? - _ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => return Some(stringify!($ty)), + _ if sqlx_core::types::TypeInfo::compatible(&<$ty as sqlx_core::types::Type<$database>>::type_info(), &info) => return Some(stringify!($ty)), )* _ => None } } - fn get_feature_gate($name: &Self::TypeInfo) -> Option<&'static str> { + fn get_feature_gate($ty_info: &Self::TypeInfo) -> Option<&'static str> { $get_gate } } diff --git a/sqlx-macros/src/database/mysql.rs b/sqlx-macros/src/database/mysql.rs index dee86a2c..3584a15e 100644 --- a/sqlx-macros/src/database/mysql.rs +++ b/sqlx-macros/src/database/mysql.rs @@ -1,5 +1,5 @@ impl_database_ext! { - sqlx::mysql::MySql { + sqlx_core::mysql::MySql { u8, u16, u32, @@ -18,33 +18,34 @@ impl_database_ext! { Vec, #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveTime, + sqlx_core::types::chrono::NaiveTime, #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDate, + sqlx_core::types::chrono::NaiveDate, #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::NaiveDateTime, + sqlx_core::types::chrono::NaiveDateTime, #[cfg(all(feature = "chrono", not(feature = "time")))] - sqlx::types::chrono::DateTime, + sqlx_core::types::chrono::DateTime, #[cfg(feature = "time")] - sqlx::types::time::Time, + sqlx_core::types::time::Time, #[cfg(feature = "time")] - sqlx::types::time::Date, + sqlx_core::types::time::Date, #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, + sqlx_core::types::time::PrimitiveDateTime, #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, + sqlx_core::types::time::OffsetDateTime, #[cfg(feature = "bigdecimal")] - sqlx::types::BigDecimal, + sqlx_core::types::BigDecimal, }, ParamChecking::Weak, feature-types: info => info.type_feature_gate(), - row = sqlx::mysql::MySqlRow + row = sqlx_core::mysql::MySqlRow, + name = "MySQL/MariaDB" } diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index 452ad27d..a52da8cd 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -1,5 +1,5 @@ impl_database_ext! { - sqlx::postgres::Postgres { + sqlx_core::postgres::Postgres { bool, String | &str, i8, @@ -13,37 +13,37 @@ impl_database_ext! { Vec | &[u8], #[cfg(feature = "uuid")] - sqlx::types::Uuid, + sqlx_core::types::Uuid, #[cfg(feature = "chrono")] - sqlx::types::chrono::NaiveTime, + sqlx_core::types::chrono::NaiveTime, #[cfg(feature = "chrono")] - sqlx::types::chrono::NaiveDate, + sqlx_core::types::chrono::NaiveDate, #[cfg(feature = "chrono")] - sqlx::types::chrono::NaiveDateTime, + sqlx_core::types::chrono::NaiveDateTime, #[cfg(feature = "chrono")] - sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, + sqlx_core::types::chrono::DateTime | sqlx_core::types::chrono::DateTime<_>, #[cfg(feature = "time")] - sqlx::types::time::Time, + sqlx_core::types::time::Time, #[cfg(feature = "time")] - sqlx::types::time::Date, + sqlx_core::types::time::Date, #[cfg(feature = "time")] - sqlx::types::time::PrimitiveDateTime, + sqlx_core::types::time::PrimitiveDateTime, #[cfg(feature = "time")] - sqlx::types::time::OffsetDateTime, + sqlx_core::types::time::OffsetDateTime, #[cfg(feature = "bigdecimal")] - sqlx::types::BigDecimal, + sqlx_core::types::BigDecimal, #[cfg(feature = "ipnetwork")] - sqlx::types::ipnetwork::IpNetwork, + sqlx_core::types::ipnetwork::IpNetwork, #[cfg(feature = "json")] serde_json::Value, @@ -61,41 +61,42 @@ impl_database_ext! { #[cfg(feature = "uuid")] - Vec | &[sqlx::types::Uuid], + Vec | &[sqlx_core::types::Uuid], #[cfg(feature = "chrono")] - Vec | &[sqlx::types::sqlx::types::chrono::NaiveTime], + Vec | &[sqlx_core::types::sqlx_core::types::chrono::NaiveTime], #[cfg(feature = "chrono")] - Vec | &[sqlx::types::chrono::NaiveDate], + Vec | &[sqlx_core::types::chrono::NaiveDate], #[cfg(feature = "chrono")] - Vec | &[sqlx::types::chrono::NaiveDateTime], + Vec | &[sqlx_core::types::chrono::NaiveDateTime], // TODO // #[cfg(feature = "chrono")] - // Vec> | &[sqlx::types::chrono::DateTime<_>], + // Vec> | &[sqlx_core::types::chrono::DateTime<_>], #[cfg(feature = "time")] - Vec | &[sqlx::types::time::Time], + Vec | &[sqlx_core::types::time::Time], #[cfg(feature = "time")] - Vec | &[sqlx::types::time::Date], + Vec | &[sqlx_core::types::time::Date], #[cfg(feature = "time")] - Vec | &[sqlx::types::time::PrimitiveDateTime], + Vec | &[sqlx_core::types::time::PrimitiveDateTime], #[cfg(feature = "time")] - Vec | &[sqlx::types::time::OffsetDateTime], + Vec | &[sqlx_core::types::time::OffsetDateTime], #[cfg(feature = "bigdecimal")] - Vec | &[sqlx::types::BigDecimal], + Vec | &[sqlx_core::types::BigDecimal], #[cfg(feature = "ipnetwork")] - Vec | &[sqlx::types::ipnetwork::IpNetwork], + Vec | &[sqlx_core::types::ipnetwork::IpNetwork], }, ParamChecking::Strong, feature-types: info => info.type_feature_gate(), - row = sqlx::postgres::PgRow + row = sqlx_core::postgres::PgRow, + name = "PostgreSQL" } diff --git a/sqlx-macros/src/database/sqlite.rs b/sqlx-macros/src/database/sqlite.rs index c7ee2bf9..72cf2f6e 100644 --- a/sqlx-macros/src/database/sqlite.rs +++ b/sqlx-macros/src/database/sqlite.rs @@ -1,5 +1,5 @@ impl_database_ext! { - sqlx::sqlite::Sqlite { + sqlx_core::sqlite::Sqlite { bool, i32, i64, @@ -10,5 +10,6 @@ impl_database_ext! { }, ParamChecking::Weak, feature-types: _info => None, - row = sqlx::sqlite::SqliteRow + row = sqlx_core::sqlite::SqliteRow, + name = "SQLite" } diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index beac1dd8..99b92a77 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -11,8 +11,6 @@ use quote::quote; #[cfg(feature = "runtime-async-std")] use async_std::task::block_on; -use std::path::PathBuf; - use url::Url; type Error = Box; @@ -26,23 +24,6 @@ mod runtime; use query_macros::*; -#[cfg(feature = "runtime-tokio")] -lazy_static::lazy_static! { - static ref BASIC_RUNTIME: tokio::runtime::Runtime = { - tokio::runtime::Builder::new() - .threaded_scheduler() - .enable_io() - .enable_time() - .build() - .expect("failed to build tokio runtime") - }; -} - -#[cfg(feature = "runtime-tokio")] -fn block_on(future: F) -> F::Output { - BASIC_RUNTIME.enter(|| futures::executor::block_on(future)) -} - fn macro_result(tokens: proc_macro2::TokenStream) -> TokenStream { quote!( macro_rules! macro_result { @@ -52,141 +33,21 @@ fn macro_result(tokens: proc_macro2::TokenStream) -> TokenStream { .into() } -macro_rules! async_macro ( - ($db:ident, $input:ident: $ty:ty => $expr:expr) => {{ - let $input = match syn::parse::<$ty>($input) { - Ok(input) => input, - Err(e) => return macro_result(e.to_compile_error()), - }; +#[proc_macro] +pub fn expand_query(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as QueryMacroInput); - let res: Result = block_on(async { - use sqlx::connection::Connect; - - // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, - // otherwise fallback to default dotenv behaviour. - if let Ok(dir) = std::env::var("CARGO_MANIFEST_DIR") { - let env_path = PathBuf::from(dir).join(".env"); - if env_path.exists() { - dotenv::from_path(&env_path) - .map_err(|e| format!("failed to load environment from {:?}, {}", env_path, e))? - } - } - - let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?; - - match db_url.scheme() { - #[cfg(feature = "sqlite")] - "sqlite" => { - let $db = sqlx::sqlite::SqliteConnection::connect(db_url.as_str()) - .await - .map_err(|e| format!("failed to connect to database: {}", e))?; - - $expr.await - } - #[cfg(not(feature = "sqlite"))] - "sqlite" => Err(format!( - "DATABASE_URL {} has the scheme of a SQLite database but the `sqlite` \ - feature of sqlx was not enabled", - db_url - ).into()), - #[cfg(feature = "postgres")] - "postgresql" | "postgres" => { - let $db = sqlx::postgres::PgConnection::connect(db_url.as_str()) - .await - .map_err(|e| format!("failed to connect to database: {}", e))?; - - $expr.await - } - #[cfg(not(feature = "postgres"))] - "postgresql" | "postgres" => Err(format!( - "DATABASE_URL {} has the scheme of a Postgres database but the `postgres` \ - feature of sqlx was not enabled", - db_url - ).into()), - #[cfg(feature = "mysql")] - "mysql" | "mariadb" => { - let $db = sqlx::mysql::MySqlConnection::connect(db_url.as_str()) - .await - .map_err(|e| format!("failed to connect to database: {}", e))?; - - $expr.await - } - #[cfg(not(feature = "mysql"))] - "mysql" | "mariadb" => Err(format!( - "DATABASE_URL {} has the scheme of a MySQL/MariaDB database but the `mysql` \ - feature of sqlx was not enabled", - db_url - ).into()), - scheme => Err(format!("unexpected scheme {:?} in DATABASE_URL {}", scheme, db_url).into()), - } - }); - - match res { - Ok(ts) => ts.into(), - Err(e) => { - if let Some(parse_err) = e.downcast_ref::() { - macro_result(parse_err.to_compile_error()) - } else { - let msg = e.to_string(); - macro_result(quote!(compile_error!(#msg))) - } + match query_macros::expand_input(input) { + Ok(ts) => ts.into(), + Err(e) => { + if let Some(parse_err) = e.downcast_ref::() { + macro_result(parse_err.to_compile_error()) + } else { + let msg = e.to_string(); + macro_result(quote!(compile_error!(#msg))) } } - }} -); - -#[proc_macro] -#[allow(unused_variables)] -pub fn query(input: TokenStream) -> TokenStream { - #[allow(unused_variables)] - async_macro!(db, input: QueryMacroInput => expand_query(input, db, true)) -} - -#[proc_macro] -#[allow(unused_variables)] -pub fn query_unchecked(input: TokenStream) -> TokenStream { - #[allow(unused_variables)] - async_macro!(db, input: QueryMacroInput => expand_query(input, db, false)) -} - -#[proc_macro] -#[allow(unused_variables)] -pub fn query_file(input: TokenStream) -> TokenStream { - #[allow(unused_variables)] - async_macro!(db, input: QueryMacroInput => expand_query_file(input, db, true)) -} - -#[proc_macro] -#[allow(unused_variables)] -pub fn query_file_unchecked(input: TokenStream) -> TokenStream { - #[allow(unused_variables)] - async_macro!(db, input: QueryMacroInput => expand_query_file(input, db, false)) -} - -#[proc_macro] -#[allow(unused_variables)] -pub fn query_as(input: TokenStream) -> TokenStream { - #[allow(unused_variables)] - async_macro!(db, input: QueryAsMacroInput => expand_query_as(input, db, true)) -} - -#[proc_macro] -#[allow(unused_variables)] -pub fn query_file_as(input: TokenStream) -> TokenStream { - async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db, true)) -} - -#[proc_macro] -#[allow(unused_variables)] -pub fn query_as_unchecked(input: TokenStream) -> TokenStream { - #[allow(unused_variables)] - async_macro!(db, input: QueryAsMacroInput => expand_query_as(input, db, false)) -} - -#[proc_macro] -#[allow(unused_variables)] -pub fn query_file_as_unchecked(input: TokenStream) -> TokenStream { - async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db, false)) + } } #[proc_macro_derive(Encode, attributes(sqlx))] diff --git a/sqlx-macros/src/query_macros/args.rs b/sqlx-macros/src/query_macros/args.rs index 85e6d9aa..f42954f3 100644 --- a/sqlx-macros/src/query_macros/args.rs +++ b/sqlx-macros/src/query_macros/args.rs @@ -3,7 +3,7 @@ use syn::spanned::Spanned; use syn::Expr; use quote::{quote, quote_spanned, ToTokens}; -use sqlx::describe::Describe; +use sqlx_core::describe::Describe; use crate::database::{DatabaseExt, ParamChecking}; use crate::query_macros::QueryMacroInput; @@ -13,7 +13,6 @@ use crate::query_macros::QueryMacroInput; pub fn quote_args( input: &QueryMacroInput, describe: &Describe, - checked: bool, ) -> crate::Result { let db_path = DB::db_path(); @@ -25,7 +24,7 @@ pub fn quote_args( let arg_name = &input.arg_names; - let args_check = if checked && DB::PARAM_CHECKING == ParamChecking::Strong { + let args_check = if input.checked && DB::PARAM_CHECKING == ParamChecking::Strong { describe .param_types .iter() diff --git a/sqlx-macros/src/query_macros/data.rs b/sqlx-macros/src/query_macros/data.rs new file mode 100644 index 00000000..a2e469be --- /dev/null +++ b/sqlx-macros/src/query_macros/data.rs @@ -0,0 +1,175 @@ +use sqlx_core::connection::{Connect, Connection}; +use sqlx_core::database::Database; +use sqlx_core::describe::Describe; +use sqlx_core::executor::{Executor, RefExecutor}; +use url::Url; + +use std::fmt::{self, Display, Formatter}; + +use crate::database::DatabaseExt; +use proc_macro2::TokenStream; +use std::fs::File; +use syn::export::Span; + +// TODO: enable serialization +#[cfg_attr(feature = "offline", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr( + feature = "offline", + serde(bound( + serialize = "Describe: serde::Serialize", + deserialize = "Describe: serde::de::DeserializeOwned" + )) +)] +pub struct QueryData { + pub(super) query: String, + pub(super) describe: Describe, +} + +impl QueryData { + pub async fn from_db( + conn: &mut impl Executor, + query: &str, + ) -> crate::Result { + Ok(QueryData { + query: query.into(), + describe: conn.describe(query).await?, + }) + } +} + +#[cfg(feature = "offline")] +pub mod offline { + use super::QueryData; + use std::fs::File; + + use std::fmt::{self, Formatter}; + + use crate::database::DatabaseExt; + use proc_macro2::{Span, TokenStream}; + use serde::de::{Deserializer, MapAccess, Visitor}; + use sqlx_core::describe::Describe; + use sqlx_core::query::query; + use std::path::Path; + + #[derive(serde::Deserialize)] + pub struct DynQueryData { + #[serde(skip)] + pub db_name: String, + pub query: String, + pub describe: serde_json::Value, + } + + impl DynQueryData { + /// Find and deserialize the data table for this query from a shared `sqlx-data.json` + /// file. The expected structure is a JSON map keyed by the SHA-256 hash of queries in hex. + pub fn from_data_file(path: impl AsRef, query: &str) -> crate::Result { + serde_json::Deserializer::from_reader( + File::open(path.as_ref()).map_err(|e| { + format!("failed to open path {}: {}", path.as_ref().display(), e) + })?, + ) + .deserialize_map(DataFileVisitor { + query, + hash: hash_string(query), + }) + .map_err(Into::into) + } + } + + impl QueryData + where + Describe: serde::Serialize + serde::de::DeserializeOwned, + { + pub fn from_dyn_data(dyn_data: DynQueryData) -> crate::Result { + assert!(!dyn_data.db_name.is_empty()); + if DB::NAME == dyn_data.db_name { + let describe: Describe = serde_json::from_value(dyn_data.describe)?; + Ok(QueryData { + query: dyn_data.query, + describe, + }) + } else { + Err(format!( + "expected query data for {}, got data for {}", + DB::NAME, + dyn_data.db_name + ) + .into()) + } + } + + pub fn save_in(&self, dir: impl AsRef, input_span: Span) -> crate::Result<()> { + // we save under the hash of the span representation because that should be unique + // per invocation + let path = dir.as_ref().join(format!( + "query-{}.json", + hash_string(&format!("{:?}", input_span)) + )); + + serde_json::to_writer_pretty( + File::create(&path) + .map_err(|e| format!("failed to open path {}: {}", path.display(), e))?, + self, + ) + .map_err(Into::into) + } + } + + fn hash_string(query: &str) -> String { + // picked `sha2` because it's already in the dependency tree for both MySQL and Postgres + use sha2::{Digest, Sha256}; + + hex::encode(Sha256::digest(query.as_bytes())) + } + + // lazily deserializes only the `QueryData` for the query we're looking for + struct DataFileVisitor<'a> { + query: &'a str, + hash: String, + } + + impl<'de> Visitor<'de> for DataFileVisitor<'_> { + type Value = DynQueryData; + + fn expecting(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "expected map key {:?} or \"db\"", self.hash) + } + + fn visit_map(self, mut map: A) -> Result>::Error> + where + A: MapAccess<'de>, + { + let mut db_name: Option = None; + + // unfortunately we can't avoid this copy because deserializing from `io::Read` + // doesn't support deserializing borrowed values + while let Some(key) = map.next_key::()? { + // lazily deserialize the query data only + if key == "db" { + db_name = Some(map.next_value::()?); + } else if key == self.hash { + let db_name = db_name.ok_or_else(|| { + serde::de::Error::custom("expected \"db\" key before query hash keys") + })?; + + let mut query_data: DynQueryData = map.next_value()?; + + return if query_data.query == self.query { + query_data.db_name = db_name; + Ok(query_data) + } else { + Err(serde::de::Error::custom(format_args!( + "hash collision for stored queries:\n{:?}\n{:?}", + self.query, query_data.query + ))) + }; + } + } + + Err(serde::de::Error::custom(format_args!( + "failed to find data for query {}", + self.hash + ))) + } + } +} diff --git a/sqlx-macros/src/query_macros/input.rs b/sqlx-macros/src/query_macros/input.rs index c9601f3c..a9ce0857 100644 --- a/sqlx-macros/src/query_macros/input.rs +++ b/sqlx-macros/src/query_macros/input.rs @@ -1,171 +1,122 @@ use std::env; +use std::fs; use proc_macro2::{Ident, Span}; use quote::{format_ident, ToTokens}; -use syn::parse::{Parse, ParseStream}; +use syn::parse::{Parse, ParseBuffer, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::token::Group; -use syn::{Expr, ExprLit, ExprPath, Lit}; -use syn::{ExprGroup, Token}; +use syn::{Error, Expr, ExprLit, ExprPath, Lit, LitBool, LitStr, Token}; +use syn::{ExprArray, ExprGroup, Type}; -use sqlx::connection::Connection; -use sqlx::describe::Describe; - -use crate::runtime::fs; +use sqlx_core::connection::Connection; +use sqlx_core::describe::Describe; /// Macro input shared by `query!()` and `query_file!()` pub struct QueryMacroInput { - pub(super) source: String, - pub(super) source_span: Span, + pub(super) src: String, + pub(super) src_span: Span, + + pub(super) data_src: DataSrc, + + pub(super) record_type: RecordType, + // `arg0 .. argN` for N arguments pub(super) arg_names: Vec, pub(super) arg_exprs: Vec, + + pub(super) checked: bool, } -impl QueryMacroInput { - fn from_exprs(input: ParseStream, mut args: impl Iterator) -> syn::Result { - fn lit_err(span: Span, unexpected: Expr) -> syn::Result { - Err(syn::Error::new( - span, - format!( - "expected string literal, got {}", - unexpected.to_token_stream() - ), - )) - } +enum QuerySrc { + String(String), + File(String), +} - let (source, source_span) = match args.next() { - Some(Expr::Lit(ExprLit { - lit: Lit::Str(sql), .. - })) => (sql.value(), sql.span()), - Some(Expr::Group(ExprGroup { - expr, - group_token: Group { span }, - .. - })) => { - // this duplication with the above is necessary because `expr` is `Box` here - // which we can't directly pattern-match without `box_patterns` - match *expr { - Expr::Lit(ExprLit { - lit: Lit::Str(sql), .. - }) => (sql.value(), span), - other_expr => return lit_err(span, other_expr), - } - } - Some(other_expr) => return lit_err(other_expr.span(), other_expr), - None => return Err(input.error("expected SQL string literal")), - }; +pub enum DataSrc { + Env(String), + DbUrl(String), + File, +} - let arg_exprs: Vec<_> = args.collect(); - let arg_names = (0..arg_exprs.len()) - .map(|i| format_ident!("arg{}", i)) - .collect(); - - Ok(Self { - source, - source_span, - arg_exprs, - arg_names, - }) - } - - pub async fn expand_file_src(self) -> syn::Result { - let source = read_file_src(&self.source, self.source_span).await?; - - Ok(Self { source, ..self }) - } - - /// Run a parse/describe on the query described by this input and validate that it matches the - /// passed number of args - pub async fn describe_validate( - &self, - conn: &mut C, - ) -> crate::Result> { - let describe = conn - .describe(&*self.source) - .await - .map_err(|e| syn::Error::new(self.source_span, e))?; - - if self.arg_names.len() != describe.param_types.len() { - return Err(syn::Error::new( - Span::call_site(), - format!( - "expected {} parameters, got {}", - describe.param_types.len(), - self.arg_names.len() - ), - ) - .into()); - } - - Ok(describe) - } +pub enum RecordType { + Given(Type), + Generated, } impl Parse for QueryMacroInput { fn parse(input: ParseStream) -> syn::Result { - let args = Punctuated::::parse_terminated(input)?.into_iter(); + let mut query_src: Option<(QuerySrc, Span)> = None; + let mut data_src = DataSrc::Env("DATABASE_URL".into()); + let mut args: Option> = None; + let mut record_type = RecordType::Generated; + let mut checked = true; - Self::from_exprs(input, args) - } -} + let mut expect_comma = false; -/// Macro input shared by `query_as!()` and `query_file_as!()` -pub struct QueryAsMacroInput { - pub(super) as_ty: ExprPath, - pub(super) query_input: QueryMacroInput, -} + while !input.is_empty() { + if expect_comma { + let _ = input.parse::()?; + } -impl QueryAsMacroInput { - pub async fn expand_file_src(self) -> syn::Result { - Ok(Self { - query_input: self.query_input.expand_file_src().await?, - ..self - }) - } -} + let key: Ident = input.parse()?; -impl Parse for QueryAsMacroInput { - fn parse(input: ParseStream) -> syn::Result { - fn path_err(span: Span, unexpected: Expr) -> syn::Result { - Err(syn::Error::new( - span, - format!( - "expected path to a type, got {}", - unexpected.to_token_stream() - ), - )) + let _ = input.parse::()?; + + if key == "source" { + let lit_str = input.parse::()?; + query_src = Some((QuerySrc::String(lit_str.value()), lit_str.span())); + } else if key == "source_file" { + let lit_str = input.parse::()?; + query_src = Some((QuerySrc::File(lit_str.value()), lit_str.span())); + } else if key == "args" { + let exprs = input.parse::()?; + args = Some(exprs.elems.into_iter().collect()) + } else if key == "record" { + record_type = RecordType::Given(input.parse()?); + } else if key == "checked" { + let lit_bool = input.parse::()?; + checked = lit_bool.value; + } else { + let message = format!("unexpected input key: {}", key); + return Err(syn::Error::new_spanned(key, message)); + } + + expect_comma = true; } - let mut args = Punctuated::::parse_terminated(input)?.into_iter(); + let (src, src_span) = + query_src.ok_or_else(|| input.error("expected `source` or `source_file` key"))?; - let as_ty = match args.next() { - Some(Expr::Path(path)) => path, - Some(Expr::Group(ExprGroup { - expr, - group_token: Group { span }, - .. - })) => { - // this duplication with the above is necessary because `expr` is `Box` here - // which we can't directly pattern-match without `box_patterns` - match *expr { - Expr::Path(path) => path, - other_expr => return path_err(span, other_expr), - } - } - Some(other_expr) => return path_err(other_expr.span(), other_expr), - None => return Err(input.error("expected path to SQL file")), - }; + let arg_exprs = args.unwrap_or_default(); + let arg_names = (0..arg_exprs.len()) + .map(|i| format_ident!("arg{}", i)) + .collect(); - Ok(QueryAsMacroInput { - as_ty, - query_input: QueryMacroInput::from_exprs(input, args)?, + Ok(QueryMacroInput { + src: src.resolve(src_span)?, + src_span, + data_src, + record_type, + arg_names, + arg_exprs, + checked, }) } } -async fn read_file_src(source: &str, source_span: Span) -> syn::Result { +impl QuerySrc { + /// If the query source is a file, read it to a string. Otherwise return the query string. + fn resolve(self, source_span: Span) -> syn::Result { + match self { + QuerySrc::String(string) => Ok(string), + QuerySrc::File(file) => read_file_src(&file, source_span), + } + } +} + +fn read_file_src(source: &str, source_span: Span) -> syn::Result { use std::path::Path; let path = Path::new(source); @@ -201,7 +152,7 @@ async fn read_file_src(source: &str, source_span: Span) -> syn::Result { let file_path = base_dir_path.join(path); - fs::read_to_string(&file_path).await.map_err(|e| { + fs::read_to_string(&file_path).map_err(|e| { syn::Error::new( source_span, format!( diff --git a/sqlx-macros/src/query_macros/mod.rs b/sqlx-macros/src/query_macros/mod.rs index e2eeca7b..a1075e4e 100644 --- a/sqlx-macros/src/query_macros/mod.rs +++ b/sqlx-macros/src/query_macros/mod.rs @@ -1,68 +1,223 @@ +use std::borrow::Cow; +use std::env; use std::fmt::Display; +use std::path::PathBuf; -use proc_macro2::TokenStream; +use proc_macro2::{Ident, Span, TokenStream}; +use syn::Type; +use url::Url; + +pub use input::QueryMacroInput; use quote::{format_ident, quote}; - -pub use input::{QueryAsMacroInput, QueryMacroInput}; -pub use query::expand_query; +use sqlx_core::connection::Connect; +use sqlx_core::connection::Connection; +use sqlx_core::database::Database; +use sqlx_core::describe::Describe; use crate::database::DatabaseExt; +use crate::query_macros::data::QueryData; +use crate::query_macros::input::RecordType; +use crate::runtime::block_on; -use sqlx::connection::Connection; -use sqlx::database::Database; +// pub use query::expand_query; mod args; +mod data; mod input; mod output; -mod query; +// mod query; -pub async fn expand_query_file( - input: QueryMacroInput, - conn: C, - checked: bool, -) -> crate::Result -where - C::Database: DatabaseExt + Sized, - ::TypeInfo: Display, -{ - expand_query(input.expand_file_src().await?, conn, checked).await +pub fn expand_input(input: QueryMacroInput) -> crate::Result { + let manifest_dir = + env::var("CARGO_MANIFEST_DIR").map_err(|_| "`CARGO_MANIFEST_DIR` must be set")?; + + // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, + // otherwise fallback to default dotenv behaviour. + let env_path = std::path::Path::new(&manifest_dir).join(".env"); + if env_path.exists() { + dotenv::from_path(&env_path) + .map_err(|e| format!("failed to load environment from {:?}, {}", env_path, e))? + } + + // if `dotenv` wasn't initialized by the above we make sure to do it here + match dotenv::var("DATABASE_URL").ok() { + Some(db_url) => expand_from_db(input, &db_url), + #[cfg(feature = "offline")] + None => { + let data_file_path = std::path::Path::new(&manifest_dir).join("sqlx-data.json"); + + if data_file_path.exists() { + expand_from_file(input, data_file_path) + } else { + Err( + "`DATABASE_URL` must be set, or `cargo sqlx prepare` must have been run \ + and sqlx-data.json must exist, to use query macros" + .into(), + ) + } + } + #[cfg(not(feature = "offline"))] + None => Err("`DATABASE_URL` must be set to use query macros".into()), + } } -pub async fn expand_query_as( - input: QueryAsMacroInput, - mut conn: C, - checked: bool, +fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { + let db_url = Url::parse(db_url)?; + match db_url.scheme() { + #[cfg(feature = "postgres")] + "postgres" | "postgresql" => { + let data = block_on(async { + let mut conn = sqlx_core::postgres::PgConnection::connect(db_url).await?; + QueryData::from_db(&mut conn, &input.src).await + })?; + + expand_with_data(input, data) + }, + #[cfg(not(feature = "postgres"))] + "postgres" | "postgresql" => Err(format!("database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled").into()), + #[cfg(feature = "mysql")] + "mysql" | "mariadb" => { + let data = block_on(async { + let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url).await?; + QueryData::from_db(&mut conn, &input.src).await + })?; + + expand_with_data(input, data) + }, + #[cfg(not(feature = "mysql"))] + "mysql" | "mariadb" => Err(format!("database URL has the scheme of a MySQL/MariaDB database but the `mysql` feature is not enabled").into()), + #[cfg(feature = "sqlite")] + "sqlite" => { + let data = block_on(async { + let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url).await?; + QueryData::from_db(&mut conn, &input.src).await + })?; + + expand_with_data(input, data) + }, + #[cfg(not(feature = "sqlite"))] + "sqlite" => Err(format!("database URL has the scheme of a SQLite database but the `sqlite` feature is not enabled").into()), + scheme => Err(format!("unknown database URL scheme {:?}", scheme).into()) + } +} + +#[cfg(feature = "offline")] +pub fn expand_from_file(input: QueryMacroInput, file: PathBuf) -> crate::Result { + use data::offline::DynQueryData; + + let query_data = DynQueryData::from_data_file(file, &input.src)?; + assert!(!query_data.db_name.is_empty()); + + match &*query_data.db_name { + #[cfg(feature = "postgres")] + sqlx_core::postgres::Postgres::NAME => expand_with_data( + input, + QueryData::::from_dyn_data(query_data)?, + ), + #[cfg(feature = "mysql")] + sqlx_core::mysql::MySql::NAME => expand_with_data( + input, + QueryData::::from_dyn_data(query_data)?, + ), + #[cfg(feature = "sqlite")] + sqlx_core::sqlite::Sqlite::NAME => expand_with_data( + input, + QueryData::::from_dyn_data(query_data)?, + ), + _ => Err(format!( + "found query data for {} but the feature for that database was not enabled", + query_data.db_name + ) + .into()), + } +} + +// marker trait for `Describe` that lets us conditionally require it to be `Serialize + Deserialize` +#[cfg(feature = "offline")] +trait DescribeExt: serde::Serialize + serde::de::DeserializeOwned {} + +#[cfg(feature = "offline")] +impl DescribeExt for Describe where + Describe: serde::Serialize + serde::de::DeserializeOwned +{ +} + +#[cfg(not(feature = "offline"))] +trait DescribeExt {} + +#[cfg(not(feature = "offline"))] +impl DescribeExt for Describe {} + +fn expand_with_data( + input: QueryMacroInput, + data: QueryData, ) -> crate::Result where - C::Database: DatabaseExt + Sized, - ::TypeInfo: Display, + Describe: DescribeExt, { - let describe = input.query_input.describe_validate(&mut conn).await?; - - if describe.result_columns.is_empty() { + // validate at the minimum that our args match the query's input parameters + if input.arg_names.len() != data.describe.param_types.len() { return Err(syn::Error::new( - input.query_input.source_span, - "query must output at least one column", + Span::call_site(), + format!( + "expected {} parameters, got {}", + data.describe.param_types.len(), + input.arg_names.len() + ), ) .into()); } - let args_tokens = args::quote_args(&input.query_input, &describe, checked)?; + let args_tokens = args::quote_args(&input, &data.describe)?; let query_args = format_ident!("query_args"); - let columns = output::columns_to_rust(&describe)?; - let output = output::quote_query_as::( - &input.query_input.source, - &input.as_ty.path, - &query_args, - &columns, - checked, - ); + let output = if data.describe.result_columns.is_empty() { + let db_path = DB::db_path(); + let sql = &input.src; - let arg_names = &input.query_input.arg_names; + quote! { + sqlx::query::<#db_path>(#sql).bind_all(#query_args) + } + } else { + let columns = output::columns_to_rust::(&data.describe)?; - Ok(quote! { + let (out_ty, mut record_tokens) = match input.record_type { + RecordType::Generated => { + let record_name: Type = syn::parse_str("Record").unwrap(); + + let record_fields = columns.iter().map( + |&output::RustColumn { + ref ident, + ref type_, + }| quote!(#ident: #type_,), + ); + + let record_tokens = quote! { + #[derive(Debug)] + struct #record_name { + #(#record_fields)* + } + }; + + (Cow::Owned(record_name), record_tokens) + } + RecordType::Given(ref out_ty) => (Cow::Borrowed(out_ty), quote!()), + }; + + record_tokens.extend(output::quote_query_as::( + &input, + &out_ty, + &query_args, + &columns, + )); + + record_tokens + }; + + let arg_names = &input.arg_names; + + let ret_tokens = quote! { macro_rules! macro_result { (#($#arg_names:expr),*) => {{ use sqlx::arguments::Arguments as _; @@ -72,17 +227,14 @@ where #output }} } - }) -} + }; -pub async fn expand_query_file_as( - input: QueryAsMacroInput, - conn: C, - checked: bool, -) -> crate::Result -where - C::Database: DatabaseExt + Sized, - ::TypeInfo: Display, -{ - expand_query_as(input.expand_file_src().await?, conn, checked).await + #[cfg(feature = "offline")] + { + let save_dir = env::var("CARGO_TARGET_DIR").unwrap_or_else(|_| "target/sqlx".into()); + std::fs::create_dir_all(&save_dir); + data.save_in(save_dir, input.src_span)?; + } + + Ok(ret_tokens) } diff --git a/sqlx-macros/src/query_macros/output.rs b/sqlx-macros/src/query_macros/output.rs index 45cbca2c..ee67751e 100644 --- a/sqlx-macros/src/query_macros/output.rs +++ b/sqlx-macros/src/query_macros/output.rs @@ -1,11 +1,12 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; -use syn::Path; +use syn::{Path, Type}; -use sqlx::describe::Describe; +use sqlx_core::describe::Describe; use crate::database::DatabaseExt; +use crate::query_macros::QueryMacroInput; use std::fmt::{self, Display, Formatter}; pub struct RustColumn { @@ -98,11 +99,10 @@ pub fn columns_to_rust(describe: &Describe) -> crate::Resul } pub fn quote_query_as( - sql: &str, - out_ty: &Path, + input: &QueryMacroInput, + out_ty: &Type, bind_args: &Ident, columns: &[RustColumn], - checked: bool, ) -> TokenStream { let instantiations = columns.iter().enumerate().map( |( @@ -116,7 +116,7 @@ pub fn quote_query_as( // For "checked" queries, the macro checks these at compile time and using "try_get" // would also perform pointless runtime checks - if checked { + if input.checked { quote!( #ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()? ) } else { quote!( #ident: row.try_get_unchecked(#i)? ) @@ -126,6 +126,7 @@ pub fn quote_query_as( let db_path = DB::db_path(); let row_path = DB::row_path(); + let sql = &input.src; quote! { sqlx::query::<#db_path>(#sql).bind_all(#bind_args).try_map(|row: #row_path| { diff --git a/sqlx-macros/src/query_macros/query.rs b/sqlx-macros/src/query_macros/query.rs index 54004067..094425ea 100644 --- a/sqlx-macros/src/query_macros/query.rs +++ b/sqlx-macros/src/query_macros/query.rs @@ -5,7 +5,7 @@ use proc_macro2::TokenStream; use syn::{Ident, Path}; use quote::{format_ident, quote}; -use sqlx::{connection::Connection, database::Database}; +use sqlx_core::{connection::Connection, database::Database}; use super::{args, output, QueryMacroInput}; use crate::database::DatabaseExt; @@ -22,7 +22,7 @@ where ::TypeInfo: Display, { let describe = input.describe_validate(&mut conn).await?; - let sql = &input.source; + let sql = &input.src; let args = args::quote_args(&input, &describe, checked)?; @@ -33,7 +33,7 @@ where return Ok(quote! { macro_rules! macro_result { (#($#arg_names:expr),*) => {{ - use sqlx::arguments::Arguments as _; + use sqlx_core::arguments::Arguments as _; #args @@ -69,7 +69,7 @@ where Ok(quote! { macro_rules! macro_result { (#($#arg_names:expr),*) => {{ - use sqlx::arguments::Arguments as _; + use sqlx_core::arguments::Arguments as _; #[derive(Debug)] struct #record_type { diff --git a/sqlx-macros/src/runtime.rs b/sqlx-macros/src/runtime.rs index b79060e5..2ad420df 100644 --- a/sqlx-macros/src/runtime.rs +++ b/sqlx-macros/src/runtime.rs @@ -5,7 +5,23 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled"); #[cfg(feature = "runtime-async-std")] -pub(crate) use async_std::fs; +pub(crate) use async_std::task::block_on; #[cfg(feature = "runtime-tokio")] -pub(crate) use tokio::fs; +pub fn block_on(future: F) -> F::Output { + use once_cell::sync::Lazy; + use tokio::runtime::{self, Runtime}; + + // lazily initialize a global runtime once for multiple invocations of the macros + static RUNTIME: Lazy = Lazy::new(|| { + runtime::Builder::new() + // `.basic_scheduler()` requires calling `Runtime::block_on()` which needs mutability + .threaded_scheduler() + .enable_io() + .enable_time() + .build() + .expect("failed to initialize Tokio runtime") + }); + + RUNTIME.enter(|| futures::executor::block_on(future)) +} diff --git a/src/macros.rs b/src/macros.rs index ade4105b..55a2c3ca 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -122,14 +122,14 @@ macro_rules! query ( ($query:literal) => ({ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query!($query); + $crate::sqlx_macros::expand_query!(source = $query); } macro_result!() }); ($query:literal, $($args:expr),*$(,)?) => ({ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query!($query, $($args),*); + $crate::sqlx_macros::expand_query!(source = $query, args = [$($args),*]); } macro_result!($($args),*) }) @@ -140,19 +140,17 @@ macro_rules! query ( #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! query_unchecked ( - // by emitting a macro definition from our proc-macro containing the result tokens, - // we no longer have a need for `proc-macro-hack` ($query:literal) => ({ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_unchecked!($query); + $crate::sqlx_macros::expand_query!(source = $query, checked = false); } macro_result!() }); ($query:literal, $($args:expr),*$(,)?) => ({ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_unchecked!($query, $($args),*); + $crate::sqlx_macros::expand_query!(source = $query, args = [$($args),*], checked = false); } macro_result!($($args),*) }) @@ -203,17 +201,17 @@ macro_rules! query_unchecked ( #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! query_file ( - ($query:literal) => (#[allow(dead_code)]{ + ($path:literal) => (#[allow(dead_code)]{ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file!($query); + $crate::sqlx_macros::expand_query!(source_file = $path); } macro_result!() }); - ($query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{ + ($path:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file!($query, $($args),*); + $crate::sqlx_macros::expand_query!(source_file = $path, args = [$($args),*]); } macro_result!($($args),*) }) @@ -224,17 +222,17 @@ macro_rules! query_file ( #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! query_file_unchecked ( - ($query:literal) => (#[allow(dead_code)]{ + ($path:literal) => (#[allow(dead_code)]{ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file_unchecked!($query); + $crate::sqlx_macros::query_file_unchecked!(source_file = $path, checked = false); } macro_result!() }); - ($query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{ + ($path:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file_unchecked!($query, $($args),*); + $crate::sqlx_macros::query_file_unchecked!(source_file = $path, args = [$($args),*], checked = false); } macro_result!($($args),*) }) @@ -298,14 +296,14 @@ macro_rules! query_as ( ($out_struct:path, $query:literal) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_as!($out_struct, $query); + $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query); } macro_result!() }); ($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_as!($out_struct, $query, $($args),*); + $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, args = [$($args),*]); } macro_result!($($args),*) }) @@ -347,17 +345,17 @@ macro_rules! query_as ( #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! query_file_as ( - ($out_struct:path, $query:literal) => (#[allow(dead_code)] { + ($out_struct:path, $path:literal) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file_as!($out_struct, $query); + $crate::sqlx_macros::expand_query!(record = $out_struct, source_file = $path); } macro_result!() }); - ($out_struct:path, $query:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] { + ($out_struct:path, $path:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file_as!($out_struct, $query, $($args),*); + $crate::sqlx_macros::expand_query!(record = $out_struct, source_file = $path, args = [$($args),*]); } macro_result!($($args),*) }) @@ -371,7 +369,7 @@ macro_rules! query_as_unchecked ( ($out_struct:path, $query:literal) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_as_unchecked!($out_struct, $query); + $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, checked = false); } macro_result!() }); @@ -379,7 +377,7 @@ macro_rules! query_as_unchecked ( ($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_as_unchecked!($out_struct, $query, $($args),*); + $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, args = [$($args),*], checked = false); } macro_result!($($args),*) }) @@ -391,18 +389,18 @@ macro_rules! query_as_unchecked ( #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! query_file_as_unchecked ( - ($out_struct:path, $query:literal) => (#[allow(dead_code)] { + ($out_struct:path, $path:literal) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file_as_unchecked!($out_struct, $query); + $crate::sqlx_macros::query_file_as_unchecked!(record = $out_struct, source_file = $path, checked = false); } macro_result!() }); - ($out_struct:path, $query:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] { + ($out_struct:path, $path:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file_as_unchecked!($out_struct, $query, $($args),*); + $crate::sqlx_macros::query_file_as_unchecked!(record = $out_struct, source_file = $path, args = [$($args),*], checked = false); } macro_result!($($args),*) })