feat(macros): implement query_scalar!() and variants

Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
Austin Bonander 2020-06-23 20:51:39 -07:00 committed by Ryan Leckey
parent b0c430ed18
commit 7c32928ebc
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
8 changed files with 396 additions and 64 deletions

View File

@ -39,5 +39,14 @@ macro_rules! impl_into_arguments_for_arguments {
};
}
/// used by the query macros to prevent supernumerary `.bind()` calls
pub struct ImmutableArguments<'q, DB: HasArguments<'q>>(pub <DB as HasArguments<'q>>::Arguments);
impl<'q, DB: HasArguments<'q>> IntoArguments<'q, DB> for ImmutableArguments<'q, DB> {
fn into_arguments(self) -> <DB as HasArguments<'q>>::Arguments {
self.0
}
}
// TODO: Impl `IntoArguments` for &[&dyn Encode]
// TODO: Impl `IntoArguments` for (impl Encode, ...) x16

View File

@ -27,6 +27,7 @@ enum QuerySrc {
pub enum RecordType {
Given(Type),
Scalar,
Generated,
}
@ -62,7 +63,21 @@ impl Parse for QueryMacroInput {
let exprs = input.parse::<ExprArray>()?;
args = Some(exprs.elems.into_iter().collect())
} else if key == "record" {
if !matches!(record_type, RecordType::Generated) {
return Err(input.error("colliding `scalar` or `record` key"));
}
record_type = RecordType::Given(input.parse()?);
} else if key == "scalar" {
if !matches!(record_type, RecordType::Generated) {
return Err(input.error("colliding `scalar` or `record` key"));
}
// we currently expect only `scalar = _`
// a `query_as_scalar!()` variant seems less useful than just overriding the type
// of the column in SQL
input.parse::<syn::Token![_]>()?;
record_type = RecordType::Scalar;
} else if key == "checked" {
let lit_bool = input.parse::<LitBool>()?;
checked = lit_bool.value;

View File

@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::env;
use std::path::{Path, PathBuf};
use std::process::Command;
@ -231,11 +230,9 @@ where
if let Some(num) = num_parameters {
if num != input.arg_exprs.len() {
return Err(syn::Error::new(
Span::call_site(),
format!("expected {} parameters, got {}", num, input.arg_exprs.len()),
)
.into());
return Err(
format!("expected {} parameters, got {}", num, input.arg_exprs.len()).into(),
);
}
}
@ -256,10 +253,10 @@ where
sqlx::query_with::<#db_path, _>(#sql, #query_args)
}
} else {
let columns = output::columns_to_rust::<DB>(&data.describe)?;
let (out_ty, mut record_tokens) = match input.record_type {
match input.record_type {
RecordType::Generated => {
let columns = output::columns_to_rust::<DB>(&data.describe)?;
let record_name: Type = syn::parse_str("Record").unwrap();
for rust_col in &columns {
@ -278,26 +275,31 @@ where
}| quote!(#ident: #type_,),
);
let record_tokens = quote! {
let mut record_tokens = quote! {
#[derive(Debug)]
struct #record_name {
#(#record_fields)*
}
};
(Cow::Owned(record_name), record_tokens)
record_tokens.extend(output::quote_query_as::<DB>(
&input,
&record_name,
&query_args,
&columns,
));
record_tokens
}
RecordType::Given(ref out_ty) => (Cow::Borrowed(out_ty), quote!()),
};
RecordType::Given(ref out_ty) => {
let columns = output::columns_to_rust::<DB>(&data.describe)?;
record_tokens.extend(output::quote_query_as::<DB>(
&input,
&out_ty,
&query_args,
&columns,
));
record_tokens
output::quote_query_as::<DB>(&input, out_ty, &query_args, &columns)
}
RecordType::Scalar => {
output::quote_query_scalar::<DB>(&input, &query_args, &data.describe)?
}
}
};
let ret_tokens = quote! {

View File

@ -75,51 +75,48 @@ impl Display for DisplayColumn<'_> {
}
pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Result<Vec<RustColumn>> {
describe
.columns()
.iter()
.enumerate()
.map(|(i, column)| -> crate::Result<_> {
// add raw prefix to all identifiers
let decl = ColumnDecl::parse(&column.name())
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
let ColumnOverride { nullability, type_ } = decl.r#override;
let nullable = match nullability {
ColumnNullabilityOverride::NonNull => false,
ColumnNullabilityOverride::Nullable => true,
ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true),
};
let type_ = match (type_, nullable) {
(ColumnTypeOverride::Exact(type_), false) => {
ColumnType::Exact(type_.to_token_stream())
}
(ColumnTypeOverride::Exact(type_), true) => {
ColumnType::Exact(quote! { Option<#type_> })
}
(ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard,
(ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard,
(ColumnTypeOverride::None, _) => {
let type_ = get_column_type::<DB>(i, column);
if !nullable {
ColumnType::Exact(type_)
} else {
ColumnType::Exact(quote! { Option<#type_> })
}
}
};
Ok(RustColumn {
ident: decl.ident,
type_,
})
})
(0..describe.columns().len())
.map(|i| column_to_rust(describe, i))
.collect::<crate::Result<Vec<_>>>()
}
fn column_to_rust<DB: DatabaseExt>(describe: &Describe<DB>, i: usize) -> crate::Result<RustColumn> {
let column = &describe.columns()[i];
// add raw prefix to all identifiers
let decl = ColumnDecl::parse(&column.name())
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
let ColumnOverride { nullability, type_ } = decl.r#override;
let nullable = match nullability {
ColumnNullabilityOverride::NonNull => false,
ColumnNullabilityOverride::Nullable => true,
ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true),
};
let type_ = match (type_, nullable) {
(ColumnTypeOverride::Exact(type_), false) => ColumnType::Exact(type_.to_token_stream()),
(ColumnTypeOverride::Exact(type_), true) => ColumnType::Exact(quote! { Option<#type_> }),
(ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard,
(ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard,
(ColumnTypeOverride::None, _) => {
let type_ = get_column_type::<DB>(i, column);
if !nullable {
ColumnType::Exact(type_)
} else {
ColumnType::Exact(quote! { Option<#type_> })
}
}
};
Ok(RustColumn {
ident: decl.ident,
type_,
})
}
pub fn quote_query_as<DB: DatabaseExt>(
input: &QueryMacroInput,
out_ty: &Type,
@ -171,6 +168,43 @@ pub fn quote_query_as<DB: DatabaseExt>(
}
}
pub fn quote_query_scalar<DB: DatabaseExt>(
input: &QueryMacroInput,
bind_args: &Ident,
describe: &Describe<DB>,
) -> crate::Result<TokenStream> {
let columns = describe.columns();
if columns.len() != 1 {
return Err(syn::Error::new(
input.src_span,
format!("expected exactly 1 column, got {}", columns.len()),
)
.into());
}
// attempt to parse a column override, otherwise fall back to the inferred type of the column
let ty = if let Ok(rust_col) = column_to_rust(describe, 0) {
rust_col.type_.to_token_stream()
} else if input.checked {
let ty = get_column_type::<DB>(0, &columns[0]);
if describe.nullable(0).unwrap_or(true) {
quote! { Option<#ty> }
} else {
ty
}
} else {
quote! { _ }
};
let db = DB::db_path();
let query = &input.src;
Ok(quote! {
sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args)
})
}
fn get_column_type<DB: DatabaseExt>(i: usize, column: &DB::Column) -> TokenStream {
let type_info = &*column.type_info();
@ -268,7 +302,9 @@ fn parse_ident(name: &str) -> crate::Result<Ident> {
// workaround for the following issue (it's semi-fixed but still spits out extra diagnostics)
// https://github.com/dtolnay/syn/issues/749#issuecomment-575451318
let is_valid_ident = name.chars().all(|c| c.is_alphanumeric() || c == '_');
let is_valid_ident = !name.is_empty() &&
name.starts_with(|c: char| c.is_alphabetic() || c == '_') &&
name.chars().all(|c| c.is_alphanumeric() || c == '_');
if is_valid_ident {
let ident = String::from("r#") + name;

View File

@ -615,6 +615,77 @@ macro_rules! query_file_as_unchecked (
})
);
/// A variant of [query!] which expects a single column from the query and evaluates to an
/// instance of [QueryScalar][crate::query::QueryScalar].
///
/// The name of the column is not required to be a valid Rust identifier, however you can still
/// use the column type override syntax in which case the column name _does_ have to be a valid
/// Rust identifier for the override to parse properly. If the override parse fails the error
/// is silently ignored (we just don't have a reliable way to tell the difference). **If you're
/// getting a different type than expected, please check to see if your override syntax is correct
/// before opening an issue.**
///
/// Wildcard overrides like in [query_as!] are also allowed, in which case the output type
/// is left up to inference.
///
/// See [query!] for more information.
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_scalar (
($query:expr) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query)
);
($query:expr, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query, args = [$($args)*])
)
);
/// A variant of [query_scalar!] which takes a file path like [query_file!].
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file_scalar (
($path:literal) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path)
);
($path:literal, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*])
)
);
/// A variant of [query_scalar!] which does not typecheck bind parameters and leaves the output type
/// to inference. The query itself is still checked that it is syntactically and semantically
/// valid for the database, that it only produces one column and that the number of bind parameters
/// is correct.
///
/// For this macro variant the name of the column is irrelevant.
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_scalar_unchecked (
($query:expr) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query, checked = false)
);
($query:expr, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source = $query, args = [$($args)*], checked = false)
)
);
/// A variant of [query_file_scalar!] which does not typecheck bind parameters and leaves the output
/// type to inference. The query itself is still checked that it is syntactically and
/// semantically valid for the database, that it only produces one column and that the number of
/// bind parameters is correct.
///
/// For this macro variant the name of the column is irrelevant.
#[macro_export]
#[cfg_attr(docsrs, doc(cfg(feature = "macros")))]
macro_rules! query_file_scalar_unchecked (
($path:literal) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, checked = false)
);
($path:literal, $($args:tt)*) => (
$crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*], checked = false)
)
);
/// Embeds migrations into the binary by expanding to a static instance of [Migrator][crate::migrate::Migrator].
///
/// ```rust,ignore

View File

@ -58,6 +58,72 @@ async fn test_query_as_raw() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn test_query_scalar() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?;
// MySQL tells us `LONG LONG` while MariaDB just `LONG`
assert_eq!(id, 1);
// invalid column names are ignored
let id = sqlx::query_scalar!(r#"select 1 as `&foo`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1);
let id = sqlx::query_scalar!(r#"select 1 as `foo!`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1);
let id = sqlx::query_scalar!(r#"select 1 as `foo?`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1));
let id = sqlx::query_scalar!(r#"select 1 as `foo: MyInt`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
let id = sqlx::query_scalar!(r#"select 1 as `foo?: MyInt`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt(1)));
let id = sqlx::query_scalar!(r#"select 1 as `foo!: MyInt`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo: _`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo?: _`"#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt>`
.unwrap();
assert_eq!(id, MyInt(1));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo!: _`"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_query_as_bool() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

View File

@ -174,6 +174,74 @@ async fn test_query_file_as() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn test_query_scalar() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?;
// nullability inference can't handle expressions
assert_eq!(id, Some(1i32));
// invalid column names are ignored
let id = sqlx::query_scalar!(r#"select 1 as "&foo""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1i32));
let id = sqlx::query_scalar!(r#"select 1 as "foo!""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1i32);
let id = sqlx::query_scalar!(r#"select 1 as "foo?""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1i32));
let id = sqlx::query_scalar!(r#"select 1 as "foo: MyInt4""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt4(1i32)));
let id = sqlx::query_scalar!(r#"select 1 as "foo?: MyInt4""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt4(1i32)));
let id = sqlx::query_scalar!(r#"select 1 as "foo!: MyInt4""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt4(1i32));
let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo: _""#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt4>`
.unwrap();
assert_eq!(id, MyInt4(1i32));
let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo?: _""#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt4>`
.unwrap();
assert_eq!(id, MyInt4(1i32));
let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo!: _""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt4(1i32));
Ok(())
}
#[sqlx_macros::test]
async fn query_by_string() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;

View File

@ -110,6 +110,71 @@ async fn test_query_as_raw() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn test_query_scalar() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?;
assert_eq!(id, 1i32);
// invalid column names are ignored
let id = sqlx::query_scalar!(r#"select 1 as "&foo""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1i32);
let id = sqlx::query_scalar!(r#"select 1 as "foo!""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, 1i32);
let id = sqlx::query_scalar!(r#"select 1 as "foo?""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(1i32));
let id = sqlx::query_scalar!(r#"select 1 as "foo: MyInt""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
let id = sqlx::query_scalar!(r#"select 1 as "foo?: MyInt""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, Some(MyInt(1i64)));
let id = sqlx::query_scalar!(r#"select 1 as "foo!: MyInt""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo: _""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo?: _""#)
.fetch_one(&mut conn)
.await?
// don't hint that it should be `Option<MyInt>`
.unwrap();
assert_eq!(id, MyInt(1i64));
let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo!: _""#)
.fetch_one(&mut conn)
.await?;
assert_eq!(id, MyInt(1i64));
Ok(())
}
#[sqlx_macros::test]
async fn macro_select_from_view() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;