add support for building in "decoupled" mode

This commit is contained in:
Austin Bonander 2020-04-13 23:00:19 -07:00 committed by Ryan Leckey
parent a9fb19b37d
commit 6913695588
24 changed files with 672 additions and 437 deletions

5
Cargo.lock generated
View File

@ -1928,10 +1928,13 @@ dependencies = [
"dotenv",
"futures 0.3.4",
"heck",
"lazy_static",
"hex",
"once_cell",
"proc-macro2",
"quote",
"serde",
"serde_json",
"sha2",
"sqlx-core",
"syn",
"tokio 0.2.13",

View File

@ -39,6 +39,9 @@ default = [ "macros", "runtime-async-std" ]
macros = [ "sqlx-macros" ]
tls = [ "sqlx-core/tls" ]
# offline building support in `sqlx-macros`
offline = ["sqlx-macros/offline", "sqlx-core/offline"]
# intended mainly for CI and docs
all = [ "tls", "all-database", "all-type" ]
all-database = [ "mysql", "sqlite", "postgres" ]

View File

@ -32,6 +32,9 @@ runtime-tokio = [ "async-native-tls/runtime-tokio", "tokio" ]
# intended for internal benchmarking, do not use
bench = []
# support offline/decoupled building (enables serialization of `Describe`)
offline = ["serde"]
[dependencies]
async-native-tls = { version = "0.3.2", default-features = false, optional = true }
async-std = { version = "1.5.0", features = [ "unstable" ], optional = true }

View File

@ -7,6 +7,14 @@ use crate::database::Database;
/// The return type of [`Executor::describe`].
///
/// [`Executor::describe`]: crate::executor::Executor::describe
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "offline",
serde(bound(
serialize = "DB::TypeInfo: serde::Serialize, Column<DB>: serde::Serialize",
deserialize = "DB::TypeInfo: serde::de::DeserializeOwned, Column<DB>: serde::de::DeserializeOwned"
))
)]
#[non_exhaustive]
pub struct Describe<DB>
where
@ -35,6 +43,14 @@ where
}
/// A single column of a result set.
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "offline",
serde(bound(
serialize = "DB::TableId: serde::Serialize, DB::TypeInfo: serde::Serialize",
deserialize = "DB::TableId: serde::de::DeserializeOwned, DB::TypeInfo: serde::de::DeserializeOwned"
))
)]
#[non_exhaustive]
pub struct Column<DB>
where

View File

@ -1,6 +1,7 @@
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html
// https://mariadb.com/kb/en/library/resultset/#field-types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct TypeId(pub u8);
// https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429

View File

@ -4,6 +4,7 @@ use crate::mysql::protocol::{ColumnDefinition, FieldFlags, TypeId};
use crate::types::TypeInfo;
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct MySqlTypeInfo {
pub(crate) id: TypeId,
pub(crate) is_unsigned: bool,

View File

@ -2,6 +2,7 @@ use crate::postgres::types::try_resolve_type_name;
use std::fmt::{self, Display};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct TypeId(pub(crate) u32);
// DEVELOPER PRO TIP: find builtin type OIDs easily by grepping this file

View File

@ -9,6 +9,7 @@ use std::sync::Arc;
/// Type information for a Postgres SQL type.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct PgTypeInfo {
pub(crate) id: Option<TypeId>,
pub(crate) name: SharedStr,
@ -186,8 +187,38 @@ impl From<String> for SharedStr {
}
}
impl From<SharedStr> for String {
fn from(s: SharedStr) -> Self {
String::from(&*s)
}
}
impl fmt::Display for SharedStr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.pad(self)
}
}
// manual impls because otherwise things get a little screwy with lifetimes
#[cfg(feature = "offline")]
impl<'de> serde::Deserialize<'de> for SharedStr {
fn deserialize<D>(deserializer: D) -> Result<Self, <D as serde::Deserializer<'de>>::Error>
where
D: serde::Deserializer<'de>,
{
Ok(String::deserialize(deserializer)?.into())
}
}
#[cfg(feature = "offline")]
impl serde::Serialize for SharedStr {
fn serialize<S>(
&self,
serializer: S,
) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self)
}
}

View File

@ -4,6 +4,7 @@ use crate::types::TypeInfo;
// https://www.sqlite.org/c3ref/c_blob.html
#[derive(Debug, PartialEq, Clone, Copy)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub(crate) enum SqliteType {
Integer = 1,
Float = 2,
@ -16,6 +17,7 @@ pub(crate) enum SqliteType {
// https://www.sqlite.org/datatype3.html#type_affinity
#[derive(Debug, PartialEq, Clone, Copy)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub(crate) enum SqliteTypeAffinity {
Text,
Numeric,
@ -25,6 +27,7 @@ pub(crate) enum SqliteTypeAffinity {
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct SqliteTypeInfo {
pub(crate) r#type: SqliteType,
pub(crate) affinity: Option<SqliteTypeAffinity>,

View File

@ -28,6 +28,14 @@ impl<'s> TryFrom<&'s String> for Url {
}
}
impl TryFrom<url::Url> for Url {
type Error = url::ParseError;
fn try_from(value: url::Url) -> Result<Self, Self::Error> {
Ok(Url(value))
}
}
impl Url {
#[allow(dead_code)]
pub(crate) fn as_str(&self) -> &str {

View File

@ -18,32 +18,38 @@ proc-macro = true
[features]
default = [ "runtime-async-std" ]
runtime-async-std = [ "sqlx/runtime-async-std", "async-std" ]
runtime-tokio = [ "sqlx/runtime-tokio", "tokio", "lazy_static" ]
runtime-async-std = [ "sqlx-core/runtime-async-std", "async-std" ]
runtime-tokio = [ "sqlx-core/runtime-tokio", "tokio", "once_cell" ]
# offline building support
offline = ["sqlx-core/offline", "serde", "serde_json", "hex", "sha2"]
# database
mysql = [ "sqlx/mysql" ]
postgres = [ "sqlx/postgres" ]
sqlite = [ "sqlx/sqlite" ]
mysql = [ "sqlx-core/mysql" ]
postgres = [ "sqlx-core/postgres" ]
sqlite = [ "sqlx-core/sqlite" ]
# type
bigdecimal = [ "sqlx/bigdecimal" ]
chrono = [ "sqlx/chrono" ]
time = [ "sqlx/time" ]
ipnetwork = [ "sqlx/ipnetwork" ]
uuid = [ "sqlx/uuid" ]
json = [ "sqlx/json", "serde_json" ]
bigdecimal = [ "sqlx-core/bigdecimal" ]
chrono = [ "sqlx-core/chrono" ]
time = [ "sqlx-core/time" ]
ipnetwork = [ "sqlx-core/ipnetwork" ]
uuid = [ "sqlx-core/uuid" ]
json = [ "sqlx-core/json", "serde_json" ]
[dependencies]
async-std = { version = "1.5.0", default-features = false, optional = true }
tokio = { version = "0.2.13", default-features = false, features = [ "rt-threaded" ], optional = true }
dotenv = { version = "0.15.0", default-features = false }
futures = { version = "0.3.4", default-features = false, features = [ "executor" ] }
hex = { version = "0.4.2", optional = true }
heck = "0.3"
proc-macro2 = { version = "1.0.9", default-features = false }
sqlx = { version = "0.3.5", default-features = false, path = "../sqlx-core", package = "sqlx-core" }
sqlx-core = { version = "0.3.5", default-features = false, path = "../sqlx-core" }
serde = { version = "1.0", optional = true }
serde_json = { version = "1.0", features = [ "raw_value" ], optional = true }
sha2 = { version = "0.8.1", optional = true }
syn = { version = "1.0.16", default-features = false, features = [ "full" ] }
quote = { version = "1.0.2", default-features = false }
url = { version = "2.1.1", default-features = false }
lazy_static = { version = "1.4.0", optional = true }
once_cell = { version = "1.3", features = ["std"], optional = true }

View File

@ -1,4 +1,4 @@
use sqlx::database::Database;
use sqlx_core::database::Database;
#[derive(PartialEq, Eq)]
#[allow(dead_code)]
@ -10,6 +10,7 @@ pub enum ParamChecking {
pub trait DatabaseExt: Database {
const DATABASE_PATH: &'static str;
const ROW_PATH: &'static str;
const NAME: &'static str;
const PARAM_CHECKING: ParamChecking;
@ -34,23 +35,25 @@ macro_rules! impl_database_ext {
$($(#[$meta:meta])? $ty:ty $(| $input:ty)?),*$(,)?
},
ParamChecking::$param_checking:ident,
feature-types: $name:ident => $get_gate:expr,
row = $row:path
feature-types: $ty_info:ident => $get_gate:expr,
row = $row:path,
name = $db_name:literal
) => {
impl $crate::database::DatabaseExt for $database {
const DATABASE_PATH: &'static str = stringify!($database);
const ROW_PATH: &'static str = stringify!($row);
const PARAM_CHECKING: $crate::database::ParamChecking = $crate::database::ParamChecking::$param_checking;
const NAME: &'static str = $db_name;
fn param_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> {
match () {
$(
$(#[$meta])?
_ if <$ty as sqlx::types::Type<$database>>::type_info() == *info => Some(input_ty!($ty $(, $input)?)),
_ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some(input_ty!($ty $(, $input)?)),
)*
$(
$(#[$meta])?
_ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)),
_ if sqlx_core::types::TypeInfo::compatible(&<$ty as sqlx_core::types::Type<$database>>::type_info(), &info) => Some(input_ty!($ty $(, $input)?)),
)*
_ => None
}
@ -60,17 +63,17 @@ macro_rules! impl_database_ext {
match () {
$(
$(#[$meta])?
_ if <$ty as sqlx::types::Type<$database>>::type_info() == *info => return Some(stringify!($ty)),
_ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => return Some(stringify!($ty)),
)*
$(
$(#[$meta])?
_ if sqlx::types::TypeInfo::compatible(&<$ty as sqlx::types::Type<$database>>::type_info(), &info) => return Some(stringify!($ty)),
_ if sqlx_core::types::TypeInfo::compatible(&<$ty as sqlx_core::types::Type<$database>>::type_info(), &info) => return Some(stringify!($ty)),
)*
_ => None
}
}
fn get_feature_gate($name: &Self::TypeInfo) -> Option<&'static str> {
fn get_feature_gate($ty_info: &Self::TypeInfo) -> Option<&'static str> {
$get_gate
}
}

View File

@ -1,5 +1,5 @@
impl_database_ext! {
sqlx::mysql::MySql {
sqlx_core::mysql::MySql {
u8,
u16,
u32,
@ -18,33 +18,34 @@ impl_database_ext! {
Vec<u8>,
#[cfg(all(feature = "chrono", not(feature = "time")))]
sqlx::types::chrono::NaiveTime,
sqlx_core::types::chrono::NaiveTime,
#[cfg(all(feature = "chrono", not(feature = "time")))]
sqlx::types::chrono::NaiveDate,
sqlx_core::types::chrono::NaiveDate,
#[cfg(all(feature = "chrono", not(feature = "time")))]
sqlx::types::chrono::NaiveDateTime,
sqlx_core::types::chrono::NaiveDateTime,
#[cfg(all(feature = "chrono", not(feature = "time")))]
sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>,
sqlx_core::types::chrono::DateTime<sqlx_core::types::chrono::Utc>,
#[cfg(feature = "time")]
sqlx::types::time::Time,
sqlx_core::types::time::Time,
#[cfg(feature = "time")]
sqlx::types::time::Date,
sqlx_core::types::time::Date,
#[cfg(feature = "time")]
sqlx::types::time::PrimitiveDateTime,
sqlx_core::types::time::PrimitiveDateTime,
#[cfg(feature = "time")]
sqlx::types::time::OffsetDateTime,
sqlx_core::types::time::OffsetDateTime,
#[cfg(feature = "bigdecimal")]
sqlx::types::BigDecimal,
sqlx_core::types::BigDecimal,
},
ParamChecking::Weak,
feature-types: info => info.type_feature_gate(),
row = sqlx::mysql::MySqlRow
row = sqlx_core::mysql::MySqlRow,
name = "MySQL/MariaDB"
}

View File

@ -1,5 +1,5 @@
impl_database_ext! {
sqlx::postgres::Postgres {
sqlx_core::postgres::Postgres {
bool,
String | &str,
i8,
@ -13,37 +13,37 @@ impl_database_ext! {
Vec<u8> | &[u8],
#[cfg(feature = "uuid")]
sqlx::types::Uuid,
sqlx_core::types::Uuid,
#[cfg(feature = "chrono")]
sqlx::types::chrono::NaiveTime,
sqlx_core::types::chrono::NaiveTime,
#[cfg(feature = "chrono")]
sqlx::types::chrono::NaiveDate,
sqlx_core::types::chrono::NaiveDate,
#[cfg(feature = "chrono")]
sqlx::types::chrono::NaiveDateTime,
sqlx_core::types::chrono::NaiveDateTime,
#[cfg(feature = "chrono")]
sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc> | sqlx::types::chrono::DateTime<_>,
sqlx_core::types::chrono::DateTime<sqlx_core::types::chrono::Utc> | sqlx_core::types::chrono::DateTime<_>,
#[cfg(feature = "time")]
sqlx::types::time::Time,
sqlx_core::types::time::Time,
#[cfg(feature = "time")]
sqlx::types::time::Date,
sqlx_core::types::time::Date,
#[cfg(feature = "time")]
sqlx::types::time::PrimitiveDateTime,
sqlx_core::types::time::PrimitiveDateTime,
#[cfg(feature = "time")]
sqlx::types::time::OffsetDateTime,
sqlx_core::types::time::OffsetDateTime,
#[cfg(feature = "bigdecimal")]
sqlx::types::BigDecimal,
sqlx_core::types::BigDecimal,
#[cfg(feature = "ipnetwork")]
sqlx::types::ipnetwork::IpNetwork,
sqlx_core::types::ipnetwork::IpNetwork,
#[cfg(feature = "json")]
serde_json::Value,
@ -61,41 +61,42 @@ impl_database_ext! {
#[cfg(feature = "uuid")]
Vec<sqlx::types::Uuid> | &[sqlx::types::Uuid],
Vec<sqlx_core::types::Uuid> | &[sqlx_core::types::Uuid],
#[cfg(feature = "chrono")]
Vec<sqlx::types::chrono::NaiveTime> | &[sqlx::types::sqlx::types::chrono::NaiveTime],
Vec<sqlx_core::types::chrono::NaiveTime> | &[sqlx_core::types::sqlx_core::types::chrono::NaiveTime],
#[cfg(feature = "chrono")]
Vec<sqlx::types::chrono::NaiveDate> | &[sqlx::types::chrono::NaiveDate],
Vec<sqlx_core::types::chrono::NaiveDate> | &[sqlx_core::types::chrono::NaiveDate],
#[cfg(feature = "chrono")]
Vec<sqlx::types::chrono::NaiveDateTime> | &[sqlx::types::chrono::NaiveDateTime],
Vec<sqlx_core::types::chrono::NaiveDateTime> | &[sqlx_core::types::chrono::NaiveDateTime],
// TODO
// #[cfg(feature = "chrono")]
// Vec<sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>> | &[sqlx::types::chrono::DateTime<_>],
// Vec<sqlx_core::types::chrono::DateTime<sqlx_core::types::chrono::Utc>> | &[sqlx_core::types::chrono::DateTime<_>],
#[cfg(feature = "time")]
Vec<sqlx::types::time::Time> | &[sqlx::types::time::Time],
Vec<sqlx_core::types::time::Time> | &[sqlx_core::types::time::Time],
#[cfg(feature = "time")]
Vec<sqlx::types::time::Date> | &[sqlx::types::time::Date],
Vec<sqlx_core::types::time::Date> | &[sqlx_core::types::time::Date],
#[cfg(feature = "time")]
Vec<sqlx::types::time::PrimitiveDateTime> | &[sqlx::types::time::PrimitiveDateTime],
Vec<sqlx_core::types::time::PrimitiveDateTime> | &[sqlx_core::types::time::PrimitiveDateTime],
#[cfg(feature = "time")]
Vec<sqlx::types::time::OffsetDateTime> | &[sqlx::types::time::OffsetDateTime],
Vec<sqlx_core::types::time::OffsetDateTime> | &[sqlx_core::types::time::OffsetDateTime],
#[cfg(feature = "bigdecimal")]
Vec<sqlx::types::BigDecimal> | &[sqlx::types::BigDecimal],
Vec<sqlx_core::types::BigDecimal> | &[sqlx_core::types::BigDecimal],
#[cfg(feature = "ipnetwork")]
Vec<sqlx::types::ipnetwork::IpNetwork> | &[sqlx::types::ipnetwork::IpNetwork],
Vec<sqlx_core::types::ipnetwork::IpNetwork> | &[sqlx_core::types::ipnetwork::IpNetwork],
},
ParamChecking::Strong,
feature-types: info => info.type_feature_gate(),
row = sqlx::postgres::PgRow
row = sqlx_core::postgres::PgRow,
name = "PostgreSQL"
}

View File

@ -1,5 +1,5 @@
impl_database_ext! {
sqlx::sqlite::Sqlite {
sqlx_core::sqlite::Sqlite {
bool,
i32,
i64,
@ -10,5 +10,6 @@ impl_database_ext! {
},
ParamChecking::Weak,
feature-types: _info => None,
row = sqlx::sqlite::SqliteRow
row = sqlx_core::sqlite::SqliteRow,
name = "SQLite"
}

View File

@ -11,8 +11,6 @@ use quote::quote;
#[cfg(feature = "runtime-async-std")]
use async_std::task::block_on;
use std::path::PathBuf;
use url::Url;
type Error = Box<dyn std::error::Error>;
@ -26,23 +24,6 @@ mod runtime;
use query_macros::*;
#[cfg(feature = "runtime-tokio")]
lazy_static::lazy_static! {
static ref BASIC_RUNTIME: tokio::runtime::Runtime = {
tokio::runtime::Builder::new()
.threaded_scheduler()
.enable_io()
.enable_time()
.build()
.expect("failed to build tokio runtime")
};
}
#[cfg(feature = "runtime-tokio")]
fn block_on<F: std::future::Future>(future: F) -> F::Output {
BASIC_RUNTIME.enter(|| futures::executor::block_on(future))
}
fn macro_result(tokens: proc_macro2::TokenStream) -> TokenStream {
quote!(
macro_rules! macro_result {
@ -52,141 +33,21 @@ fn macro_result(tokens: proc_macro2::TokenStream) -> TokenStream {
.into()
}
macro_rules! async_macro (
($db:ident, $input:ident: $ty:ty => $expr:expr) => {{
let $input = match syn::parse::<$ty>($input) {
Ok(input) => input,
Err(e) => return macro_result(e.to_compile_error()),
};
#[proc_macro]
pub fn expand_query(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as QueryMacroInput);
let res: Result<proc_macro2::TokenStream> = block_on(async {
use sqlx::connection::Connect;
// If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this,
// otherwise fallback to default dotenv behaviour.
if let Ok(dir) = std::env::var("CARGO_MANIFEST_DIR") {
let env_path = PathBuf::from(dir).join(".env");
if env_path.exists() {
dotenv::from_path(&env_path)
.map_err(|e| format!("failed to load environment from {:?}, {}", env_path, e))?
}
}
let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?;
match db_url.scheme() {
#[cfg(feature = "sqlite")]
"sqlite" => {
let $db = sqlx::sqlite::SqliteConnection::connect(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;
$expr.await
}
#[cfg(not(feature = "sqlite"))]
"sqlite" => Err(format!(
"DATABASE_URL {} has the scheme of a SQLite database but the `sqlite` \
feature of sqlx was not enabled",
db_url
).into()),
#[cfg(feature = "postgres")]
"postgresql" | "postgres" => {
let $db = sqlx::postgres::PgConnection::connect(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;
$expr.await
}
#[cfg(not(feature = "postgres"))]
"postgresql" | "postgres" => Err(format!(
"DATABASE_URL {} has the scheme of a Postgres database but the `postgres` \
feature of sqlx was not enabled",
db_url
).into()),
#[cfg(feature = "mysql")]
"mysql" | "mariadb" => {
let $db = sqlx::mysql::MySqlConnection::connect(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;
$expr.await
}
#[cfg(not(feature = "mysql"))]
"mysql" | "mariadb" => Err(format!(
"DATABASE_URL {} has the scheme of a MySQL/MariaDB database but the `mysql` \
feature of sqlx was not enabled",
db_url
).into()),
scheme => Err(format!("unexpected scheme {:?} in DATABASE_URL {}", scheme, db_url).into()),
}
});
match res {
Ok(ts) => ts.into(),
Err(e) => {
if let Some(parse_err) = e.downcast_ref::<syn::Error>() {
macro_result(parse_err.to_compile_error())
} else {
let msg = e.to_string();
macro_result(quote!(compile_error!(#msg)))
}
match query_macros::expand_input(input) {
Ok(ts) => ts.into(),
Err(e) => {
if let Some(parse_err) = e.downcast_ref::<syn::Error>() {
macro_result(parse_err.to_compile_error())
} else {
let msg = e.to_string();
macro_result(quote!(compile_error!(#msg)))
}
}
}}
);
#[proc_macro]
#[allow(unused_variables)]
pub fn query(input: TokenStream) -> TokenStream {
#[allow(unused_variables)]
async_macro!(db, input: QueryMacroInput => expand_query(input, db, true))
}
#[proc_macro]
#[allow(unused_variables)]
pub fn query_unchecked(input: TokenStream) -> TokenStream {
#[allow(unused_variables)]
async_macro!(db, input: QueryMacroInput => expand_query(input, db, false))
}
#[proc_macro]
#[allow(unused_variables)]
pub fn query_file(input: TokenStream) -> TokenStream {
#[allow(unused_variables)]
async_macro!(db, input: QueryMacroInput => expand_query_file(input, db, true))
}
#[proc_macro]
#[allow(unused_variables)]
pub fn query_file_unchecked(input: TokenStream) -> TokenStream {
#[allow(unused_variables)]
async_macro!(db, input: QueryMacroInput => expand_query_file(input, db, false))
}
#[proc_macro]
#[allow(unused_variables)]
pub fn query_as(input: TokenStream) -> TokenStream {
#[allow(unused_variables)]
async_macro!(db, input: QueryAsMacroInput => expand_query_as(input, db, true))
}
#[proc_macro]
#[allow(unused_variables)]
pub fn query_file_as(input: TokenStream) -> TokenStream {
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db, true))
}
#[proc_macro]
#[allow(unused_variables)]
pub fn query_as_unchecked(input: TokenStream) -> TokenStream {
#[allow(unused_variables)]
async_macro!(db, input: QueryAsMacroInput => expand_query_as(input, db, false))
}
#[proc_macro]
#[allow(unused_variables)]
pub fn query_file_as_unchecked(input: TokenStream) -> TokenStream {
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db, false))
}
}
#[proc_macro_derive(Encode, attributes(sqlx))]

View File

@ -3,7 +3,7 @@ use syn::spanned::Spanned;
use syn::Expr;
use quote::{quote, quote_spanned, ToTokens};
use sqlx::describe::Describe;
use sqlx_core::describe::Describe;
use crate::database::{DatabaseExt, ParamChecking};
use crate::query_macros::QueryMacroInput;
@ -13,7 +13,6 @@ use crate::query_macros::QueryMacroInput;
pub fn quote_args<DB: DatabaseExt>(
input: &QueryMacroInput,
describe: &Describe<DB>,
checked: bool,
) -> crate::Result<TokenStream> {
let db_path = DB::db_path();
@ -25,7 +24,7 @@ pub fn quote_args<DB: DatabaseExt>(
let arg_name = &input.arg_names;
let args_check = if checked && DB::PARAM_CHECKING == ParamChecking::Strong {
let args_check = if input.checked && DB::PARAM_CHECKING == ParamChecking::Strong {
describe
.param_types
.iter()

View File

@ -0,0 +1,175 @@
use sqlx_core::connection::{Connect, Connection};
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::{Executor, RefExecutor};
use url::Url;
use std::fmt::{self, Display, Formatter};
use crate::database::DatabaseExt;
use proc_macro2::TokenStream;
use std::fs::File;
use syn::export::Span;
// TODO: enable serialization
#[cfg_attr(feature = "offline", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(
feature = "offline",
serde(bound(
serialize = "Describe<DB>: serde::Serialize",
deserialize = "Describe<DB>: serde::de::DeserializeOwned"
))
)]
pub struct QueryData<DB: Database> {
pub(super) query: String,
pub(super) describe: Describe<DB>,
}
impl<DB: Database> QueryData<DB> {
pub async fn from_db(
conn: &mut impl Executor<Database = DB>,
query: &str,
) -> crate::Result<Self> {
Ok(QueryData {
query: query.into(),
describe: conn.describe(query).await?,
})
}
}
#[cfg(feature = "offline")]
pub mod offline {
use super::QueryData;
use std::fs::File;
use std::fmt::{self, Formatter};
use crate::database::DatabaseExt;
use proc_macro2::{Span, TokenStream};
use serde::de::{Deserializer, MapAccess, Visitor};
use sqlx_core::describe::Describe;
use sqlx_core::query::query;
use std::path::Path;
#[derive(serde::Deserialize)]
pub struct DynQueryData {
#[serde(skip)]
pub db_name: String,
pub query: String,
pub describe: serde_json::Value,
}
impl DynQueryData {
/// Find and deserialize the data table for this query from a shared `sqlx-data.json`
/// file. The expected structure is a JSON map keyed by the SHA-256 hash of queries in hex.
pub fn from_data_file(path: impl AsRef<Path>, query: &str) -> crate::Result<Self> {
serde_json::Deserializer::from_reader(
File::open(path.as_ref()).map_err(|e| {
format!("failed to open path {}: {}", path.as_ref().display(), e)
})?,
)
.deserialize_map(DataFileVisitor {
query,
hash: hash_string(query),
})
.map_err(Into::into)
}
}
impl<DB: DatabaseExt> QueryData<DB>
where
Describe<DB>: serde::Serialize + serde::de::DeserializeOwned,
{
pub fn from_dyn_data(dyn_data: DynQueryData) -> crate::Result<Self> {
assert!(!dyn_data.db_name.is_empty());
if DB::NAME == dyn_data.db_name {
let describe: Describe<DB> = serde_json::from_value(dyn_data.describe)?;
Ok(QueryData {
query: dyn_data.query,
describe,
})
} else {
Err(format!(
"expected query data for {}, got data for {}",
DB::NAME,
dyn_data.db_name
)
.into())
}
}
pub fn save_in(&self, dir: impl AsRef<Path>, input_span: Span) -> crate::Result<()> {
// we save under the hash of the span representation because that should be unique
// per invocation
let path = dir.as_ref().join(format!(
"query-{}.json",
hash_string(&format!("{:?}", input_span))
));
serde_json::to_writer_pretty(
File::create(&path)
.map_err(|e| format!("failed to open path {}: {}", path.display(), e))?,
self,
)
.map_err(Into::into)
}
}
fn hash_string(query: &str) -> String {
// picked `sha2` because it's already in the dependency tree for both MySQL and Postgres
use sha2::{Digest, Sha256};
hex::encode(Sha256::digest(query.as_bytes()))
}
// lazily deserializes only the `QueryData` for the query we're looking for
struct DataFileVisitor<'a> {
query: &'a str,
hash: String,
}
impl<'de> Visitor<'de> for DataFileVisitor<'_> {
type Value = DynQueryData;
fn expecting(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "expected map key {:?} or \"db\"", self.hash)
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, <A as MapAccess<'de>>::Error>
where
A: MapAccess<'de>,
{
let mut db_name: Option<String> = None;
// unfortunately we can't avoid this copy because deserializing from `io::Read`
// doesn't support deserializing borrowed values
while let Some(key) = map.next_key::<String>()? {
// lazily deserialize the query data only
if key == "db" {
db_name = Some(map.next_value::<String>()?);
} else if key == self.hash {
let db_name = db_name.ok_or_else(|| {
serde::de::Error::custom("expected \"db\" key before query hash keys")
})?;
let mut query_data: DynQueryData = map.next_value()?;
return if query_data.query == self.query {
query_data.db_name = db_name;
Ok(query_data)
} else {
Err(serde::de::Error::custom(format_args!(
"hash collision for stored queries:\n{:?}\n{:?}",
self.query, query_data.query
)))
};
}
}
Err(serde::de::Error::custom(format_args!(
"failed to find data for query {}",
self.hash
)))
}
}
}

View File

@ -1,171 +1,122 @@
use std::env;
use std::fs;
use proc_macro2::{Ident, Span};
use quote::{format_ident, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::parse::{Parse, ParseBuffer, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Group;
use syn::{Expr, ExprLit, ExprPath, Lit};
use syn::{ExprGroup, Token};
use syn::{Error, Expr, ExprLit, ExprPath, Lit, LitBool, LitStr, Token};
use syn::{ExprArray, ExprGroup, Type};
use sqlx::connection::Connection;
use sqlx::describe::Describe;
use crate::runtime::fs;
use sqlx_core::connection::Connection;
use sqlx_core::describe::Describe;
/// Macro input shared by `query!()` and `query_file!()`
pub struct QueryMacroInput {
pub(super) source: String,
pub(super) source_span: Span,
pub(super) src: String,
pub(super) src_span: Span,
pub(super) data_src: DataSrc,
pub(super) record_type: RecordType,
// `arg0 .. argN` for N arguments
pub(super) arg_names: Vec<Ident>,
pub(super) arg_exprs: Vec<Expr>,
pub(super) checked: bool,
}
impl QueryMacroInput {
fn from_exprs(input: ParseStream, mut args: impl Iterator<Item = Expr>) -> syn::Result<Self> {
fn lit_err<T>(span: Span, unexpected: Expr) -> syn::Result<T> {
Err(syn::Error::new(
span,
format!(
"expected string literal, got {}",
unexpected.to_token_stream()
),
))
}
enum QuerySrc {
String(String),
File(String),
}
let (source, source_span) = match args.next() {
Some(Expr::Lit(ExprLit {
lit: Lit::Str(sql), ..
})) => (sql.value(), sql.span()),
Some(Expr::Group(ExprGroup {
expr,
group_token: Group { span },
..
})) => {
// this duplication with the above is necessary because `expr` is `Box<Expr>` here
// which we can't directly pattern-match without `box_patterns`
match *expr {
Expr::Lit(ExprLit {
lit: Lit::Str(sql), ..
}) => (sql.value(), span),
other_expr => return lit_err(span, other_expr),
}
}
Some(other_expr) => return lit_err(other_expr.span(), other_expr),
None => return Err(input.error("expected SQL string literal")),
};
pub enum DataSrc {
Env(String),
DbUrl(String),
File,
}
let arg_exprs: Vec<_> = args.collect();
let arg_names = (0..arg_exprs.len())
.map(|i| format_ident!("arg{}", i))
.collect();
Ok(Self {
source,
source_span,
arg_exprs,
arg_names,
})
}
pub async fn expand_file_src(self) -> syn::Result<Self> {
let source = read_file_src(&self.source, self.source_span).await?;
Ok(Self { source, ..self })
}
/// Run a parse/describe on the query described by this input and validate that it matches the
/// passed number of args
pub async fn describe_validate<C: Connection>(
&self,
conn: &mut C,
) -> crate::Result<Describe<C::Database>> {
let describe = conn
.describe(&*self.source)
.await
.map_err(|e| syn::Error::new(self.source_span, e))?;
if self.arg_names.len() != describe.param_types.len() {
return Err(syn::Error::new(
Span::call_site(),
format!(
"expected {} parameters, got {}",
describe.param_types.len(),
self.arg_names.len()
),
)
.into());
}
Ok(describe)
}
pub enum RecordType {
Given(Type),
Generated,
}
impl Parse for QueryMacroInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
let mut query_src: Option<(QuerySrc, Span)> = None;
let mut data_src = DataSrc::Env("DATABASE_URL".into());
let mut args: Option<Vec<Expr>> = None;
let mut record_type = RecordType::Generated;
let mut checked = true;
Self::from_exprs(input, args)
}
}
let mut expect_comma = false;
/// Macro input shared by `query_as!()` and `query_file_as!()`
pub struct QueryAsMacroInput {
pub(super) as_ty: ExprPath,
pub(super) query_input: QueryMacroInput,
}
while !input.is_empty() {
if expect_comma {
let _ = input.parse::<syn::token::Comma>()?;
}
impl QueryAsMacroInput {
pub async fn expand_file_src(self) -> syn::Result<Self> {
Ok(Self {
query_input: self.query_input.expand_file_src().await?,
..self
})
}
}
let key: Ident = input.parse()?;
impl Parse for QueryAsMacroInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
fn path_err<T>(span: Span, unexpected: Expr) -> syn::Result<T> {
Err(syn::Error::new(
span,
format!(
"expected path to a type, got {}",
unexpected.to_token_stream()
),
))
let _ = input.parse::<syn::token::Eq>()?;
if key == "source" {
let lit_str = input.parse::<LitStr>()?;
query_src = Some((QuerySrc::String(lit_str.value()), lit_str.span()));
} else if key == "source_file" {
let lit_str = input.parse::<LitStr>()?;
query_src = Some((QuerySrc::File(lit_str.value()), lit_str.span()));
} else if key == "args" {
let exprs = input.parse::<ExprArray>()?;
args = Some(exprs.elems.into_iter().collect())
} else if key == "record" {
record_type = RecordType::Given(input.parse()?);
} else if key == "checked" {
let lit_bool = input.parse::<LitBool>()?;
checked = lit_bool.value;
} else {
let message = format!("unexpected input key: {}", key);
return Err(syn::Error::new_spanned(key, message));
}
expect_comma = true;
}
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
let (src, src_span) =
query_src.ok_or_else(|| input.error("expected `source` or `source_file` key"))?;
let as_ty = match args.next() {
Some(Expr::Path(path)) => path,
Some(Expr::Group(ExprGroup {
expr,
group_token: Group { span },
..
})) => {
// this duplication with the above is necessary because `expr` is `Box<Expr>` here
// which we can't directly pattern-match without `box_patterns`
match *expr {
Expr::Path(path) => path,
other_expr => return path_err(span, other_expr),
}
}
Some(other_expr) => return path_err(other_expr.span(), other_expr),
None => return Err(input.error("expected path to SQL file")),
};
let arg_exprs = args.unwrap_or_default();
let arg_names = (0..arg_exprs.len())
.map(|i| format_ident!("arg{}", i))
.collect();
Ok(QueryAsMacroInput {
as_ty,
query_input: QueryMacroInput::from_exprs(input, args)?,
Ok(QueryMacroInput {
src: src.resolve(src_span)?,
src_span,
data_src,
record_type,
arg_names,
arg_exprs,
checked,
})
}
}
async fn read_file_src(source: &str, source_span: Span) -> syn::Result<String> {
impl QuerySrc {
/// If the query source is a file, read it to a string. Otherwise return the query string.
fn resolve(self, source_span: Span) -> syn::Result<String> {
match self {
QuerySrc::String(string) => Ok(string),
QuerySrc::File(file) => read_file_src(&file, source_span),
}
}
}
fn read_file_src(source: &str, source_span: Span) -> syn::Result<String> {
use std::path::Path;
let path = Path::new(source);
@ -201,7 +152,7 @@ async fn read_file_src(source: &str, source_span: Span) -> syn::Result<String> {
let file_path = base_dir_path.join(path);
fs::read_to_string(&file_path).await.map_err(|e| {
fs::read_to_string(&file_path).map_err(|e| {
syn::Error::new(
source_span,
format!(

View File

@ -1,68 +1,223 @@
use std::borrow::Cow;
use std::env;
use std::fmt::Display;
use std::path::PathBuf;
use proc_macro2::TokenStream;
use proc_macro2::{Ident, Span, TokenStream};
use syn::Type;
use url::Url;
pub use input::QueryMacroInput;
use quote::{format_ident, quote};
pub use input::{QueryAsMacroInput, QueryMacroInput};
pub use query::expand_query;
use sqlx_core::connection::Connect;
use sqlx_core::connection::Connection;
use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use crate::database::DatabaseExt;
use crate::query_macros::data::QueryData;
use crate::query_macros::input::RecordType;
use crate::runtime::block_on;
use sqlx::connection::Connection;
use sqlx::database::Database;
// pub use query::expand_query;
mod args;
mod data;
mod input;
mod output;
mod query;
// mod query;
pub async fn expand_query_file<C: Connection>(
input: QueryMacroInput,
conn: C,
checked: bool,
) -> crate::Result<TokenStream>
where
C::Database: DatabaseExt + Sized,
<C::Database as Database>::TypeInfo: Display,
{
expand_query(input.expand_file_src().await?, conn, checked).await
pub fn expand_input(input: QueryMacroInput) -> crate::Result<TokenStream> {
let manifest_dir =
env::var("CARGO_MANIFEST_DIR").map_err(|_| "`CARGO_MANIFEST_DIR` must be set")?;
// If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this,
// otherwise fallback to default dotenv behaviour.
let env_path = std::path::Path::new(&manifest_dir).join(".env");
if env_path.exists() {
dotenv::from_path(&env_path)
.map_err(|e| format!("failed to load environment from {:?}, {}", env_path, e))?
}
// if `dotenv` wasn't initialized by the above we make sure to do it here
match dotenv::var("DATABASE_URL").ok() {
Some(db_url) => expand_from_db(input, &db_url),
#[cfg(feature = "offline")]
None => {
let data_file_path = std::path::Path::new(&manifest_dir).join("sqlx-data.json");
if data_file_path.exists() {
expand_from_file(input, data_file_path)
} else {
Err(
"`DATABASE_URL` must be set, or `cargo sqlx prepare` must have been run \
and sqlx-data.json must exist, to use query macros"
.into(),
)
}
}
#[cfg(not(feature = "offline"))]
None => Err("`DATABASE_URL` must be set to use query macros".into()),
}
}
pub async fn expand_query_as<C: Connection>(
input: QueryAsMacroInput,
mut conn: C,
checked: bool,
fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenStream> {
let db_url = Url::parse(db_url)?;
match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => {
let data = block_on(async {
let mut conn = sqlx_core::postgres::PgConnection::connect(db_url).await?;
QueryData::from_db(&mut conn, &input.src).await
})?;
expand_with_data(input, data)
},
#[cfg(not(feature = "postgres"))]
"postgres" | "postgresql" => Err(format!("database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled").into()),
#[cfg(feature = "mysql")]
"mysql" | "mariadb" => {
let data = block_on(async {
let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url).await?;
QueryData::from_db(&mut conn, &input.src).await
})?;
expand_with_data(input, data)
},
#[cfg(not(feature = "mysql"))]
"mysql" | "mariadb" => Err(format!("database URL has the scheme of a MySQL/MariaDB database but the `mysql` feature is not enabled").into()),
#[cfg(feature = "sqlite")]
"sqlite" => {
let data = block_on(async {
let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url).await?;
QueryData::from_db(&mut conn, &input.src).await
})?;
expand_with_data(input, data)
},
#[cfg(not(feature = "sqlite"))]
"sqlite" => Err(format!("database URL has the scheme of a SQLite database but the `sqlite` feature is not enabled").into()),
scheme => Err(format!("unknown database URL scheme {:?}", scheme).into())
}
}
#[cfg(feature = "offline")]
pub fn expand_from_file(input: QueryMacroInput, file: PathBuf) -> crate::Result<TokenStream> {
use data::offline::DynQueryData;
let query_data = DynQueryData::from_data_file(file, &input.src)?;
assert!(!query_data.db_name.is_empty());
match &*query_data.db_name {
#[cfg(feature = "postgres")]
sqlx_core::postgres::Postgres::NAME => expand_with_data(
input,
QueryData::<sqlx_core::postgres::Postgres>::from_dyn_data(query_data)?,
),
#[cfg(feature = "mysql")]
sqlx_core::mysql::MySql::NAME => expand_with_data(
input,
QueryData::<sqlx_core::mysql::MySql>::from_dyn_data(query_data)?,
),
#[cfg(feature = "sqlite")]
sqlx_core::sqlite::Sqlite::NAME => expand_with_data(
input,
QueryData::<sqlx::sqlite::Sqlite>::from_dyn_data(query_data)?,
),
_ => Err(format!(
"found query data for {} but the feature for that database was not enabled",
query_data.db_name
)
.into()),
}
}
// marker trait for `Describe` that lets us conditionally require it to be `Serialize + Deserialize`
#[cfg(feature = "offline")]
trait DescribeExt: serde::Serialize + serde::de::DeserializeOwned {}
#[cfg(feature = "offline")]
impl<DB: Database> DescribeExt for Describe<DB> where
Describe<DB>: serde::Serialize + serde::de::DeserializeOwned
{
}
#[cfg(not(feature = "offline"))]
trait DescribeExt {}
#[cfg(not(feature = "offline"))]
impl<DB: Database> DescribeExt for Describe<DB> {}
fn expand_with_data<DB: DatabaseExt>(
input: QueryMacroInput,
data: QueryData<DB>,
) -> crate::Result<TokenStream>
where
C::Database: DatabaseExt + Sized,
<C::Database as Database>::TypeInfo: Display,
Describe<DB>: DescribeExt,
{
let describe = input.query_input.describe_validate(&mut conn).await?;
if describe.result_columns.is_empty() {
// validate at the minimum that our args match the query's input parameters
if input.arg_names.len() != data.describe.param_types.len() {
return Err(syn::Error::new(
input.query_input.source_span,
"query must output at least one column",
Span::call_site(),
format!(
"expected {} parameters, got {}",
data.describe.param_types.len(),
input.arg_names.len()
),
)
.into());
}
let args_tokens = args::quote_args(&input.query_input, &describe, checked)?;
let args_tokens = args::quote_args(&input, &data.describe)?;
let query_args = format_ident!("query_args");
let columns = output::columns_to_rust(&describe)?;
let output = output::quote_query_as::<C::Database>(
&input.query_input.source,
&input.as_ty.path,
&query_args,
&columns,
checked,
);
let output = if data.describe.result_columns.is_empty() {
let db_path = DB::db_path();
let sql = &input.src;
let arg_names = &input.query_input.arg_names;
quote! {
sqlx::query::<#db_path>(#sql).bind_all(#query_args)
}
} else {
let columns = output::columns_to_rust::<DB>(&data.describe)?;
Ok(quote! {
let (out_ty, mut record_tokens) = match input.record_type {
RecordType::Generated => {
let record_name: Type = syn::parse_str("Record").unwrap();
let record_fields = columns.iter().map(
|&output::RustColumn {
ref ident,
ref type_,
}| quote!(#ident: #type_,),
);
let record_tokens = quote! {
#[derive(Debug)]
struct #record_name {
#(#record_fields)*
}
};
(Cow::Owned(record_name), record_tokens)
}
RecordType::Given(ref out_ty) => (Cow::Borrowed(out_ty), quote!()),
};
record_tokens.extend(output::quote_query_as::<DB>(
&input,
&out_ty,
&query_args,
&columns,
));
record_tokens
};
let arg_names = &input.arg_names;
let ret_tokens = quote! {
macro_rules! macro_result {
(#($#arg_names:expr),*) => {{
use sqlx::arguments::Arguments as _;
@ -72,17 +227,14 @@ where
#output
}}
}
})
}
};
pub async fn expand_query_file_as<C: Connection>(
input: QueryAsMacroInput,
conn: C,
checked: bool,
) -> crate::Result<TokenStream>
where
C::Database: DatabaseExt + Sized,
<C::Database as Database>::TypeInfo: Display,
{
expand_query_as(input.expand_file_src().await?, conn, checked).await
#[cfg(feature = "offline")]
{
let save_dir = env::var("CARGO_TARGET_DIR").unwrap_or_else(|_| "target/sqlx".into());
std::fs::create_dir_all(&save_dir);
data.save_in(save_dir, input.src_span)?;
}
Ok(ret_tokens)
}

View File

@ -1,11 +1,12 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::Path;
use syn::{Path, Type};
use sqlx::describe::Describe;
use sqlx_core::describe::Describe;
use crate::database::DatabaseExt;
use crate::query_macros::QueryMacroInput;
use std::fmt::{self, Display, Formatter};
pub struct RustColumn {
@ -98,11 +99,10 @@ pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Resul
}
pub fn quote_query_as<DB: DatabaseExt>(
sql: &str,
out_ty: &Path,
input: &QueryMacroInput,
out_ty: &Type,
bind_args: &Ident,
columns: &[RustColumn],
checked: bool,
) -> TokenStream {
let instantiations = columns.iter().enumerate().map(
|(
@ -116,7 +116,7 @@ pub fn quote_query_as<DB: DatabaseExt>(
// For "checked" queries, the macro checks these at compile time and using "try_get"
// would also perform pointless runtime checks
if checked {
if input.checked {
quote!( #ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()? )
} else {
quote!( #ident: row.try_get_unchecked(#i)? )
@ -126,6 +126,7 @@ pub fn quote_query_as<DB: DatabaseExt>(
let db_path = DB::db_path();
let row_path = DB::row_path();
let sql = &input.src;
quote! {
sqlx::query::<#db_path>(#sql).bind_all(#bind_args).try_map(|row: #row_path| {

View File

@ -5,7 +5,7 @@ use proc_macro2::TokenStream;
use syn::{Ident, Path};
use quote::{format_ident, quote};
use sqlx::{connection::Connection, database::Database};
use sqlx_core::{connection::Connection, database::Database};
use super::{args, output, QueryMacroInput};
use crate::database::DatabaseExt;
@ -22,7 +22,7 @@ where
<C::Database as Database>::TypeInfo: Display,
{
let describe = input.describe_validate(&mut conn).await?;
let sql = &input.source;
let sql = &input.src;
let args = args::quote_args(&input, &describe, checked)?;
@ -33,7 +33,7 @@ where
return Ok(quote! {
macro_rules! macro_result {
(#($#arg_names:expr),*) => {{
use sqlx::arguments::Arguments as _;
use sqlx_core::arguments::Arguments as _;
#args
@ -69,7 +69,7 @@ where
Ok(quote! {
macro_rules! macro_result {
(#($#arg_names:expr),*) => {{
use sqlx::arguments::Arguments as _;
use sqlx_core::arguments::Arguments as _;
#[derive(Debug)]
struct #record_type {

View File

@ -5,7 +5,23 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
#[cfg(feature = "runtime-async-std")]
pub(crate) use async_std::fs;
pub(crate) use async_std::task::block_on;
#[cfg(feature = "runtime-tokio")]
pub(crate) use tokio::fs;
pub fn block_on<F: std::future::Future>(future: F) -> F::Output {
use once_cell::sync::Lazy;
use tokio::runtime::{self, Runtime};
// lazily initialize a global runtime once for multiple invocations of the macros
static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
runtime::Builder::new()
// `.basic_scheduler()` requires calling `Runtime::block_on()` which needs mutability
.threaded_scheduler()
.enable_io()
.enable_time()
.build()
.expect("failed to initialize Tokio runtime")
});
RUNTIME.enter(|| futures::executor::block_on(future))
}

View File

@ -122,14 +122,14 @@ macro_rules! query (
($query:literal) => ({
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query!($query);
$crate::sqlx_macros::expand_query!(source = $query);
}
macro_result!()
});
($query:literal, $($args:expr),*$(,)?) => ({
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query!($query, $($args),*);
$crate::sqlx_macros::expand_query!(source = $query, args = [$($args),*]);
}
macro_result!($($args),*)
})
@ -140,19 +140,17 @@ macro_rules! query (
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_unchecked (
// by emitting a macro definition from our proc-macro containing the result tokens,
// we no longer have a need for `proc-macro-hack`
($query:literal) => ({
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_unchecked!($query);
$crate::sqlx_macros::expand_query!(source = $query, checked = false);
}
macro_result!()
});
($query:literal, $($args:expr),*$(,)?) => ({
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_unchecked!($query, $($args),*);
$crate::sqlx_macros::expand_query!(source = $query, args = [$($args),*], checked = false);
}
macro_result!($($args),*)
})
@ -203,17 +201,17 @@ macro_rules! query_unchecked (
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file (
($query:literal) => (#[allow(dead_code)]{
($path:literal) => (#[allow(dead_code)]{
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file!($query);
$crate::sqlx_macros::expand_query!(source_file = $path);
}
macro_result!()
});
($query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{
($path:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file!($query, $($args),*);
$crate::sqlx_macros::expand_query!(source_file = $path, args = [$($args),*]);
}
macro_result!($($args),*)
})
@ -224,17 +222,17 @@ macro_rules! query_file (
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file_unchecked (
($query:literal) => (#[allow(dead_code)]{
($path:literal) => (#[allow(dead_code)]{
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file_unchecked!($query);
$crate::sqlx_macros::query_file_unchecked!(source_file = $path, checked = false);
}
macro_result!()
});
($query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{
($path:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file_unchecked!($query, $($args),*);
$crate::sqlx_macros::query_file_unchecked!(source_file = $path, args = [$($args),*], checked = false);
}
macro_result!($($args),*)
})
@ -298,14 +296,14 @@ macro_rules! query_as (
($out_struct:path, $query:literal) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_as!($out_struct, $query);
$crate::sqlx_macros::expand_query!(record = $out_struct, source = $query);
}
macro_result!()
});
($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_as!($out_struct, $query, $($args),*);
$crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, args = [$($args),*]);
}
macro_result!($($args),*)
})
@ -347,17 +345,17 @@ macro_rules! query_as (
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file_as (
($out_struct:path, $query:literal) => (#[allow(dead_code)] {
($out_struct:path, $path:literal) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file_as!($out_struct, $query);
$crate::sqlx_macros::expand_query!(record = $out_struct, source_file = $path);
}
macro_result!()
});
($out_struct:path, $query:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] {
($out_struct:path, $path:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file_as!($out_struct, $query, $($args),*);
$crate::sqlx_macros::expand_query!(record = $out_struct, source_file = $path, args = [$($args),*]);
}
macro_result!($($args),*)
})
@ -371,7 +369,7 @@ macro_rules! query_as_unchecked (
($out_struct:path, $query:literal) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_as_unchecked!($out_struct, $query);
$crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, checked = false);
}
macro_result!()
});
@ -379,7 +377,7 @@ macro_rules! query_as_unchecked (
($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_as_unchecked!($out_struct, $query, $($args),*);
$crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, args = [$($args),*], checked = false);
}
macro_result!($($args),*)
})
@ -391,18 +389,18 @@ macro_rules! query_as_unchecked (
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file_as_unchecked (
($out_struct:path, $query:literal) => (#[allow(dead_code)] {
($out_struct:path, $path:literal) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file_as_unchecked!($out_struct, $query);
$crate::sqlx_macros::query_file_as_unchecked!(record = $out_struct, source_file = $path, checked = false);
}
macro_result!()
});
($out_struct:path, $query:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] {
($out_struct:path, $path:literal, $($args:tt),*$(,)?) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file_as_unchecked!($out_struct, $query, $($args),*);
$crate::sqlx_macros::query_file_as_unchecked!(record = $out_struct, source_file = $path, args = [$($args),*], checked = false);
}
macro_result!($($args),*)
})