From bb933decb72c2edce0cc820a734c8fa6b84aaff8 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 18 Jan 2020 13:30:16 +0100 Subject: [PATCH] add derives for Encode and Decode --- sqlx-macros/src/derives.rs | 110 +++++++++++++++++++++++++++++++++++++ sqlx-macros/src/lib.rs | 20 +++++++ src/decode.rs | 4 ++ src/encode.rs | 4 ++ src/lib.rs | 6 +- tests/derives.rs | 27 +++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 sqlx-macros/src/derives.rs create mode 100644 src/decode.rs create mode 100644 src/encode.rs create mode 100644 tests/derives.rs diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs new file mode 100644 index 00000000..07e2fa8d --- /dev/null +++ b/sqlx-macros/src/derives.rs @@ -0,0 +1,110 @@ +use proc_macro2::Span; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::{Gt, Lt, Where}; +use syn::{ + parse_quote, Data, DataStruct, DeriveInput, Fields, PredicateType, Token, WhereClause, + WherePredicate, +}; + +pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { + if let Data::Struct(DataStruct { + fields: Fields::Unnamed(fields), + .. + }) = &input.data + { + let fields = &fields.unnamed; + if fields.len() == 1 { + let ident = &input.ident; + let ty = &fields.first().unwrap().ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } +} + +pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result { + if let Data::Struct(DataStruct { + fields: Fields::Unnamed(fields), + .. + }) = &input.data + { + let fields = &fields.unnamed; + if fields.len() == 1 { + let ident = &input.ident; + let ty = &fields.first().unwrap().ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 3df0d784..37eed8b0 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -21,6 +21,8 @@ type Result = std::result::Result; mod database; +mod derives; + mod query_macros; use query_macros::*; @@ -136,3 +138,21 @@ pub fn query_as(input: TokenStream) -> TokenStream { pub fn query_file_as(input: TokenStream) -> TokenStream { async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db)) } + +#[proc_macro_derive(Encode)] +pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_encode(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(Decode)] +pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_decode(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} diff --git a/src/decode.rs b/src/decode.rs new file mode 100644 index 00000000..22ce4554 --- /dev/null +++ b/src/decode.rs @@ -0,0 +1,4 @@ +pub use sqlx_core::decode::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::Decode; diff --git a/src/encode.rs b/src/encode.rs new file mode 100644 index 00000000..0c2b7f41 --- /dev/null +++ b/src/encode.rs @@ -0,0 +1,4 @@ +pub use sqlx_core::encode::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::Encode; diff --git a/src/lib.rs b/src/lib.rs index 5d249e34..fffa92b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled"); // Modules -pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types}; +pub use sqlx_core::{arguments, describe, error, pool, row, types}; // Types pub use sqlx_core::{ @@ -42,3 +42,7 @@ pub mod ty_cons; #[cfg(feature = "macros")] #[doc(hidden)] pub mod result_ext; + +pub mod encode; + +pub mod decode; diff --git a/tests/derives.rs b/tests/derives.rs new file mode 100644 index 00000000..0de3fc6d --- /dev/null +++ b/tests/derives.rs @@ -0,0 +1,27 @@ +#[test] +#[cfg(feature = "macros")] +fn encode() { + use sqlx::encode::Encode; + + #[derive(Encode)] + struct Foo(i32); + + #[cfg(feature = "postgres")] + let _: Box> = Box::new(Foo(1)); + #[cfg(feature = "mysql")] + let _: Box> = Box::new(Foo(1)); +} + +#[test] +#[cfg(feature = "macros")] +fn decode() { + use sqlx::decode::Decode; + + #[derive(Decode)] + struct Foo(i32); + + #[cfg(feature = "postgres")] + let _: Box> = Box::new(Foo(1)); + #[cfg(feature = "mysql")] + let _: Box> = Box::new(Foo(1)); +}