From 7cdb68be1a436241ec16d8200c565a6fd656d4d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20OIRY?= <46460001+TheoOiry@users.noreply.github.com> Date: Tue, 12 Jul 2022 23:28:07 +0200 Subject: [PATCH] support flatten attribute in FromRow macro (#1959) * support flatten attribute in FromRow macro * added docs for flatten FromRow attribute --- sqlx-core/src/from_row.rs | 30 ++++++++++++ sqlx-macros/src/derives/attributes.rs | 9 +++- sqlx-macros/src/derives/row.rs | 69 ++++++++++++++------------- tests/postgres/derives.rs | 41 ++++++++++++++++ 4 files changed, 115 insertions(+), 34 deletions(-) diff --git a/sqlx-core/src/from_row.rs b/sqlx-core/src/from_row.rs index 6f05d008..a71b40af 100644 --- a/sqlx-core/src/from_row.rs +++ b/sqlx-core/src/from_row.rs @@ -92,6 +92,36 @@ use crate::row::Row; /// will set the value of the field `location` to the default value of `Option`, /// which is `None`. /// +/// ### `flatten` +/// +/// If you want to handle a field that implements [`FromRow`], +/// you can use the `flatten` attribute to specify that you want +/// it to use [`FromRow`] for parsing rather than the usual method. +/// For example: +/// +/// ```rust,ignore +/// #[derive(sqlx::FromRow)] +/// struct Address { +/// country: String, +/// city: String, +/// road: String, +/// } +/// +/// #[derive(sqlx::FromRow)] +/// struct User { +/// id: i32, +/// name: String, +/// #[sqlx(flatten)] +/// address: Address, +/// } +/// ``` +/// Given a query such as: +/// +/// ```sql +/// SELECT id, name, country, city, road FROM users; +/// ``` +/// +/// This field is compatible with the `default` attribute. pub trait FromRow<'r, R: Row>: Sized { fn from_row(row: &'r R) -> Result; } diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index 202b6b5a..46d7aef8 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -70,6 +70,7 @@ pub struct SqlxContainerAttributes { pub struct SqlxChildAttributes { pub rename: Option, pub default: bool, + pub flatten: bool, } pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { @@ -177,6 +178,7 @@ pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result syn::Result { let mut rename = None; let mut default = false; + let mut flatten = false; for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) { let meta = attr @@ -193,6 +195,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result try_set!(rename, val.value(), value), Meta::Path(path) if path.is_ident("default") => default = true, + Meta::Path(path) if path.is_ident("flatten") => flatten = true, u => fail!(u, "unexpected attribute"), }, u => fail!(u, "unexpected attribute"), @@ -201,7 +204,11 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result)); - for field in fields { - let ty = &field.ty; - - predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); - predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); - } - - let (impl_generics, _, where_clause) = generics.split_for_impl(); - let container_attributes = parse_container_attributes(&input.attrs)?; - let reads = fields.iter().filter_map(|field| -> Option { - let id = &field.ident.as_ref()?; - let attributes = parse_child_attributes(&field.attrs).unwrap(); - let id_s = attributes - .rename - .or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned())) - .map(|s| match container_attributes.rename_all { - Some(pattern) => rename_all(&s, pattern), - None => s, - }) - .unwrap(); + let reads: Vec = fields + .iter() + .filter_map(|field| -> Option { + let id = &field.ident.as_ref()?; + let attributes = parse_child_attributes(&field.attrs).unwrap(); + let ty = &field.ty; - let ty = &field.ty; + let expr: Expr = if attributes.flatten { + predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>)); + parse_quote!(#ty::from_row(row)) + } else { + predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); + predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); - if attributes.default { - Some( - parse_quote!(let #id: #ty = row.try_get(#id_s).or_else(|e| match e { + let id_s = attributes + .rename + .or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned())) + .map(|s| match container_attributes.rename_all { + Some(pattern) => rename_all(&s, pattern), + None => s, + }) + .unwrap(); + parse_quote!(row.try_get(#id_s)) + }; + + if attributes.default { + Some(parse_quote!(let #id: #ty = #expr.or_else(|e| match e { ::sqlx::Error::ColumnNotFound(_) => { ::std::result::Result::Ok(Default::default()) }, e => ::std::result::Result::Err(e) - })?;), - ) - } else { - Some(parse_quote!( - let #id: #ty = row.try_get(#id_s)?; - )) - } - }); + })?;)) + } else { + Some(parse_quote!( + let #id: #ty = #expr?; + )) + } + }) + .collect(); + + let (impl_generics, _, where_clause) = generics.split_for_impl(); let names = fields.iter().map(|field| &field.ident); diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index 23ac4187..3eedc2ab 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -573,3 +573,44 @@ async fn test_default() -> anyhow::Result<()> { Ok(()) } + +#[cfg(feature = "macros")] +#[sqlx_macros::test] +async fn test_flatten() -> anyhow::Result<()> { + #[derive(Debug, Default, sqlx::FromRow)] + struct AccountDefault { + default: Option, + } + + #[derive(Debug, sqlx::FromRow)] + struct UserInfo { + name: String, + surname: String, + } + + #[derive(Debug, sqlx::FromRow)] + struct AccountKeyword { + id: i32, + #[sqlx(flatten)] + info: UserInfo, + #[sqlx(default)] + #[sqlx(flatten)] + default: AccountDefault, + } + + let mut conn = new::().await?; + + let account: AccountKeyword = sqlx::query_as( + r#"SELECT * from (VALUES (1, 'foo', 'bar')) accounts("id", "name", "surname")"#, + ) + .fetch_one(&mut conn) + .await?; + println!("{:?}", account); + + assert_eq!(1, account.id); + assert_eq!("foo", account.info.name); + assert_eq!("bar", account.info.surname); + assert_eq!(None, account.default.default); + + Ok(()) +}