diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 4e366c0f..d2876eaa 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -208,6 +208,15 @@ where RecordType::Generated => { let record_name: Type = syn::parse_str("Record").unwrap(); + for rust_col in &columns { + if rust_col.type_.is_none() { + return Err( + "columns may not have wildcard overrides in `query!()` or `query_as!()" + .into(), + ); + } + } + let record_fields = columns.iter().map( |&output::RustColumn { ref ident, diff --git a/sqlx-macros/src/query/output.rs b/sqlx-macros/src/query/output.rs index cdc6c738..e93b3867 100644 --- a/sqlx-macros/src/query/output.rs +++ b/sqlx-macros/src/query/output.rs @@ -1,17 +1,19 @@ use proc_macro2::{Ident, Span, TokenStream}; -use quote::quote; +use quote::{quote, ToTokens}; use syn::Type; -use sqlx_core::describe::Describe; +use sqlx_core::describe::{Column, Describe}; use crate::database::DatabaseExt; use crate::query::QueryMacroInput; use std::fmt::{self, Display, Formatter}; +use syn::parse::{Parse, ParseStream}; +use syn::Token; pub struct RustColumn { pub(super) ident: Ident, - pub(super) type_: TokenStream, + pub(super) type_: Option, } struct DisplayColumn<'a> { @@ -20,6 +22,18 @@ struct DisplayColumn<'a> { name: &'a str, } +struct ColumnDecl { + ident: Ident, + // TIL Rust still has OOP keywords like `abstract`, `final`, `override` and `virtual` reserved + r#override: Option, +} + +enum ColumnOverride { + NonNull, + Wildcard, + Exact(Type), +} + impl Display for DisplayColumn<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "column #{} ({:?})", self.idx + 1, self.name) @@ -32,58 +46,29 @@ pub fn columns_to_rust(describe: &Describe) -> crate::Resul .iter() .enumerate() .map(|(i, column)| -> crate::Result<_> { - let name = &*column.name; - let ident = parse_ident(name)?; + // add raw prefix to all identifiers + let decl = ColumnDecl::parse(&column.name) + .map_err(|e| format!("column name {:?} is invalid: {}", column.name, e))?; - let mut type_ = if let Some(type_info) = &column.type_info { - ::return_type_for_id(&type_info).map_or_else( - || { - let message = if let Some(feature_gate) = - ::get_feature_gate(&type_info) - { - format!( - "optional feature `{feat}` required for type {ty} of {col}", - ty = &type_info, - feat = feature_gate, - col = DisplayColumn { - idx: i, - name: &*column.name - } - ) - } else { - format!( - "unsupported type {ty} of {col}", - ty = type_info, - col = DisplayColumn { - idx: i, - name: &*column.name - } - ) - }; - syn::Error::new(Span::call_site(), message).to_compile_error() - }, - |t| t.parse().unwrap(), - ) - } else { - syn::Error::new( - Span::call_site(), - format!( - "database couldn't tell us the type of {col}; \ - this can happen for columns that are the result of an expression", - col = DisplayColumn { - idx: i, - name: &*column.name - } - ), - ) - .to_compile_error() + let type_ = match decl.r#override { + Some(ColumnOverride::Exact(ty)) => Some(ty.to_token_stream()), + Some(ColumnOverride::Wildcard) => None, + Some(ColumnOverride::NonNull) => Some(get_column_type(i, column)), + None => { + let type_ = get_column_type(i, column); + + if !column.not_null.unwrap_or(false) { + Some(quote! { Option<#type_> }) + } else { + Some(type_) + } + } }; - if !column.not_null.unwrap_or(false) { - type_ = quote! { Option<#type_> }; - } - - Ok(RustColumn { ident, type_ }) + Ok(RustColumn { + ident: decl.ident, + type_, + }) }) .collect::>>() } @@ -103,13 +88,15 @@ pub fn quote_query_as( .. }, )| { - // For "checked" queries, the macro checks these at compile time and using "try_get" - // would also perform pointless runtime checks - - if input.checked { - quote!( #ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()? ) - } else { - quote!( #ident: row.try_get_unchecked(#i)? ) + match (input.checked, type_) { + // we guarantee the type is valid so we can skip the runtime check + (true, Some(type_)) => quote! { + #ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()? + }, + // type was overridden to be a wildcard so we fallback to the runtime check + (true, None) => quote! ( #ident: row.try_get(#i)? ), + // macro is the `_unchecked!()` variant so this will die in decoding if it's wrong + (false, _) => quote!( #ident: row.try_get_unchecked(#i)? ), } }, ); @@ -128,6 +115,99 @@ pub fn quote_query_as( } } +fn get_column_type(i: usize, column: &Column) -> TokenStream { + if let Some(type_info) = &column.type_info { + ::return_type_for_id(&type_info).map_or_else( + || { + let message = + if let Some(feature_gate) = ::get_feature_gate(&type_info) { + format!( + "optional feature `{feat}` required for type {ty} of {col}", + ty = &type_info, + feat = feature_gate, + col = DisplayColumn { + idx: i, + name: &*column.name + } + ) + } else { + format!( + "unsupported type {ty} of {col}", + ty = type_info, + col = DisplayColumn { + idx: i, + name: &*column.name + } + ) + }; + syn::Error::new(Span::call_site(), message).to_compile_error() + }, + |t| t.parse().unwrap(), + ) + } else { + syn::Error::new( + Span::call_site(), + format!( + "database couldn't tell us the type of {col}; \ + this can happen for columns that are the result of an expression", + col = DisplayColumn { + idx: i, + name: &*column.name + } + ), + ) + .to_compile_error() + } +} + +impl ColumnDecl { + fn parse(col_name: &str) -> crate::Result { + // find the end of the identifier because we want to use our own logic to parse it + // if we tried to feed this into `syn::parse_str()` we might get an un-great error + // for some kinds of invalid identifiers + let (ident, remainder) = if let Some(i) = col_name.find(&[':', '!'][..]) { + let (ident, remainder) = col_name.split_at(i); + + (parse_ident(ident)?, remainder) + } else { + (parse_ident(col_name)?, "") + }; + + Ok(ColumnDecl { + ident, + r#override: if !remainder.is_empty() { + Some(syn::parse_str(remainder)?) + } else { + None + }, + }) + } +} + +impl Parse for ColumnOverride { + fn parse(input: ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + + if lookahead.peek(Token![:]) { + input.parse::()?; + + let ty = Type::parse(input)?; + + if let Type::Infer(_) = ty { + Ok(ColumnOverride::Wildcard) + } else { + Ok(ColumnOverride::Exact(ty)) + } + } else if lookahead.peek(Token![!]) { + input.parse::()?; + + Ok(ColumnOverride::NonNull) + } else { + Err(lookahead.error()) + } + } +} + fn parse_ident(name: &str) -> crate::Result { // workaround for the following issue (it's semi-fixed but still spits out extra diagnostics) // https://github.com/dtolnay/syn/issues/749#issuecomment-575451318 diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index 5026fe9f..c927bec3 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -254,3 +254,50 @@ async fn fetch_is_usable_issue_224() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn test_column_override_not_null() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let record = sqlx::query!(r#"select 1 as "id!""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(record.id, 1); + + Ok(()) +} + +#[derive(PartialEq, Eq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct MyInt4(i32); + +#[sqlx_macros::test] +async fn test_column_override_wildcard() -> anyhow::Result<()> { + struct Record { + id: MyInt4, + } + + let mut conn = new::().await?; + + let record = sqlx::query_as!(Record, r#"select 1 as "id: _""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(record.id, MyInt4(1)); + + Ok(()) +} + +#[sqlx_macros::test] +async fn test_column_override_exact() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let record = sqlx::query!(r#"select 1 as "id: MyInt4""#) + .fetch_one(&mut conn) + .await?; + + assert_eq!(record.id, MyInt4(1)); + + Ok(()) +}