query_macros: allow Option<&str> to be passed in place of String

closes #93
This commit is contained in:
Austin Bonander
2020-01-27 19:02:46 -08:00
committed by Ryan Leckey
parent 4163388298
commit 800af574c5
13 changed files with 266 additions and 144 deletions

View File

@@ -1,7 +1,7 @@
impl_database_ext! {
sqlx::postgres::Postgres {
bool,
String,
String | &str,
i16,
i32,
i64,
@@ -9,7 +9,7 @@ impl_database_ext! {
f64,
// BYTEA
Vec<u8>,
Vec<u8> | &[u8],
#[cfg(feature = "uuid")]
sqlx::types::Uuid,

View File

@@ -1,79 +1,96 @@
use proc_macro2::TokenStream;
use quote::{quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::Expr;
use quote::{quote, quote_spanned, ToTokens};
use sqlx::describe::Describe;
use crate::database::{DatabaseExt, ParamChecking};
use crate::query_macros::QueryMacroInput;
/// Returns a tokenstream which typechecks the arguments passed to the macro
/// and binds them to `DB::Arguments` with the ident `query_args`.
pub fn quote_args<DB: DatabaseExt>(
input: &QueryMacroInput,
describe: &Describe<DB>,
) -> crate::Result<TokenStream> {
let db_path = DB::db_path();
if input.arg_names.is_empty() {
return Ok(quote! {
let args = ();
let query_args = <#db_path as sqlx::Database>::Arguments::default();
});
}
let arg_name = &input.arg_names;
let args_check = if DB::PARAM_CHECKING == ParamChecking::Strong {
let param_types = describe
describe
.param_types
.iter()
.zip(&*input.arg_exprs)
.zip(input.arg_names.iter().zip(&input.arg_exprs))
.enumerate()
.map(|(i, (type_, expr))| {
get_type_override(expr)
.map(|(i, (param_ty, (name, expr)))| -> crate::Result<_>{
let param_ty = get_type_override(expr)
.or_else(|| {
Some(
DB::param_type_for_id(type_)?
DB::param_type_for_id(param_ty)?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.ok_or_else(|| {
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&type_) {
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&param_ty) {
format!(
"optional feature `{}` required for type {} of param #{}",
feature_gate,
type_,
param_ty,
i + 1,
)
.into()
} else {
format!("unsupported type {} for param #{}", type_, i + 1).into()
format!("unsupported type {} for param #{}", param_ty, i + 1)
}
})
})
.collect::<crate::Result<Vec<_>>>()?;
})?;
let args_ty_cons = input.arg_names.iter().enumerate().map(|(i, expr)| {
// required or `quote!()` emits it as `Nusize`
let i = syn::Index::from(i);
quote_spanned!( expr.span() => {
use sqlx::ty_cons::TyConsExt as _;
sqlx::ty_cons::TyCons::new(&args.#i).ty_cons()
})
});
Ok(quote_spanned!(expr.span() =>
// this shouldn't actually run
if false {
use sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _};
// we want to make sure it doesn't run
quote! {
if false {
let _: (#(#param_types),*,) = (#(#args_ty_cons),*,);
}
}
// evaluate the expression only once in case it contains moves
let _expr = sqlx::ty_match::dupe_value(&$#name);
// if `_expr` is `Option<T>`, get `Option<$ty>`, otherwise `$ty`
let ty_check = sqlx::ty_match::WrapSame::<#param_ty, _>::new(&_expr).wrap_same();
// if `_expr` is `&str`, convert `String` to `&str`
let (mut ty_check, match_borrow) = sqlx::ty_match::MatchBorrow::new(ty_check, &_expr);
ty_check = match_borrow.match_borrow();
// this causes move-analysis to effectively ignore this block
panic!();
}
))
})
.collect::<crate::Result<TokenStream>>()?
} else {
// all we can do is check arity which we did in `QueryMacroInput::describe_validate()`
TokenStream::new()
};
let args = input.arg_names.iter();
let args_count = input.arg_names.len();
Ok(quote! {
// emit as a tuple first so each expression is only evaluated once
let args = (#(&$#args),*,);
#args_check
// bind as a local expression, by-ref
#(let #arg_name = &$#arg_name;)*
let mut query_args = <#db_path as sqlx::Database>::Arguments::default();
query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(#arg_name))*
);
#(query_args.add(#arg_name);)*
})
}

View File

@@ -1,7 +1,7 @@
use std::env;
use proc_macro2::{Ident, Span};
use sqlx::runtime::fs;
use quote::{format_ident, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
@@ -9,10 +9,9 @@ use syn::token::Group;
use syn::{Expr, ExprLit, ExprPath, Lit};
use syn::{ExprGroup, Token};
use quote::{format_ident, ToTokens};
use sqlx::connection::Connection;
use sqlx::describe::Describe;
use sqlx::runtime::fs;
/// Macro input shared by `query!()` and `query_file!()`
pub struct QueryMacroInput {

View File

@@ -46,7 +46,6 @@ where
}
let args_tokens = args::quote_args(&input.query_input, &describe)?;
let arg_names = &input.query_input.arg_names;
let query_args = format_ident!("query_args");
@@ -58,10 +57,7 @@ where
&columns,
);
let db_path = <C::Database as DatabaseExt>::db_path();
let args_count = arg_names.len();
let arg_indices = (0..args_count).map(|i| syn::Index::from(i));
let arg_indices_2 = arg_indices.clone();
let arg_names = &input.query_input.arg_names;
Ok(quote! {
macro_rules! macro_result {
@@ -70,13 +66,6 @@ where
#args_tokens
let mut #query_args = <#db_path as sqlx::Database>::Arguments::default();
#query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(#query_args.add(args.#arg_indices_2);)*
#output
}}
}

View File

@@ -26,9 +26,6 @@ where
let args = args::quote_args(&input, &describe)?;
let arg_names = &input.arg_names;
let args_count = arg_names.len();
let arg_indices = (0..args_count).map(|i| syn::Index::from(i));
let arg_indices_2 = arg_indices.clone();
let db_path = <C::Database as DatabaseExt>::db_path();
if describe.result_columns.is_empty() {
@@ -39,14 +36,6 @@ where
#args
let mut query_args = <#db_path as sqlx::Database>::Arguments::default();
query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(query_args.add(args.#arg_indices_2);)*
sqlx::query::<#db_path>(#sql).bind_all(query_args)
}
}}
@@ -85,14 +74,6 @@ where
#args
let mut #query_args = <#db_path as sqlx::Database>::Arguments::default();
#query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(#query_args.add(args.#arg_indices_2);)*
#output
}
}}