feat(macros): implement query_scalar!() and variants

Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
Austin Bonander 2020-06-23 20:51:39 -07:00 committed by Ryan Leckey
parent b0c430ed18
commit 7c32928ebc
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
8 changed files with 396 additions and 64 deletions

View File

@ -39,5 +39,14 @@ macro_rules! impl_into_arguments_for_arguments {
}; };
} }
/// used by the query macros to prevent supernumerary `.bind()` calls
pub struct ImmutableArguments<'q, DB: HasArguments<'q>>(pub <DB as HasArguments<'q>>::Arguments);
impl<'q, DB: HasArguments<'q>> IntoArguments<'q, DB> for ImmutableArguments<'q, DB> {
fn into_arguments(self) -> <DB as HasArguments<'q>>::Arguments {
self.0
}
}
// TODO: Impl `IntoArguments` for &[&dyn Encode] // TODO: Impl `IntoArguments` for &[&dyn Encode]
// TODO: Impl `IntoArguments` for (impl Encode, ...) x16 // TODO: Impl `IntoArguments` for (impl Encode, ...) x16

View File

@ -27,6 +27,7 @@ enum QuerySrc {
pub enum RecordType { pub enum RecordType {
Given(Type), Given(Type),
Scalar,
Generated, Generated,
} }
@ -62,7 +63,21 @@ impl Parse for QueryMacroInput {
let exprs = input.parse::<ExprArray>()?; let exprs = input.parse::<ExprArray>()?;
args = Some(exprs.elems.into_iter().collect()) args = Some(exprs.elems.into_iter().collect())
} else if key == "record" { } else if key == "record" {
if !matches!(record_type, RecordType::Generated) {
return Err(input.error("colliding `scalar` or `record` key"));
}
record_type = RecordType::Given(input.parse()?); record_type = RecordType::Given(input.parse()?);
} else if key == "scalar" {
if !matches!(record_type, RecordType::Generated) {
return Err(input.error("colliding `scalar` or `record` key"));
}
// we currently expect only `scalar = _`
// a `query_as_scalar!()` variant seems less useful than just overriding the type
// of the column in SQL
input.parse::<syn::Token![_]>()?;
record_type = RecordType::Scalar;
} else if key == "checked" { } else if key == "checked" {
let lit_bool = input.parse::<LitBool>()?; let lit_bool = input.parse::<LitBool>()?;
checked = lit_bool.value; checked = lit_bool.value;

View File

@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::env; use std::env;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
@ -231,11 +230,9 @@ where
if let Some(num) = num_parameters { if let Some(num) = num_parameters {
if num != input.arg_exprs.len() { if num != input.arg_exprs.len() {
return Err(syn::Error::new( return Err(
Span::call_site(), format!("expected {} parameters, got {}", num, input.arg_exprs.len()).into(),
format!("expected {} parameters, got {}", num, input.arg_exprs.len()), );
)
.into());
} }
} }
@ -256,10 +253,10 @@ where
sqlx::query_with::<#db_path, _>(#sql, #query_args) sqlx::query_with::<#db_path, _>(#sql, #query_args)
} }
} else { } else {
let columns = output::columns_to_rust::<DB>(&data.describe)?; match input.record_type {
let (out_ty, mut record_tokens) = match input.record_type {
RecordType::Generated => { RecordType::Generated => {
let columns = output::columns_to_rust::<DB>(&data.describe)?;
let record_name: Type = syn::parse_str("Record").unwrap(); let record_name: Type = syn::parse_str("Record").unwrap();
for rust_col in &columns { for rust_col in &columns {
@ -278,26 +275,31 @@ where
}| quote!(#ident: #type_,), }| quote!(#ident: #type_,),
); );
let record_tokens = quote! { let mut record_tokens = quote! {
#[derive(Debug)] #[derive(Debug)]
struct #record_name { struct #record_name {
#(#record_fields)* #(#record_fields)*
} }
}; };
(Cow::Owned(record_name), record_tokens) record_tokens.extend(output::quote_query_as::<DB>(
&input,
&record_name,
&query_args,
&columns,
));
record_tokens
} }
RecordType::Given(ref out_ty) => (Cow::Borrowed(out_ty), quote!()), RecordType::Given(ref out_ty) => {
}; let columns = output::columns_to_rust::<DB>(&data.describe)?;
record_tokens.extend(output::quote_query_as::<DB>( output::quote_query_as::<DB>(&input, out_ty, &query_args, &columns)
&input, }
&out_ty, RecordType::Scalar => {
&query_args, output::quote_query_scalar::<DB>(&input, &query_args, &data.describe)?
&columns, }
)); }
record_tokens
}; };
let ret_tokens = quote! { let ret_tokens = quote! {

View File

@ -75,51 +75,48 @@ impl Display for DisplayColumn<'_> {
} }
pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Result<Vec<RustColumn>> { pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Result<Vec<RustColumn>> {
describe (0..describe.columns().len())
.columns() .map(|i| column_to_rust(describe, i))
.iter()
.enumerate()
.map(|(i, column)| -> crate::Result<_> {
// add raw prefix to all identifiers
let decl = ColumnDecl::parse(&column.name())
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
let ColumnOverride { nullability, type_ } = decl.r#override;
let nullable = match nullability {
ColumnNullabilityOverride::NonNull => false,
ColumnNullabilityOverride::Nullable => true,
ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true),
};
let type_ = match (type_, nullable) {
(ColumnTypeOverride::Exact(type_), false) => {
ColumnType::Exact(type_.to_token_stream())
}
(ColumnTypeOverride::Exact(type_), true) => {
ColumnType::Exact(quote! { Option<#type_> })
}
(ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard,
(ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard,
(ColumnTypeOverride::None, _) => {
let type_ = get_column_type::<DB>(i, column);
if !nullable {
ColumnType::Exact(type_)
} else {
ColumnType::Exact(quote! { Option<#type_> })
}
}
};
Ok(RustColumn {
ident: decl.ident,
type_,
})
})
.collect::<crate::Result<Vec<_>>>() .collect::<crate::Result<Vec<_>>>()
} }
fn column_to_rust<DB: DatabaseExt>(describe: &Describe<DB>, i: usize) -> crate::Result<RustColumn> {
let column = &describe.columns()[i];
// add raw prefix to all identifiers
let decl = ColumnDecl::parse(&column.name())
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
let ColumnOverride { nullability, type_ } = decl.r#override;
let nullable = match nullability {
ColumnNullabilityOverride::NonNull => false,
ColumnNullabilityOverride::Nullable => true,
ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true),
};
let type_ = match (type_, nullable) {
(ColumnTypeOverride::Exact(type_), false) => ColumnType::Exact(type_.to_token_stream()),
(ColumnTypeOverride::Exact(type_), true) => ColumnType::Exact(quote! { Option<#type_> }),
(ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard,
(ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard,
(ColumnTypeOverride::None, _) => {
let type_ = get_column_type::<DB>(i, column);
if !nullable {
ColumnType::Exact(type_)
} else {
ColumnType::Exact(quote! { Option<#type_> })
}
}
};
Ok(RustColumn {
ident: decl.ident,
type_,
})
}
pub fn quote_query_as<DB: DatabaseExt>( pub fn quote_query_as<DB: DatabaseExt>(
input: &QueryMacroInput, input: &QueryMacroInput,
out_ty: &Type, out_ty: &Type,
@ -171,6 +168,43 @@ pub fn quote_query_as<DB: DatabaseExt>(
} }
} }
pub fn quote_query_scalar<DB: DatabaseExt>(
input: &QueryMacroInput,
bind_args: &Ident,
describe: &Describe<DB>,
) -> crate::Result<TokenStream> {
let columns = describe.columns();
if columns.len() != 1 {
return Err(syn::Error::new(
input.src_span,
format!("expected exactly 1 column, got {}", columns.len()),
)
.into());
}
// attempt to parse a column override, otherwise fall back to the inferred type of the column
let ty = if let Ok(rust_col) = column_to_rust(describe, 0) {
rust_col.type_.to_token_stream()
} else if input.checked {
let ty = get_column_type::<DB>(0, &columns[0]);
if describe.nullable(0).unwrap_or(true) {
quote! { Option<#ty> }
} else {
ty
}
} else {
quote! { _ }
};
let db = DB::db_path();
let query = &input.src;
Ok(quote! {
sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args)
})
}
fn get_column_type<DB: DatabaseExt>(i: usize, column: &DB::Column) -> TokenStream { fn get_column_type<DB: DatabaseExt>(i: usize, column: &DB::Column) -> TokenStream {
let type_info = &*column.type_info(); let type_info = &*column.type_info();
@ -268,7 +302,9 @@ fn parse_ident(name: &str) -> crate::Result<Ident> {
// workaround for the following issue (it's semi-fixed but still spits out extra diagnostics) // workaround for the following issue (it's semi-fixed but still spits out extra diagnostics)
// https://github.com/dtolnay/syn/issues/749#issuecomment-575451318 // https://github.com/dtolnay/syn/issues/749#issuecomment-575451318
let is_valid_ident = name.chars().all(|c| c.is_alphanumeric() || c == '_'); let is_valid_ident = !name.is_empty() &&
name.starts_with(|c: char| c.is_alphabetic() || c == '_') &&
name.chars().all(|c| c.is_alphanumeric() || c == '_');
if is_valid_ident { if is_valid_ident {
let ident = String::from("r#") + name; let ident = String::from("r#") + name;

View File

@ -615,6 +615,77 @@ macro_rules! query_file_as_unchecked (
}) })
); );
/// A variant of [query!] which expects a single column from the query and evaluates to an
/// instance of [QueryScalar][crate::query::QueryScalar].
///
/// The name of the column is not required to be a valid Rust identifier, however you can still
/// use the column type override syntax in which case the column name _does_ have to be a valid
/// Rust identifier for the override to parse properly. If the override parse fails the error
/// is silently ignored (we just don't have a reliable way to tell the difference). **If you're
/// getting a different type than expected, please check to see if your override syntax is correct
/// before opening an issue.**
///
/// Wildcard overrides like in [query_as!] are also allowed, in which case the output type
/// is left up to inference.
///
/// See [query!] for more information.
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_scalar (
($query:expr) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query)
);
($query:expr, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query, args = [$($args)*])
)
);
/// A variant of [query_scalar!] which takes a file path like [query_file!].
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file_scalar (
($path:literal) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path)
);
($path:literal, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*])
)
);
/// A variant of [query_scalar!] which does not typecheck bind parameters and leaves the output type
/// to inference. The query itself is still checked that it is syntactically and semantically
/// valid for the database, that it only produces one column and that the number of bind parameters
/// is correct.
///
/// For this macro variant the name of the column is irrelevant.
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_scalar_unchecked (
($query:expr) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query, checked = false)
);
($query:expr, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query, args = [$($args)*], checked = false)
)
);
/// A variant of [query_file_scalar!] which does not typecheck bind parameters and leaves the output
/// type to inference. The query itself is still checked that it is syntactically and
/// semantically valid for the database, that it only produces one column and that the number of
/// bind parameters is correct.
///
/// For this macro variant the name of the column is irrelevant.
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file_scalar_unchecked (
($path:literal) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, checked = false)
);
($path:literal, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*], checked = false)
)
);
/// Embeds migrations into the binary by expanding to a static instance of [Migrator][crate::migrate::Migrator]. /// Embeds migrations into the binary by expanding to a static instance of [Migrator][crate::migrate::Migrator].
/// ///
/// ```rust,ignore /// ```rust,ignore

View File

@ -58,6 +58,72 @@ async fn test_query_as_raw() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
#[sqlx_macros::test]
async fn test_query_scalar() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?;
// MySQL tells us `LONG LONG` while MariaDB just `LONG`
assert_eq!(id, 1);
// invalid column names are ignored
let id = sqlx::query_scalar!(r#"select 1 as `&foo`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1);
let id = sqlx::query_scalar!(r#"select 1 as `foo!`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1);
let id = sqlx::query_scalar!(r#"select 1 as `foo?`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1));
let id = sqlx::query_scalar!(r#"select 1 as `foo: MyInt`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
let id = sqlx::query_scalar!(r#"select 1 as `foo?: MyInt`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt(1)));
let id = sqlx::query_scalar!(r#"select 1 as `foo!: MyInt`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo: _`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo?: _`"#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt>`
.unwrap();
assert_eq!(id, MyInt(1));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo!: _`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
Ok(())
}
#[sqlx_macros::test] #[sqlx_macros::test]
async fn test_query_as_bool() -> anyhow::Result<()> { async fn test_query_as_bool() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?; let mut conn = new::<MySql>().await?;

View File

@ -174,6 +174,74 @@ async fn test_query_file_as() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
#[sqlx_macros::test]
async fn test_query_scalar() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?;
// nullability inference can't handle expressions
assert_eq!(id, Some(1i32));
// invalid column names are ignored
let id = sqlx::query_scalar!(r#"select 1 as "&foo""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1i32));
let id = sqlx::query_scalar!(r#"select 1 as "foo!""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1i32);
let id = sqlx::query_scalar!(r#"select 1 as "foo?""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1i32));
let id = sqlx::query_scalar!(r#"select 1 as "foo: MyInt4""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt4(1i32)));
let id = sqlx::query_scalar!(r#"select 1 as "foo?: MyInt4""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt4(1i32)));
let id = sqlx::query_scalar!(r#"select 1 as "foo!: MyInt4""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt4(1i32));
let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo: _""#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt4>`
.unwrap();
assert_eq!(id, MyInt4(1i32));
let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo?: _""#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt4>`
.unwrap();
assert_eq!(id, MyInt4(1i32));
let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo!: _""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt4(1i32));
Ok(())
}
#[sqlx_macros::test] #[sqlx_macros::test]
async fn query_by_string() -> anyhow::Result<()> { async fn query_by_string() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?; let mut conn = new::<Postgres>().await?;

View File

@ -110,6 +110,71 @@ async fn test_query_as_raw() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
#[sqlx_macros::test]
async fn test_query_scalar() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?;
assert_eq!(id, 1i32);
// invalid column names are ignored
let id = sqlx::query_scalar!(r#"select 1 as "&foo""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1i32);
let id = sqlx::query_scalar!(r#"select 1 as "foo!""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1i32);
let id = sqlx::query_scalar!(r#"select 1 as "foo?""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1i32));
let id = sqlx::query_scalar!(r#"select 1 as "foo: MyInt""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
let id = sqlx::query_scalar!(r#"select 1 as "foo?: MyInt""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt(1i64)));
let id = sqlx::query_scalar!(r#"select 1 as "foo!: MyInt""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo: _""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo?: _""#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt>`
.unwrap();
assert_eq!(id, MyInt(1i64));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo!: _""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
Ok(())
}
#[sqlx_macros::test] #[sqlx_macros::test]
async fn macro_select_from_view() -> anyhow::Result<()> { async fn macro_select_from_view() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?; let mut conn = new::<Sqlite>().await?;