diff --git a/sqlx-core/src/arguments.rs b/sqlx-core/src/arguments.rs index d484f12e..88671762 100644 --- a/sqlx-core/src/arguments.rs +++ b/sqlx-core/src/arguments.rs @@ -39,5 +39,14 @@ macro_rules! impl_into_arguments_for_arguments { }; } +/// used by the query macros to prevent supernumerary `.bind()` calls +pub struct ImmutableArguments<'q, DB: HasArguments<'q>>(pub >::Arguments); + +impl<'q, DB: HasArguments<'q>> IntoArguments<'q, DB> for ImmutableArguments<'q, DB> { + fn into_arguments(self) -> >::Arguments { + self.0 + } +} + // TODO: Impl `IntoArguments` for &[&dyn Encode] // TODO: Impl `IntoArguments` for (impl Encode, ...) x16 diff --git a/sqlx-macros/src/query/input.rs b/sqlx-macros/src/query/input.rs index ddcd6d4a..86627d60 100644 --- a/sqlx-macros/src/query/input.rs +++ b/sqlx-macros/src/query/input.rs @@ -27,6 +27,7 @@ enum QuerySrc { pub enum RecordType { Given(Type), + Scalar, Generated, } @@ -62,7 +63,21 @@ impl Parse for QueryMacroInput { let exprs = input.parse::()?; args = Some(exprs.elems.into_iter().collect()) } else if key == "record" { + if !matches!(record_type, RecordType::Generated) { + return Err(input.error("colliding `scalar` or `record` key")); + } + record_type = RecordType::Given(input.parse()?); + } else if key == "scalar" { + if !matches!(record_type, RecordType::Generated) { + return Err(input.error("colliding `scalar` or `record` key")); + } + + // we currently expect only `scalar = _` + // a `query_as_scalar!()` variant seems less useful than just overriding the type + // of the column in SQL + input.parse::()?; + record_type = RecordType::Scalar; } else if key == "checked" { let lit_bool = input.parse::()?; checked = lit_bool.value; diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 0e06b460..65d956fd 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::env; use std::path::{Path, PathBuf}; use std::process::Command; @@ -231,11 +230,9 @@ where if let Some(num) = num_parameters { if num != input.arg_exprs.len() { - return Err(syn::Error::new( - Span::call_site(), - format!("expected {} parameters, got {}", num, input.arg_exprs.len()), - ) - .into()); + return Err( + format!("expected {} parameters, got {}", num, input.arg_exprs.len()).into(), + ); } } @@ -256,10 +253,10 @@ where sqlx::query_with::<#db_path, _>(#sql, #query_args) } } else { - let columns = output::columns_to_rust::(&data.describe)?; - - let (out_ty, mut record_tokens) = match input.record_type { + match input.record_type { RecordType::Generated => { + let columns = output::columns_to_rust::(&data.describe)?; + let record_name: Type = syn::parse_str("Record").unwrap(); for rust_col in &columns { @@ -278,26 +275,31 @@ where }| quote!(#ident: #type_,), ); - let record_tokens = quote! { + let mut record_tokens = quote! { #[derive(Debug)] struct #record_name { #(#record_fields)* } }; - (Cow::Owned(record_name), record_tokens) + record_tokens.extend(output::quote_query_as::( + &input, + &record_name, + &query_args, + &columns, + )); + + record_tokens } - RecordType::Given(ref out_ty) => (Cow::Borrowed(out_ty), quote!()), - }; + RecordType::Given(ref out_ty) => { + let columns = output::columns_to_rust::(&data.describe)?; - record_tokens.extend(output::quote_query_as::( - &input, - &out_ty, - &query_args, - &columns, - )); - - record_tokens + output::quote_query_as::(&input, out_ty, &query_args, &columns) + } + RecordType::Scalar => { + output::quote_query_scalar::(&input, &query_args, &data.describe)? + } + } }; let ret_tokens = quote! { diff --git a/sqlx-macros/src/query/output.rs b/sqlx-macros/src/query/output.rs index 84f91f6d..821fac35 100644 --- a/sqlx-macros/src/query/output.rs +++ b/sqlx-macros/src/query/output.rs @@ -75,51 +75,48 @@ impl Display for DisplayColumn<'_> { } pub fn columns_to_rust(describe: &Describe) -> crate::Result> { - describe - .columns() - .iter() - .enumerate() - .map(|(i, column)| -> crate::Result<_> { - // add raw prefix to all identifiers - let decl = ColumnDecl::parse(&column.name()) - .map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?; - - let ColumnOverride { nullability, type_ } = decl.r#override; - - let nullable = match nullability { - ColumnNullabilityOverride::NonNull => false, - ColumnNullabilityOverride::Nullable => true, - ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true), - }; - let type_ = match (type_, nullable) { - (ColumnTypeOverride::Exact(type_), false) => { - ColumnType::Exact(type_.to_token_stream()) - } - (ColumnTypeOverride::Exact(type_), true) => { - ColumnType::Exact(quote! { Option<#type_> }) - } - - (ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard, - (ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard, - - (ColumnTypeOverride::None, _) => { - let type_ = get_column_type::(i, column); - if !nullable { - ColumnType::Exact(type_) - } else { - ColumnType::Exact(quote! { Option<#type_> }) - } - } - }; - - Ok(RustColumn { - ident: decl.ident, - type_, - }) - }) + (0..describe.columns().len()) + .map(|i| column_to_rust(describe, i)) .collect::>>() } +fn column_to_rust(describe: &Describe, i: usize) -> crate::Result { + let column = &describe.columns()[i]; + + // add raw prefix to all identifiers + let decl = ColumnDecl::parse(&column.name()) + .map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?; + + let ColumnOverride { nullability, type_ } = decl.r#override; + + let nullable = match nullability { + ColumnNullabilityOverride::NonNull => false, + ColumnNullabilityOverride::Nullable => true, + ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true), + }; + let type_ = match (type_, nullable) { + (ColumnTypeOverride::Exact(type_), false) => ColumnType::Exact(type_.to_token_stream()), + (ColumnTypeOverride::Exact(type_), true) => ColumnType::Exact(quote! { Option<#type_> }), + + (ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard, + (ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard, + + (ColumnTypeOverride::None, _) => { + let type_ = get_column_type::(i, column); + if !nullable { + ColumnType::Exact(type_) + } else { + ColumnType::Exact(quote! { Option<#type_> }) + } + } + }; + + Ok(RustColumn { + ident: decl.ident, + type_, + }) +} + pub fn quote_query_as( input: &QueryMacroInput, out_ty: &Type, @@ -171,6 +168,43 @@ pub fn quote_query_as( } } +pub fn quote_query_scalar( + input: &QueryMacroInput, + bind_args: &Ident, + describe: &Describe, +) -> crate::Result { + let columns = describe.columns(); + + if columns.len() != 1 { + return Err(syn::Error::new( + input.src_span, + format!("expected exactly 1 column, got {}", columns.len()), + ) + .into()); + } + + // attempt to parse a column override, otherwise fall back to the inferred type of the column + let ty = if let Ok(rust_col) = column_to_rust(describe, 0) { + rust_col.type_.to_token_stream() + } else if input.checked { + let ty = get_column_type::(0, &columns[0]); + if describe.nullable(0).unwrap_or(true) { + quote! { Option<#ty> } + } else { + ty + } + } else { + quote! { _ } + }; + + let db = DB::db_path(); + let query = &input.src; + + Ok(quote! { + sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args) + }) +} + fn get_column_type(i: usize, column: &DB::Column) -> TokenStream { let type_info = &*column.type_info(); @@ -268,7 +302,9 @@ fn parse_ident(name: &str) -> crate::Result { // workaround for the following issue (it's semi-fixed but still spits out extra diagnostics) // https://github.com/dtolnay/syn/issues/749#issuecomment-575451318 - let is_valid_ident = name.chars().all(|c| c.is_alphanumeric() || c == '_'); + let is_valid_ident = !name.is_empty() && + name.starts_with(|c: char| c.is_alphabetic() || c == '_') && + name.chars().all(|c| c.is_alphanumeric() || c == '_'); if is_valid_ident { let ident = String::from("r#") + name; diff --git a/src/macros.rs b/src/macros.rs index 7d2f7a5e..433e9ddc 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -615,6 +615,77 @@ macro_rules! query_file_as_unchecked ( }) ); +/// A variant of [query!] which expects a single column from the query and evaluates to an +/// instance of [QueryScalar][crate::query::QueryScalar]. +/// +/// The name of the column is not required to be a valid Rust identifier, however you can still +/// use the column type override syntax in which case the column name _does_ have to be a valid +/// Rust identifier for the override to parse properly. If the override parse fails the error +/// is silently ignored (we just don't have a reliable way to tell the difference). **If you're +/// getting a different type than expected, please check to see if your override syntax is correct +/// before opening an issue.** +/// +/// Wildcard overrides like in [query_as!] are also allowed, in which case the output type +/// is left up to inference. +/// +/// See [query!] for more information. +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_scalar ( + ($query:expr) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source = $query) + ); + ($query:expr, $($args:tt)*) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source = $query, args = [$($args)*]) + ) +); + +/// A variant of [query_scalar!] which takes a file path like [query_file!]. +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file_scalar ( + ($path:literal) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source_file = $path) + ); + ($path:literal, $($args:tt)*) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*]) + ) +); + +/// A variant of [query_scalar!] which does not typecheck bind parameters and leaves the output type +/// to inference. The query itself is still checked that it is syntactically and semantically +/// valid for the database, that it only produces one column and that the number of bind parameters +/// is correct. +/// +/// For this macro variant the name of the column is irrelevant. +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_scalar_unchecked ( + ($query:expr) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source = $query, checked = false) + ); + ($query:expr, $($args:tt)*) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source = $query, args = [$($args)*], checked = false) + ) +); + +/// A variant of [query_file_scalar!] which does not typecheck bind parameters and leaves the output +/// type to inference. The query itself is still checked that it is syntactically and +/// semantically valid for the database, that it only produces one column and that the number of +/// bind parameters is correct. +/// +/// For this macro variant the name of the column is irrelevant. +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +macro_rules! query_file_scalar_unchecked ( + ($path:literal) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, checked = false) + ); + ($path:literal, $($args:tt)*) => ( + $crate::sqlx_macros::expand_query!(scalar = _, source_file = $path, args = [$($args)*], checked = false) + ) +); + /// Embeds migrations into the binary by expanding to a static instance of [Migrator][crate::migrate::Migrator]. /// /// ```rust,ignore diff --git a/tests/mysql/macros.rs b/tests/mysql/macros.rs index 8a3b1714..9b9b436c 100644 --- a/tests/mysql/macros.rs +++ b/tests/mysql/macros.rs @@ -58,6 +58,72 @@ async fn test_query_as_raw() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_query_scalar() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?; + // MySQL tells us `LONG LONG` while MariaDB just `LONG` + assert_eq!(id, 1); + + // invalid column names are ignored + let id = sqlx::query_scalar!(r#"select 1 as `&foo`"#) + .fetch_one(&mut conn) + .await?; + assert_eq!(id, 1); + + let id = sqlx::query_scalar!(r#"select 1 as `foo!`"#) + .fetch_one(&mut conn) + .await?; + assert_eq!(id, 1); + + let id = sqlx::query_scalar!(r#"select 1 as `foo?`"#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, Some(1)); + + let id = sqlx::query_scalar!(r#"select 1 as `foo: MyInt`"#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1)); + + let id = sqlx::query_scalar!(r#"select 1 as `foo?: MyInt`"#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, Some(MyInt(1))); + + let id = sqlx::query_scalar!(r#"select 1 as `foo!: MyInt`"#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1)); + + let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo: _`"#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1)); + + let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo?: _`"#) + .fetch_one(&mut conn) + .await? + // don't hint that it should be `Option` + .unwrap(); + + assert_eq!(id, MyInt(1)); + + let id: MyInt = sqlx::query_scalar!(r#"select 1 as `foo!: _`"#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1)); + + Ok(()) +} + #[sqlx_macros::test] async fn test_query_as_bool() -> anyhow::Result<()> { let mut conn = new::().await?; diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index e2fbc2d7..bc770e05 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -174,6 +174,74 @@ async fn test_query_file_as() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_query_scalar() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?; + // nullability inference can't handle expressions + assert_eq!(id, Some(1i32)); + + // invalid column names are ignored + let id = sqlx::query_scalar!(r#"select 1 as "&foo""#) + .fetch_one(&mut conn) + .await?; + assert_eq!(id, Some(1i32)); + + let id = sqlx::query_scalar!(r#"select 1 as "foo!""#) + .fetch_one(&mut conn) + .await?; + assert_eq!(id, 1i32); + + let id = sqlx::query_scalar!(r#"select 1 as "foo?""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, Some(1i32)); + + let id = sqlx::query_scalar!(r#"select 1 as "foo: MyInt4""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, Some(MyInt4(1i32))); + + let id = sqlx::query_scalar!(r#"select 1 as "foo?: MyInt4""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, Some(MyInt4(1i32))); + + let id = sqlx::query_scalar!(r#"select 1 as "foo!: MyInt4""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt4(1i32)); + + let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo: _""#) + .fetch_one(&mut conn) + .await? + // don't hint that it should be `Option` + .unwrap(); + + assert_eq!(id, MyInt4(1i32)); + + let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo?: _""#) + .fetch_one(&mut conn) + .await? + // don't hint that it should be `Option` + .unwrap(); + + assert_eq!(id, MyInt4(1i32)); + + let id: MyInt4 = sqlx::query_scalar!(r#"select 1 as "foo!: _""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt4(1i32)); + + Ok(()) +} + #[sqlx_macros::test] async fn query_by_string() -> anyhow::Result<()> { let mut conn = new::().await?; diff --git a/tests/sqlite/macros.rs b/tests/sqlite/macros.rs index ea4a657b..0afad758 100644 --- a/tests/sqlite/macros.rs +++ b/tests/sqlite/macros.rs @@ -110,6 +110,71 @@ async fn test_query_as_raw() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_query_scalar() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let id = sqlx::query_scalar!("select 1").fetch_one(&mut conn).await?; + assert_eq!(id, 1i32); + + // invalid column names are ignored + let id = sqlx::query_scalar!(r#"select 1 as "&foo""#) + .fetch_one(&mut conn) + .await?; + assert_eq!(id, 1i32); + + let id = sqlx::query_scalar!(r#"select 1 as "foo!""#) + .fetch_one(&mut conn) + .await?; + assert_eq!(id, 1i32); + + let id = sqlx::query_scalar!(r#"select 1 as "foo?""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, Some(1i32)); + + let id = sqlx::query_scalar!(r#"select 1 as "foo: MyInt""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1i64)); + + let id = sqlx::query_scalar!(r#"select 1 as "foo?: MyInt""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, Some(MyInt(1i64))); + + let id = sqlx::query_scalar!(r#"select 1 as "foo!: MyInt""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1i64)); + + let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo: _""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1i64)); + + let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo?: _""#) + .fetch_one(&mut conn) + .await? + // don't hint that it should be `Option` + .unwrap(); + + assert_eq!(id, MyInt(1i64)); + + let id: MyInt = sqlx::query_scalar!(r#"select 1 as "foo!: _""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(id, MyInt(1i64)); + + Ok(()) +} + #[sqlx_macros::test] async fn macro_select_from_view() -> anyhow::Result<()> { let mut conn = new::().await?;