From 800af574c59ec47cbbb293a5660be589792ff2f4 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 27 Jan 2020 19:02:46 -0800 Subject: [PATCH] query_macros: allow `Option<&str>` to be passed in place of `String` closes #93 --- sqlx-core/src/encode.rs | 6 +- sqlx-macros/src/database/postgres.rs | 4 +- sqlx-macros/src/query_macros/args.rs | 79 ++++++++------ sqlx-macros/src/query_macros/input.rs | 5 +- sqlx-macros/src/query_macros/mod.rs | 13 +-- sqlx-macros/src/query_macros/query.rs | 19 ---- src/lib.rs | 2 +- src/macros.rs | 16 +-- src/ty_cons.rs | 59 ----------- src/ty_match.rs | 122 ++++++++++++++++++++++ tests/postgres-macros.rs | 16 ++- tests/ui/postgres/wrong_param_type.rs | 14 +++ tests/ui/postgres/wrong_param_type.stderr | 55 ++++++++++ 13 files changed, 266 insertions(+), 144 deletions(-) delete mode 100644 src/ty_cons.rs create mode 100644 src/ty_match.rs create mode 100644 tests/ui/postgres/wrong_param_type.rs create mode 100644 tests/ui/postgres/wrong_param_type.stderr diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index b7d38873..3f856dfd 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -75,10 +75,6 @@ where } fn size_hint(&self) -> usize { - if self.is_some() { - (*self).size_hint() - } else { - 0 - } + self.as_ref().map_or(0, Encode::size_hint) } } diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index c84da77c..8a221be1 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -1,7 +1,7 @@ impl_database_ext! { sqlx::postgres::Postgres { bool, - String, + String | &str, i16, i32, i64, @@ -9,7 +9,7 @@ impl_database_ext! { f64, // BYTEA - Vec, + Vec | &[u8], #[cfg(feature = "uuid")] sqlx::types::Uuid, diff --git a/sqlx-macros/src/query_macros/args.rs b/sqlx-macros/src/query_macros/args.rs index 438b777b..ce08efc0 100644 --- a/sqlx-macros/src/query_macros/args.rs +++ b/sqlx-macros/src/query_macros/args.rs @@ -1,79 +1,96 @@ use proc_macro2::TokenStream; -use quote::{quote, quote_spanned, ToTokens}; +use syn::spanned::Spanned; use syn::Expr; +use quote::{quote, quote_spanned, ToTokens}; use sqlx::describe::Describe; use crate::database::{DatabaseExt, ParamChecking}; use crate::query_macros::QueryMacroInput; +/// Returns a tokenstream which typechecks the arguments passed to the macro +/// and binds them to `DB::Arguments` with the ident `query_args`. pub fn quote_args( input: &QueryMacroInput, describe: &Describe, ) -> crate::Result { + let db_path = DB::db_path(); + if input.arg_names.is_empty() { return Ok(quote! { - let args = (); + let query_args = <#db_path as sqlx::Database>::Arguments::default(); }); } + let arg_name = &input.arg_names; + let args_check = if DB::PARAM_CHECKING == ParamChecking::Strong { - let param_types = describe + describe .param_types .iter() - .zip(&*input.arg_exprs) + .zip(input.arg_names.iter().zip(&input.arg_exprs)) .enumerate() - .map(|(i, (type_, expr))| { - get_type_override(expr) + .map(|(i, (param_ty, (name, expr)))| -> crate::Result<_>{ + let param_ty = get_type_override(expr) .or_else(|| { Some( - DB::param_type_for_id(type_)? + DB::param_type_for_id(param_ty)? .parse::() .unwrap(), ) }) .ok_or_else(|| { - if let Some(feature_gate) = ::get_feature_gate(&type_) { + if let Some(feature_gate) = ::get_feature_gate(¶m_ty) { format!( "optional feature `{}` required for type {} of param #{}", feature_gate, - type_, + param_ty, i + 1, ) - .into() } else { - format!("unsupported type {} for param #{}", type_, i + 1).into() + format!("unsupported type {} for param #{}", param_ty, i + 1) } - }) - }) - .collect::>>()?; + })?; - let args_ty_cons = input.arg_names.iter().enumerate().map(|(i, expr)| { - // required or `quote!()` emits it as `Nusize` - let i = syn::Index::from(i); - quote_spanned!( expr.span() => { - use sqlx::ty_cons::TyConsExt as _; - sqlx::ty_cons::TyCons::new(&args.#i).ty_cons() - }) - }); + Ok(quote_spanned!(expr.span() => + // this shouldn't actually run + if false { + use sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _}; - // we want to make sure it doesn't run - quote! { - if false { - let _: (#(#param_types),*,) = (#(#args_ty_cons),*,); - } - } + // evaluate the expression only once in case it contains moves + let _expr = sqlx::ty_match::dupe_value(&$#name); + + // if `_expr` is `Option`, get `Option<$ty>`, otherwise `$ty` + let ty_check = sqlx::ty_match::WrapSame::<#param_ty, _>::new(&_expr).wrap_same(); + // if `_expr` is `&str`, convert `String` to `&str` + let (mut ty_check, match_borrow) = sqlx::ty_match::MatchBorrow::new(ty_check, &_expr); + + ty_check = match_borrow.match_borrow(); + + // this causes move-analysis to effectively ignore this block + panic!(); + } + )) + }) + .collect::>()? } else { // all we can do is check arity which we did in `QueryMacroInput::describe_validate()` TokenStream::new() }; - let args = input.arg_names.iter(); + let args_count = input.arg_names.len(); Ok(quote! { - // emit as a tuple first so each expression is only evaluated once - let args = (#(&$#args),*,); #args_check + + // bind as a local expression, by-ref + #(let #arg_name = &$#arg_name;)* + let mut query_args = <#db_path as sqlx::Database>::Arguments::default(); + query_args.reserve( + #args_count, + 0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(#arg_name))* + ); + #(query_args.add(#arg_name);)* }) } diff --git a/sqlx-macros/src/query_macros/input.rs b/sqlx-macros/src/query_macros/input.rs index 445152e6..0969baee 100644 --- a/sqlx-macros/src/query_macros/input.rs +++ b/sqlx-macros/src/query_macros/input.rs @@ -1,7 +1,7 @@ use std::env; use proc_macro2::{Ident, Span}; -use sqlx::runtime::fs; +use quote::{format_ident, ToTokens}; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; @@ -9,10 +9,9 @@ use syn::token::Group; use syn::{Expr, ExprLit, ExprPath, Lit}; use syn::{ExprGroup, Token}; -use quote::{format_ident, ToTokens}; - use sqlx::connection::Connection; use sqlx::describe::Describe; +use sqlx::runtime::fs; /// Macro input shared by `query!()` and `query_file!()` pub struct QueryMacroInput { diff --git a/sqlx-macros/src/query_macros/mod.rs b/sqlx-macros/src/query_macros/mod.rs index aa68ccf0..77ec93c3 100644 --- a/sqlx-macros/src/query_macros/mod.rs +++ b/sqlx-macros/src/query_macros/mod.rs @@ -46,7 +46,6 @@ where } let args_tokens = args::quote_args(&input.query_input, &describe)?; - let arg_names = &input.query_input.arg_names; let query_args = format_ident!("query_args"); @@ -58,10 +57,7 @@ where &columns, ); - let db_path = ::db_path(); - let args_count = arg_names.len(); - let arg_indices = (0..args_count).map(|i| syn::Index::from(i)); - let arg_indices_2 = arg_indices.clone(); + let arg_names = &input.query_input.arg_names; Ok(quote! { macro_rules! macro_result { @@ -70,13 +66,6 @@ where #args_tokens - let mut #query_args = <#db_path as sqlx::Database>::Arguments::default(); - #query_args.reserve( - #args_count, - 0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))* - ); - #(#query_args.add(args.#arg_indices_2);)* - #output }} } diff --git a/sqlx-macros/src/query_macros/query.rs b/sqlx-macros/src/query_macros/query.rs index d675b67b..105337cd 100644 --- a/sqlx-macros/src/query_macros/query.rs +++ b/sqlx-macros/src/query_macros/query.rs @@ -26,9 +26,6 @@ where let args = args::quote_args(&input, &describe)?; let arg_names = &input.arg_names; - let args_count = arg_names.len(); - let arg_indices = (0..args_count).map(|i| syn::Index::from(i)); - let arg_indices_2 = arg_indices.clone(); let db_path = ::db_path(); if describe.result_columns.is_empty() { @@ -39,14 +36,6 @@ where #args - let mut query_args = <#db_path as sqlx::Database>::Arguments::default(); - query_args.reserve( - #args_count, - 0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))* - ); - - #(query_args.add(args.#arg_indices_2);)* - sqlx::query::<#db_path>(#sql).bind_all(query_args) } }} @@ -85,14 +74,6 @@ where #args - let mut #query_args = <#db_path as sqlx::Database>::Arguments::default(); - #query_args.reserve( - #args_count, - 0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))* - ); - - #(#query_args.add(args.#arg_indices_2);)* - #output } }} diff --git a/src/lib.rs b/src/lib.rs index 669b48a2..1f78550f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,7 +40,7 @@ mod macros; // macro support #[cfg(feature = "macros")] #[doc(hidden)] -pub mod ty_cons; +pub mod ty_match; #[cfg(feature = "macros")] #[doc(hidden)] diff --git a/src/macros.rs b/src/macros.rs index 15563551..f0f092ea 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -97,9 +97,9 @@ macro_rules! query ( ($query:literal, $($args:expr),*$(,)?) => ({ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query!($query, $($args),*); + $crate::sqlx_macros::query!($query, $($args)*); } - macro_result!($($args),*) + macro_result!($($args)*) }) ); @@ -158,9 +158,9 @@ macro_rules! query_file ( ($query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{ #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file!($query, $($args),*); + $crate::sqlx_macros::query_file!($query, $($args)*); } - macro_result!($($args),*) + macro_result!($($args)*) }) ); @@ -224,9 +224,9 @@ macro_rules! query_as ( ($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_as!($out_struct, $query, $($args),*); + $crate::sqlx_macros::query_as!($out_struct, $query, $($args)*); } - macro_result!($($args),*) + macro_result!($($args)*) }) ); @@ -275,8 +275,8 @@ macro_rules! query_file_as ( ($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] { #[macro_use] mod _macro_result { - $crate::sqlx_macros::query_file_as!($out_struct, $query, $($args),*); + $crate::sqlx_macros::query_file_as!($out_struct, $query, $($args)*); } - macro_result!($($args),*) + macro_result!($($args)*) }) ); diff --git a/src/ty_cons.rs b/src/ty_cons.rs deleted file mode 100644 index c6e06e93..00000000 --- a/src/ty_cons.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::marker::PhantomData; - -// These types allow the `sqlx_macros::query_[as]!()` macros to polymorphically compare a -// given parameter's type to an expected parameter type even if the former -// is behind a reference or in `Option` - -#[doc(hidden)] -pub struct TyCons(PhantomData); - -impl TyCons { - pub fn new(_t: &T) -> TyCons { - TyCons(PhantomData) - } -} - -#[doc(hidden)] -pub trait TyConsExt: Sized { - type Cons; - fn ty_cons(self) -> Self::Cons { - panic!("should not be run, only for type resolution") - } -} - -impl TyCons> { - pub fn ty_cons(self) -> T { - panic!("should not be run, only for type resolution") - } -} - -// no overlap with the following impls because of the `: Sized` bound -impl TyConsExt for TyCons<&'_ T> { - type Cons = T; -} - -impl TyConsExt for TyCons<&'_ str> { - type Cons = String; -} - -impl TyConsExt for TyCons<&'_ [T]> { - type Cons = Vec; -} - -impl TyConsExt for TyCons> { - type Cons = T; -} - -impl TyConsExt for &'_ TyCons { - type Cons = T; -} - -#[test] -fn test_tycons_ext() { - if false { - let _: u64 = TyCons::new(&Some(5u64)).ty_cons(); - let _: u64 = TyCons::new(&Some(&5u64)).ty_cons(); - let _: u64 = TyCons::new(&&5u64).ty_cons(); - let _: u64 = TyCons::new(&5u64).ty_cons(); - } -} diff --git a/src/ty_match.rs b/src/ty_match.rs new file mode 100644 index 00000000..ebaae86e --- /dev/null +++ b/src/ty_match.rs @@ -0,0 +1,122 @@ +use std::marker::PhantomData; + +// These types allow the `query!()` and friends to compare a given parameter's type to +// an expected parameter type even if the former is behind a reference or in `Option`. + +// For query parameters, Postgres gives us a single type ID which we convert to an "expected" or +// preferred Rust type, but there can actually be several types that are compatible for a given type +// in input position. E.g. for an expected parameter of `String`, we want to accept `String`, +// `Option`, `&str` and `Option<&str>`. And for the best compiler errors we don't just +// want an `IsCompatible` trait (at least not without `#[on_unimplemented]` which is unstable +// for the foreseeable future). + +// We can do this by using autoref (for method calls, the compiler adds reference ops until +// it finds a matching impl) with impls that technically don't overlap as a hacky form of +// specialization (but this works only if all types are statically known, i.e. we're not in a +// generic context; this should suit 99% of use cases for the macros). + +pub fn same_type(_1: &T, _2: &T) {} + +pub struct WrapSame(PhantomData, PhantomData); + +impl WrapSame { + pub fn new(_arg: &U) -> Self { + WrapSame(PhantomData, PhantomData) + } +} + +pub trait WrapSameExt: Sized { + type Wrapped; + + fn wrap_same(self) -> Self::Wrapped { + panic!("only for type resolution") + } +} + +impl WrapSameExt for WrapSame> { + type Wrapped = Option; +} + +impl WrapSameExt for &'_ WrapSame { + type Wrapped = T; +} + +pub struct MatchBorrow(PhantomData, PhantomData); + +impl MatchBorrow { + pub fn new(t: T, _u: &U) -> (T, Self) { + (t, MatchBorrow(PhantomData, PhantomData)) + } +} + +pub trait MatchBorrowExt: Sized { + type Matched; + + fn match_borrow(self) -> Self::Matched { + panic!("only for type resolution") + } +} + +impl<'a> MatchBorrowExt for MatchBorrow, Option> { + type Matched = Option<&'a str>; +} + +impl<'a> MatchBorrowExt for MatchBorrow, Option>> { + type Matched = Option<&'a [u8]>; +} + +impl<'a> MatchBorrowExt for MatchBorrow, Option<&'a String>> { + type Matched = Option<&'a str>; +} + +impl<'a> MatchBorrowExt for MatchBorrow, Option<&'a Vec>> { + type Matched = Option<&'a [u8]>; +} + +impl<'a> MatchBorrowExt for MatchBorrow<&'a str, String> { + type Matched = &'a str; +} + +impl<'a> MatchBorrowExt for MatchBorrow<&'a [u8], Vec> { + type Matched = &'a [u8]; +} + +impl MatchBorrowExt for &'_ MatchBorrow { + type Matched = U; +} + +pub fn conjure_value() -> T { + panic!() +} + +pub fn dupe_value(_t: &T) -> T { + panic!() +} + +#[test] +fn test_dupe_value() { + let ref val = (String::new(),); + + if false { + let _: i32 = dupe_value(&0i32); + let _: String = dupe_value(&String::new()); + let _: String = dupe_value(&val.0); + } +} + +#[test] +fn test_wrap_same() { + if false { + let _: i32 = WrapSame::::new(&0i32).wrap_same(); + let _: i32 = WrapSame::::new(&"hello, world!").wrap_same(); + let _: Option = WrapSame::::new(&Some(String::new())).wrap_same(); + } +} + +#[test] +fn test_match_borrow() { + if false { + let (_, match_borrow) = MatchBorrow::new("", &String::new()); + let _: &str = match_borrow.match_borrow(); + } +} diff --git a/tests/postgres-macros.rs b/tests/postgres-macros.rs index 13fdf723..bf18f55d 100644 --- a/tests/postgres-macros.rs +++ b/tests/postgres-macros.rs @@ -54,9 +54,11 @@ struct Account { async fn test_query_as() -> anyhow::Result<()> { let mut conn = connect().await?; + let name: Option<&str> = None; let account = sqlx::query_as!( Account, - "SELECT * from (VALUES (1, null)) accounts(id, name)", + "SELECT * from (VALUES (1, $1)) accounts(id, name)", + name ) .fetch_one(&mut conn) .await?; @@ -114,12 +116,18 @@ async fn query_by_string() -> anyhow::Result<()> { let mut conn = connect().await?; let string = "Hello, world!".to_string(); + let ref tuple = ("Hello, world!".to_string(),); let result = sqlx::query!( "SELECT * from (VALUES('Hello, world!')) strings(string)\ - where string = $1 or string = $2", - string, - string[..] + where string in ($1, $2, $3, $4, $5, $6, $7)", + string, // make sure we don't actually take ownership here + &string[..], + Some(&string), + Some(&string[..]), + Option::::None, + string.clone(), + tuple.0 // make sure we're not trying to move out of a field expression ) .fetch_one(&mut conn) .await?; diff --git a/tests/ui/postgres/wrong_param_type.rs b/tests/ui/postgres/wrong_param_type.rs new file mode 100644 index 00000000..18737f0e --- /dev/null +++ b/tests/ui/postgres/wrong_param_type.rs @@ -0,0 +1,14 @@ +fn main() { + let _query = sqlx::query!("select $1::text", 0i32); + + let _query = sqlx::query!("select $1::text", &0i32); + + let _query = sqlx::query!("select $1::text", Some(0i32)); + + let arg = 0i32; + let _query = sqlx::query!("select $1::text", arg); + + let arg = Some(0i32); + let _query = sqlx::query!("select $1::text", arg); + let _query = sqlx::query!("select $1::text", arg.as_ref()); +} diff --git a/tests/ui/postgres/wrong_param_type.stderr b/tests/ui/postgres/wrong_param_type.stderr new file mode 100644 index 00000000..0166b851 --- /dev/null +++ b/tests/ui/postgres/wrong_param_type.stderr @@ -0,0 +1,55 @@ +error[E0308]: mismatched types + --> $DIR/wrong_param_type.rs:2:50 + | +2 | let _query = sqlx::query!("select $1::text", 0i32); + | ^^^^ expected `&str`, found `i32` + | + = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0308]: mismatched types + --> $DIR/wrong_param_type.rs:4:50 + | +4 | let _query = sqlx::query!("select $1::text", &0i32); + | ^^^^^ expected `str`, found `i32` + | + = note: expected reference `&str` + found reference `&i32` + = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0308]: mismatched types + --> $DIR/wrong_param_type.rs:6:50 + | +6 | let _query = sqlx::query!("select $1::text", Some(0i32)); + | ^^^^^^^^^^ expected `&str`, found `i32` + | + = note: expected enum `std::option::Option<&str>` + found enum `std::option::Option` + = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0308]: mismatched types + --> $DIR/wrong_param_type.rs:9:50 + | +9 | let _query = sqlx::query!("select $1::text", arg); + | ^^^ expected `&str`, found `i32` + | + = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0308]: mismatched types + --> $DIR/wrong_param_type.rs:12:50 + | +12 | let _query = sqlx::query!("select $1::text", arg); + | ^^^ expected `&str`, found `i32` + | + = note: expected enum `std::option::Option<&str>` + found enum `std::option::Option` + = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0308]: mismatched types + --> $DIR/wrong_param_type.rs:13:50 + | +13 | let _query = sqlx::query!("select $1::text", arg.as_ref()); + | ^^^^^^^^^^^^ expected `str`, found `i32` + | + = note: expected enum `std::option::Option<&str>` + found enum `std::option::Option<&i32>` + = note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)