mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-03 07:45:30 +00:00
feat(macros): type override annotations for columns
This commit is contained in:
parent
efc4df3eea
commit
eb831382e5
@ -208,6 +208,15 @@ where
|
|||||||
RecordType::Generated => {
|
RecordType::Generated => {
|
||||||
let record_name: Type = syn::parse_str("Record").unwrap();
|
let record_name: Type = syn::parse_str("Record").unwrap();
|
||||||
|
|
||||||
|
for rust_col in &columns {
|
||||||
|
if rust_col.type_.is_none() {
|
||||||
|
return Err(
|
||||||
|
"columns may not have wildcard overrides in `query!()` or `query_as!()"
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let record_fields = columns.iter().map(
|
let record_fields = columns.iter().map(
|
||||||
|&output::RustColumn {
|
|&output::RustColumn {
|
||||||
ref ident,
|
ref ident,
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
use proc_macro2::{Ident, Span, TokenStream};
|
use proc_macro2::{Ident, Span, TokenStream};
|
||||||
use quote::quote;
|
use quote::{quote, ToTokens};
|
||||||
use syn::Type;
|
use syn::Type;
|
||||||
|
|
||||||
use sqlx_core::describe::Describe;
|
use sqlx_core::describe::{Column, Describe};
|
||||||
|
|
||||||
use crate::database::DatabaseExt;
|
use crate::database::DatabaseExt;
|
||||||
|
|
||||||
use crate::query::QueryMacroInput;
|
use crate::query::QueryMacroInput;
|
||||||
use std::fmt::{self, Display, Formatter};
|
use std::fmt::{self, Display, Formatter};
|
||||||
|
use syn::parse::{Parse, ParseStream};
|
||||||
|
use syn::Token;
|
||||||
|
|
||||||
pub struct RustColumn {
|
pub struct RustColumn {
|
||||||
pub(super) ident: Ident,
|
pub(super) ident: Ident,
|
||||||
pub(super) type_: TokenStream,
|
pub(super) type_: Option<TokenStream>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DisplayColumn<'a> {
|
struct DisplayColumn<'a> {
|
||||||
@ -20,6 +22,18 @@ struct DisplayColumn<'a> {
|
|||||||
name: &'a str,
|
name: &'a str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ColumnDecl {
|
||||||
|
ident: Ident,
|
||||||
|
// TIL Rust still has OOP keywords like `abstract`, `final`, `override` and `virtual` reserved
|
||||||
|
r#override: Option<ColumnOverride>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ColumnOverride {
|
||||||
|
NonNull,
|
||||||
|
Wildcard,
|
||||||
|
Exact(Type),
|
||||||
|
}
|
||||||
|
|
||||||
impl Display for DisplayColumn<'_> {
|
impl Display for DisplayColumn<'_> {
|
||||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||||
write!(f, "column #{} ({:?})", self.idx + 1, self.name)
|
write!(f, "column #{} ({:?})", self.idx + 1, self.name)
|
||||||
@ -32,58 +46,29 @@ pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Resul
|
|||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, column)| -> crate::Result<_> {
|
.map(|(i, column)| -> crate::Result<_> {
|
||||||
let name = &*column.name;
|
// add raw prefix to all identifiers
|
||||||
let ident = parse_ident(name)?;
|
let decl = ColumnDecl::parse(&column.name)
|
||||||
|
.map_err(|e| format!("column name {:?} is invalid: {}", column.name, e))?;
|
||||||
|
|
||||||
let mut type_ = if let Some(type_info) = &column.type_info {
|
let type_ = match decl.r#override {
|
||||||
<DB as DatabaseExt>::return_type_for_id(&type_info).map_or_else(
|
Some(ColumnOverride::Exact(ty)) => Some(ty.to_token_stream()),
|
||||||
|| {
|
Some(ColumnOverride::Wildcard) => None,
|
||||||
let message = if let Some(feature_gate) =
|
Some(ColumnOverride::NonNull) => Some(get_column_type(i, column)),
|
||||||
<DB as DatabaseExt>::get_feature_gate(&type_info)
|
None => {
|
||||||
{
|
let type_ = get_column_type(i, column);
|
||||||
format!(
|
|
||||||
"optional feature `{feat}` required for type {ty} of {col}",
|
if !column.not_null.unwrap_or(false) {
|
||||||
ty = &type_info,
|
Some(quote! { Option<#type_> })
|
||||||
feat = feature_gate,
|
} else {
|
||||||
col = DisplayColumn {
|
Some(type_)
|
||||||
idx: i,
|
}
|
||||||
name: &*column.name
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
format!(
|
|
||||||
"unsupported type {ty} of {col}",
|
|
||||||
ty = type_info,
|
|
||||||
col = DisplayColumn {
|
|
||||||
idx: i,
|
|
||||||
name: &*column.name
|
|
||||||
}
|
|
||||||
)
|
|
||||||
};
|
|
||||||
syn::Error::new(Span::call_site(), message).to_compile_error()
|
|
||||||
},
|
|
||||||
|t| t.parse().unwrap(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
syn::Error::new(
|
|
||||||
Span::call_site(),
|
|
||||||
format!(
|
|
||||||
"database couldn't tell us the type of {col}; \
|
|
||||||
this can happen for columns that are the result of an expression",
|
|
||||||
col = DisplayColumn {
|
|
||||||
idx: i,
|
|
||||||
name: &*column.name
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.to_compile_error()
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if !column.not_null.unwrap_or(false) {
|
Ok(RustColumn {
|
||||||
type_ = quote! { Option<#type_> };
|
ident: decl.ident,
|
||||||
}
|
type_,
|
||||||
|
})
|
||||||
Ok(RustColumn { ident, type_ })
|
|
||||||
})
|
})
|
||||||
.collect::<crate::Result<Vec<_>>>()
|
.collect::<crate::Result<Vec<_>>>()
|
||||||
}
|
}
|
||||||
@ -103,13 +88,15 @@ pub fn quote_query_as<DB: DatabaseExt>(
|
|||||||
..
|
..
|
||||||
},
|
},
|
||||||
)| {
|
)| {
|
||||||
// For "checked" queries, the macro checks these at compile time and using "try_get"
|
match (input.checked, type_) {
|
||||||
// would also perform pointless runtime checks
|
// we guarantee the type is valid so we can skip the runtime check
|
||||||
|
(true, Some(type_)) => quote! {
|
||||||
if input.checked {
|
#ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()?
|
||||||
quote!( #ident: row.try_get_unchecked::<#type_, _>(#i).try_unwrap_optional()? )
|
},
|
||||||
} else {
|
// type was overridden to be a wildcard so we fallback to the runtime check
|
||||||
quote!( #ident: row.try_get_unchecked(#i)? )
|
(true, None) => quote! ( #ident: row.try_get(#i)? ),
|
||||||
|
// macro is the `_unchecked!()` variant so this will die in decoding if it's wrong
|
||||||
|
(false, _) => quote!( #ident: row.try_get_unchecked(#i)? ),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@ -128,6 +115,99 @@ pub fn quote_query_as<DB: DatabaseExt>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_column_type<DB: DatabaseExt>(i: usize, column: &Column<DB>) -> TokenStream {
|
||||||
|
if let Some(type_info) = &column.type_info {
|
||||||
|
<DB as DatabaseExt>::return_type_for_id(&type_info).map_or_else(
|
||||||
|
|| {
|
||||||
|
let message =
|
||||||
|
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&type_info) {
|
||||||
|
format!(
|
||||||
|
"optional feature `{feat}` required for type {ty} of {col}",
|
||||||
|
ty = &type_info,
|
||||||
|
feat = feature_gate,
|
||||||
|
col = DisplayColumn {
|
||||||
|
idx: i,
|
||||||
|
name: &*column.name
|
||||||
|
}
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
format!(
|
||||||
|
"unsupported type {ty} of {col}",
|
||||||
|
ty = type_info,
|
||||||
|
col = DisplayColumn {
|
||||||
|
idx: i,
|
||||||
|
name: &*column.name
|
||||||
|
}
|
||||||
|
)
|
||||||
|
};
|
||||||
|
syn::Error::new(Span::call_site(), message).to_compile_error()
|
||||||
|
},
|
||||||
|
|t| t.parse().unwrap(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
syn::Error::new(
|
||||||
|
Span::call_site(),
|
||||||
|
format!(
|
||||||
|
"database couldn't tell us the type of {col}; \
|
||||||
|
this can happen for columns that are the result of an expression",
|
||||||
|
col = DisplayColumn {
|
||||||
|
idx: i,
|
||||||
|
name: &*column.name
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.to_compile_error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ColumnDecl {
|
||||||
|
fn parse(col_name: &str) -> crate::Result<Self> {
|
||||||
|
// find the end of the identifier because we want to use our own logic to parse it
|
||||||
|
// if we tried to feed this into `syn::parse_str()` we might get an un-great error
|
||||||
|
// for some kinds of invalid identifiers
|
||||||
|
let (ident, remainder) = if let Some(i) = col_name.find(&[':', '!'][..]) {
|
||||||
|
let (ident, remainder) = col_name.split_at(i);
|
||||||
|
|
||||||
|
(parse_ident(ident)?, remainder)
|
||||||
|
} else {
|
||||||
|
(parse_ident(col_name)?, "")
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ColumnDecl {
|
||||||
|
ident,
|
||||||
|
r#override: if !remainder.is_empty() {
|
||||||
|
Some(syn::parse_str(remainder)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Parse for ColumnOverride {
|
||||||
|
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||||
|
let lookahead = input.lookahead1();
|
||||||
|
|
||||||
|
if lookahead.peek(Token![:]) {
|
||||||
|
input.parse::<Token![:]>()?;
|
||||||
|
|
||||||
|
let ty = Type::parse(input)?;
|
||||||
|
|
||||||
|
if let Type::Infer(_) = ty {
|
||||||
|
Ok(ColumnOverride::Wildcard)
|
||||||
|
} else {
|
||||||
|
Ok(ColumnOverride::Exact(ty))
|
||||||
|
}
|
||||||
|
} else if lookahead.peek(Token![!]) {
|
||||||
|
input.parse::<Token![!]>()?;
|
||||||
|
|
||||||
|
Ok(ColumnOverride::NonNull)
|
||||||
|
} else {
|
||||||
|
Err(lookahead.error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_ident(name: &str) -> crate::Result<Ident> {
|
fn parse_ident(name: &str) -> crate::Result<Ident> {
|
||||||
// workaround for the following issue (it's semi-fixed but still spits out extra diagnostics)
|
// workaround for the following issue (it's semi-fixed but still spits out extra diagnostics)
|
||||||
// https://github.com/dtolnay/syn/issues/749#issuecomment-575451318
|
// https://github.com/dtolnay/syn/issues/749#issuecomment-575451318
|
||||||
|
@ -254,3 +254,50 @@ async fn fetch_is_usable_issue_224() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[sqlx_macros::test]
|
||||||
|
async fn test_column_override_not_null() -> anyhow::Result<()> {
|
||||||
|
let mut conn = new::<Postgres>().await?;
|
||||||
|
|
||||||
|
let record = sqlx::query!(r#"select 1 as "id!""#)
|
||||||
|
.fetch_one(&mut conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(record.id, 1);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
|
||||||
|
#[sqlx(transparent)]
|
||||||
|
struct MyInt4(i32);
|
||||||
|
|
||||||
|
#[sqlx_macros::test]
|
||||||
|
async fn test_column_override_wildcard() -> anyhow::Result<()> {
|
||||||
|
struct Record {
|
||||||
|
id: MyInt4,
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut conn = new::<Postgres>().await?;
|
||||||
|
|
||||||
|
let record = sqlx::query_as!(Record, r#"select 1 as "id: _""#)
|
||||||
|
.fetch_one(&mut conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(record.id, MyInt4(1));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx_macros::test]
|
||||||
|
async fn test_column_override_exact() -> anyhow::Result<()> {
|
||||||
|
let mut conn = new::<Postgres>().await?;
|
||||||
|
|
||||||
|
let record = sqlx::query!(r#"select 1 as "id: MyInt4""#)
|
||||||
|
.fetch_one(&mut conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert_eq!(record.id, MyInt4(1));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user