From 5ada3f3ae65ec188ce202e1426d6a6e905f6ad33 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 23 Mar 2020 21:18:03 -0700 Subject: [PATCH] Implement #[derive(FromRow)] --- sqlx-macros/src/derives/mod.rs | 2 + sqlx-macros/src/derives/row.rs | 98 ++++++++++++++++++++++++++++++++++ sqlx-macros/src/lib.rs | 10 ++++ src/lib.rs | 2 +- 4 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 sqlx-macros/src/derives/row.rs diff --git a/sqlx-macros/src/derives/mod.rs b/sqlx-macros/src/derives/mod.rs index 2d8208f2..2613d1e2 100644 --- a/sqlx-macros/src/derives/mod.rs +++ b/sqlx-macros/src/derives/mod.rs @@ -1,11 +1,13 @@ mod attributes; mod decode; mod encode; +mod row; mod r#type; pub(crate) use decode::expand_derive_decode; pub(crate) use encode::expand_derive_encode; pub(crate) use r#type::expand_derive_type; +pub(crate) use row::expand_derive_from_row; use self::attributes::RenameAll; use std::iter::FromIterator; diff --git a/sqlx-macros/src/derives/row.rs b/sqlx-macros/src/derives/row.rs new file mode 100644 index 00000000..ffd0f0fd --- /dev/null +++ b/sqlx-macros/src/derives/row.rs @@ -0,0 +1,98 @@ +use proc_macro2::Span; +use quote::quote; +use syn::{ + parse_quote, punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput, Field, + Fields, FieldsNamed, Lifetime, LifetimeDef, Stmt, +}; + +pub fn expand_derive_from_row(input: &DeriveInput) -> syn::Result { + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Named(FieldsNamed { named, .. }), + .. + }) => expand_derive_from_row_struct(input, named), + + Data::Struct(DataStruct { + fields: Fields::Unnamed(_), + .. + }) => Err(syn::Error::new_spanned( + input, + "tuple structs are not supported", + )), + + Data::Struct(DataStruct { + fields: Fields::Unit, + .. + }) => Err(syn::Error::new_spanned( + input, + "unit structs are not supported", + )), + + Data::Enum(_) => Err(syn::Error::new_spanned(input, "enums are not supported")), + + Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), + } +} + +fn expand_derive_from_row_struct( + input: &DeriveInput, + fields: &Punctuated, +) -> syn::Result { + let ident = &input.ident; + + let generics = &input.generics; + + let (lifetime, provided) = generics + .lifetimes() + .next() + .map(|def| (def.lifetime.clone(), false)) + .unwrap_or_else(|| (Lifetime::new("'a", Span::call_site()), true)); + + let (_, ty_generics, _) = generics.split_for_impl(); + + let mut generics = generics.clone(); + generics + .params + .insert(0, parse_quote!(R: sqlx::Row<#lifetime>)); + + if provided { + generics.params.insert(0, parse_quote!(#lifetime)); + } + + let predicates = &mut generics.make_where_clause().predicates; + + predicates.push(parse_quote!(&#lifetime str: sqlx::row::ColumnIndex<#lifetime, R>)); + + for field in fields { + let ty = &field.ty; + + predicates.push(parse_quote!(#ty: sqlx::decode::Decode<#lifetime, R::Database>)); + predicates.push(parse_quote!(#ty: sqlx::types::Type)); + } + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let reads = fields.iter().filter_map(|field| -> Option { + let id = &field.ident.as_ref()?; + let id_s = id.to_string(); + let ty = &field.ty; + + Some(parse_quote!( + let #id: #ty = row.try_get(#id_s)?; + )) + }); + + let names = fields.iter().map(|field| &field.ident); + + Ok(quote!( + impl #impl_generics sqlx::row::FromRow<#lifetime, R> for #ident #ty_generics #where_clause { + fn from_row(row: R) -> sqlx::Result { + #(#reads)* + + Ok(#ident { + #(#names),* + }) + } + } + )) +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index d0a19995..0ca9279a 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -177,3 +177,13 @@ pub fn derive_type(tokenstream: TokenStream) -> TokenStream { Err(e) => e.to_compile_error().into(), } } + +#[proc_macro_derive(FromRow)] +pub fn derive_from_row(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as syn::DeriveInput); + + match derives::expand_derive_from_row(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} diff --git a/src/lib.rs b/src/lib.rs index 89919b25..659be123 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool}; pub extern crate sqlx_macros; #[cfg(feature = "macros")] -pub use sqlx_macros::Type; +pub use sqlx_macros::{FromRow, Type}; #[cfg(feature = "macros")] mod macros;