From bb933decb72c2edce0cc820a734c8fa6b84aaff8 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 18 Jan 2020 13:30:16 +0100 Subject: [PATCH 1/7] add derives for Encode and Decode --- sqlx-macros/src/derives.rs | 110 +++++++++++++++++++++++++++++++++++++ sqlx-macros/src/lib.rs | 20 +++++++ src/decode.rs | 4 ++ src/encode.rs | 4 ++ src/lib.rs | 6 +- tests/derives.rs | 27 +++++++++ 6 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 sqlx-macros/src/derives.rs create mode 100644 src/decode.rs create mode 100644 src/encode.rs create mode 100644 tests/derives.rs diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs new file mode 100644 index 00000000..07e2fa8d --- /dev/null +++ b/sqlx-macros/src/derives.rs @@ -0,0 +1,110 @@ +use proc_macro2::Span; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::{Gt, Lt, Where}; +use syn::{ + parse_quote, Data, DataStruct, DeriveInput, Fields, PredicateType, Token, WhereClause, + WherePredicate, +}; + +pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { + if let Data::Struct(DataStruct { + fields: Fields::Unnamed(fields), + .. + }) = &input.data + { + let fields = &fields.unnamed; + if fields.len() == 1 { + let ident = &input.ident; + let ty = &fields.first().unwrap().ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } +} + +pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result { + if let Data::Struct(DataStruct { + fields: Fields::Unnamed(fields), + .. + }) = &input.data + { + let fields = &fields.unnamed; + if fields.len() == 1 { + let ident = &input.ident; + let ty = &fields.first().unwrap().ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } + } else { + Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )) + } +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 3df0d784..37eed8b0 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -21,6 +21,8 @@ type Result = std::result::Result; mod database; +mod derives; + mod query_macros; use query_macros::*; @@ -136,3 +138,21 @@ pub fn query_as(input: TokenStream) -> TokenStream { pub fn query_file_as(input: TokenStream) -> TokenStream { async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db)) } + +#[proc_macro_derive(Encode)] +pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_encode(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(Decode)] +pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_decode(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} diff --git a/src/decode.rs b/src/decode.rs new file mode 100644 index 00000000..22ce4554 --- /dev/null +++ b/src/decode.rs @@ -0,0 +1,4 @@ +pub use sqlx_core::decode::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::Decode; diff --git a/src/encode.rs b/src/encode.rs new file mode 100644 index 00000000..0c2b7f41 --- /dev/null +++ b/src/encode.rs @@ -0,0 +1,4 @@ +pub use sqlx_core::encode::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::Encode; diff --git a/src/lib.rs b/src/lib.rs index 5d249e34..fffa92b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled"); // Modules -pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types}; +pub use sqlx_core::{arguments, describe, error, pool, row, types}; // Types pub use sqlx_core::{ @@ -42,3 +42,7 @@ pub mod ty_cons; #[cfg(feature = "macros")] #[doc(hidden)] pub mod result_ext; + +pub mod encode; + +pub mod decode; diff --git a/tests/derives.rs b/tests/derives.rs new file mode 100644 index 00000000..0de3fc6d --- /dev/null +++ b/tests/derives.rs @@ -0,0 +1,27 @@ +#[test] +#[cfg(feature = "macros")] +fn encode() { + use sqlx::encode::Encode; + + #[derive(Encode)] + struct Foo(i32); + + #[cfg(feature = "postgres")] + let _: Box> = Box::new(Foo(1)); + #[cfg(feature = "mysql")] + let _: Box> = Box::new(Foo(1)); +} + +#[test] +#[cfg(feature = "macros")] +fn decode() { + use sqlx::decode::Decode; + + #[derive(Decode)] + struct Foo(i32); + + #[cfg(feature = "postgres")] + let _: Box> = Box::new(Foo(1)); + #[cfg(feature = "mysql")] + let _: Box> = Box::new(Foo(1)); +} From 24007b143a1f4fccb324e924703956475548b51a Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Sat, 18 Jan 2020 13:39:49 +0100 Subject: [PATCH 2/7] fix tests for decode derive --- tests/derives.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/derives.rs b/tests/derives.rs index 0de3fc6d..e9b36e0a 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -21,7 +21,7 @@ fn decode() { struct Foo(i32); #[cfg(feature = "postgres")] - let _: Box> = Box::new(Foo(1)); + >::decode_null().ok(); #[cfg(feature = "mysql")] - let _: Box> = Box::new(Foo(1)); + >::decode_null().ok(); } From 40a0e113b70fe0f88fb49da2a2c78a7d66d83158 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Fri, 24 Jan 2020 18:34:57 +0100 Subject: [PATCH 3/7] copy doc comment --- src/decode.rs | 2 ++ src/encode.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/decode.rs b/src/decode.rs index 22ce4554..dcc99713 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,3 +1,5 @@ +//! Types and traits for decoding values from the database. + pub use sqlx_core::decode::*; #[cfg(feature = "macros")] diff --git a/src/encode.rs b/src/encode.rs index 0c2b7f41..d25c971c 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -1,3 +1,5 @@ +//! Types and traits for encoding values to the database. + pub use sqlx_core::encode::*; #[cfg(feature = "macros")] From 60ef8627134989eb64b11106538f51ac8160cf91 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Fri, 24 Jan 2020 18:35:25 +0100 Subject: [PATCH 4/7] lift if to pattern match with pattern guard --- sqlx-macros/src/derives.rs | 52 +++++++++++++------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index 07e2fa8d..c4e5a7e3 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -3,20 +3,18 @@ use quote::quote; use syn::punctuated::Punctuated; use syn::token::{Gt, Lt, Where}; use syn::{ - parse_quote, Data, DataStruct, DeriveInput, Fields, PredicateType, Token, WhereClause, - WherePredicate, + parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed, PredicateType, Token, + WhereClause, WherePredicate, }; pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { - if let Data::Struct(DataStruct { - fields: Fields::Unnamed(fields), - .. - }) = &input.data - { - let fields = &fields.unnamed; - if fields.len() == 1 { + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { let ident = &input.ident; - let ty = &fields.first().unwrap().ty; + let ty = &unnamed.first().unwrap().ty; // extract type generics let generics = &input.generics; @@ -44,30 +42,22 @@ pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result Err(syn::Error::new_spanned( input, "expected a tuple struct with a single field", - )) + )), } } pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result { - if let Data::Struct(DataStruct { - fields: Fields::Unnamed(fields), - .. - }) = &input.data - { - let fields = &fields.unnamed; - if fields.len() == 1 { + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { let ident = &input.ident; - let ty = &fields.first().unwrap().ty; + let ty = &unnamed.first().unwrap().ty; // extract type generics let generics = &input.generics; @@ -95,16 +85,10 @@ pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result Err(syn::Error::new_spanned( input, "expected a tuple struct with a single field", - )) + )), } } From 1e9e816fab374c8e070defbac5d7e8c07d1b9db0 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Fri, 24 Jan 2020 18:36:04 +0100 Subject: [PATCH 5/7] add proper encode/decode test --- Cargo.toml | 4 +++ tests/derives.rs | 77 ++++++++++++++++++++++++++++++++++-------------- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4889f558..84b7dfd7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,5 +84,9 @@ required-features = [ "mysql" ] name = "mysql-types-chrono" required-features = [ "mysql", "chrono", "macros" ] +[[test]] +name = "derives" +required-features = [ "macros" ] + [profile.release] lto = true diff --git a/tests/derives.rs b/tests/derives.rs index e9b36e0a..9c7aa411 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -1,27 +1,60 @@ +use sqlx::decode::Decode; +use sqlx::encode::Encode; + +#[derive(PartialEq, Debug, Encode, Decode)] +struct Foo(i32); + #[test] -#[cfg(feature = "macros")] -fn encode() { - use sqlx::encode::Encode; - - #[derive(Encode)] - struct Foo(i32); - - #[cfg(feature = "postgres")] - let _: Box> = Box::new(Foo(1)); - #[cfg(feature = "mysql")] - let _: Box> = Box::new(Foo(1)); +#[cfg(feature = "mysql")] +fn encode_mysql() { + encode_with_db::(); } #[test] -#[cfg(feature = "macros")] -fn decode() { - use sqlx::decode::Decode; - - #[derive(Decode)] - struct Foo(i32); - - #[cfg(feature = "postgres")] - >::decode_null().ok(); - #[cfg(feature = "mysql")] - >::decode_null().ok(); +#[cfg(feature = "postgres")] +fn encode_postgres() { + encode_with_db::(); +} + +#[allow(dead_code)] +fn encode_with_db() +where + Foo: Encode, + i32: Encode, +{ + let example = Foo(0x1122_3344); + + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(&example, &mut encoded); + Encode::::encode(&example.0, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_mysql() { + decode_with_db::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_postgres() { + decode_with_db::(); +} + +#[allow(dead_code)] +fn decode_with_db() +where + Foo: Decode + Encode, +{ + let example = Foo(0x1122_3344); + + let mut encoded = Vec::new(); + Encode::::encode(&example, &mut encoded); + + let decoded = Foo::decode(&encoded).unwrap(); + assert_eq!(example, decoded); } From a94e60880d78474a0b2b33e24fe53aebacdfdc10 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Fri, 24 Jan 2020 18:52:01 +0100 Subject: [PATCH 6/7] remove unused imports --- sqlx-macros/src/derives.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index c4e5a7e3..acd9ec58 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -1,10 +1,6 @@ -use proc_macro2::Span; use quote::quote; -use syn::punctuated::Punctuated; -use syn::token::{Gt, Lt, Where}; use syn::{ - parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed, PredicateType, Token, - WhereClause, WherePredicate, + parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed, }; pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { From 8ca342932772f50f1a887dbcc54fe567d9f1aba1 Mon Sep 17 00:00:00 2001 From: Tom Dohrmann Date: Fri, 24 Jan 2020 18:54:41 +0100 Subject: [PATCH 7/7] format code --- sqlx-macros/src/derives.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index acd9ec58..460b05c8 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -1,7 +1,5 @@ use quote::quote; -use syn::{ - parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed, -}; +use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed}; pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { match &input.data {