diff --git a/Cargo.toml b/Cargo.toml index 49a716e7..300f6443 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,5 +85,9 @@ required-features = [ "mysql" ] name = "mysql-types-chrono" required-features = [ "mysql", "chrono", "macros" ] +[[test]] +name = "derives" +required-features = [ "macros" ] + [profile.release] lto = true diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs new file mode 100644 index 00000000..460b05c8 --- /dev/null +++ b/sqlx-macros/src/derives.rs @@ -0,0 +1,88 @@ +use quote::quote; +use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed}; + +pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result { + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + let ident = &input.ident; + let ty = &unnamed.first().unwrap().ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::encode::Encode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::encode::Encode for #ident #ty_generics #where_clause { + fn encode(&self, buf: &mut std::vec::Vec) { + sqlx::encode::Encode::encode(&self.0, buf) + } + fn encode_nullable(&self, buf: &mut std::vec::Vec) -> sqlx::encode::IsNull { + sqlx::encode::Encode::encode_nullable(&self.0, buf) + } + fn size_hint(&self) -> usize { + sqlx::encode::Encode::size_hint(&self.0) + } + } + )) + } + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} + +pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result { + match &input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), + .. + }) if unnamed.len() == 1 => { + let ident = &input.ident; + let ty = &unnamed.first().unwrap().ty; + + // extract type generics + let generics = &input.generics; + let (_, ty_generics, _) = generics.split_for_impl(); + + // add db type for impl generics & where clause + let mut generics = generics.clone(); + generics.params.insert(0, parse_quote!(DB: sqlx::Database)); + generics + .make_where_clause() + .predicates + .push(parse_quote!(#ty: sqlx::decode::Decode)); + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + Ok(quote!( + impl #impl_generics sqlx::decode::Decode for #ident #ty_generics #where_clause { + fn decode(raw: &[u8]) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode(raw).map(Self) + } + fn decode_null() -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_null().map(Self) + } + fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result { + <#ty as sqlx::decode::Decode>::decode_nullable(raw).map(Self) + } + } + )) + } + _ => Err(syn::Error::new_spanned( + input, + "expected a tuple struct with a single field", + )), + } +} diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 7ae98436..ede05206 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -19,6 +19,8 @@ type Result = std::result::Result; mod database; +mod derives; + mod query_macros; use query_macros::*; @@ -134,3 +136,21 @@ pub fn query_as(input: TokenStream) -> TokenStream { pub fn query_file_as(input: TokenStream) -> TokenStream { async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db)) } + +#[proc_macro_derive(Encode)] +pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_encode(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +#[proc_macro_derive(Decode)] +pub fn derive_decode(tokenstream: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); + match derives::expand_derive_decode(input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} diff --git a/src/decode.rs b/src/decode.rs new file mode 100644 index 00000000..dcc99713 --- /dev/null +++ b/src/decode.rs @@ -0,0 +1,6 @@ +//! Types and traits for decoding values from the database. + +pub use sqlx_core::decode::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::Decode; diff --git a/src/encode.rs b/src/encode.rs new file mode 100644 index 00000000..d25c971c --- /dev/null +++ b/src/encode.rs @@ -0,0 +1,6 @@ +//! Types and traits for encoding values to the database. + +pub use sqlx_core::encode::*; + +#[cfg(feature = "macros")] +pub use sqlx_macros::Encode; diff --git a/src/lib.rs b/src/lib.rs index 0a861f48..46e5e8bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled"); // Modules -pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types}; +pub use sqlx_core::{arguments, describe, error, pool, row, types}; // Types pub use sqlx_core::{ @@ -41,3 +41,7 @@ pub mod ty_cons; #[cfg(feature = "macros")] #[doc(hidden)] pub mod result_ext; + +pub mod encode; + +pub mod decode; diff --git a/tests/derives.rs b/tests/derives.rs new file mode 100644 index 00000000..9c7aa411 --- /dev/null +++ b/tests/derives.rs @@ -0,0 +1,60 @@ +use sqlx::decode::Decode; +use sqlx::encode::Encode; + +#[derive(PartialEq, Debug, Encode, Decode)] +struct Foo(i32); + +#[test] +#[cfg(feature = "mysql")] +fn encode_mysql() { + encode_with_db::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn encode_postgres() { + encode_with_db::(); +} + +#[allow(dead_code)] +fn encode_with_db() +where + Foo: Encode, + i32: Encode, +{ + let example = Foo(0x1122_3344); + + let mut encoded = Vec::new(); + let mut encoded_orig = Vec::new(); + + Encode::::encode(&example, &mut encoded); + Encode::::encode(&example.0, &mut encoded_orig); + + assert_eq!(encoded, encoded_orig); +} + +#[test] +#[cfg(feature = "mysql")] +fn decode_mysql() { + decode_with_db::(); +} + +#[test] +#[cfg(feature = "postgres")] +fn decode_postgres() { + decode_with_db::(); +} + +#[allow(dead_code)] +fn decode_with_db() +where + Foo: Decode + Encode, +{ + let example = Foo(0x1122_3344); + + let mut encoded = Vec::new(); + Encode::::encode(&example, &mut encoded); + + let decoded = Foo::decode(&encoded).unwrap(); + assert_eq!(example, decoded); +}