Merge pull request #71 from Freax13/master

add derives for Encode and Decode
This commit is contained in:
Austin Bonander 2020-01-24 13:13:06 -08:00 committed by GitHub
commit f0c88da152
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 189 additions and 1 deletions

View File

@ -85,5 +85,9 @@ required-features = [ "mysql" ]
name = "mysql-types-chrono"
required-features = [ "mysql", "chrono", "macros" ]
[[test]]
name = "derives"
required-features = [ "macros" ]
[profile.release]
lto = true

View File

@ -0,0 +1,88 @@
use quote::quote;
use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed};
pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
let ident = &input.ident;
let ty = &unnamed.first().unwrap().ty;
// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!(
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
sqlx::encode::Encode::encode(&self.0, buf)
}
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode_nullable(&self.0, buf)
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&self.0)
}
}
))
}
_ => Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
)),
}
}
pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
let ident = &input.ident;
let ty = &unnamed.first().unwrap().ty;
// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::decode::Decode<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!(
impl #impl_generics sqlx::decode::Decode<DB> for #ident #ty_generics #where_clause {
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode(raw).map(Self)
}
fn decode_null() -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode_null().map(Self)
}
fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode_nullable(raw).map(Self)
}
}
))
}
_ => Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
)),
}
}

View File

@ -19,6 +19,8 @@ type Result<T> = std::result::Result<T, Error>;
mod database;
mod derives;
mod query_macros;
use query_macros::*;
@ -134,3 +136,21 @@ pub fn query_as(input: TokenStream) -> TokenStream {
pub fn query_file_as(input: TokenStream) -> TokenStream {
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
}
#[proc_macro_derive(Encode)]
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
match derives::expand_derive_encode(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro_derive(Decode)]
pub fn derive_decode(tokenstream: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
match derives::expand_derive_decode(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}

6
src/decode.rs Normal file
View File

@ -0,0 +1,6 @@
//! Types and traits for decoding values from the database.
pub use sqlx_core::decode::*;
#[cfg(feature = "macros")]
pub use sqlx_macros::Decode;

6
src/encode.rs Normal file
View File

@ -0,0 +1,6 @@
//! Types and traits for encoding values to the database.
pub use sqlx_core::encode::*;
#[cfg(feature = "macros")]
pub use sqlx_macros::Encode;

View File

@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
// Modules
pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types};
pub use sqlx_core::{arguments, describe, error, pool, row, types};
// Types
pub use sqlx_core::{
@ -41,3 +41,7 @@ pub mod ty_cons;
#[cfg(feature = "macros")]
#[doc(hidden)]
pub mod result_ext;
pub mod encode;
pub mod decode;

60
tests/derives.rs Normal file
View File

@ -0,0 +1,60 @@
use sqlx::decode::Decode;
use sqlx::encode::Encode;
#[derive(PartialEq, Debug, Encode, Decode)]
struct Foo(i32);
#[test]
#[cfg(feature = "mysql")]
fn encode_mysql() {
encode_with_db::<sqlx::MySql>();
}
#[test]
#[cfg(feature = "postgres")]
fn encode_postgres() {
encode_with_db::<sqlx::Postgres>();
}
#[allow(dead_code)]
fn encode_with_db<DB: sqlx::Database>()
where
Foo: Encode<DB>,
i32: Encode<DB>,
{
let example = Foo(0x1122_3344);
let mut encoded = Vec::new();
let mut encoded_orig = Vec::new();
Encode::<DB>::encode(&example, &mut encoded);
Encode::<DB>::encode(&example.0, &mut encoded_orig);
assert_eq!(encoded, encoded_orig);
}
#[test]
#[cfg(feature = "mysql")]
fn decode_mysql() {
decode_with_db::<sqlx::MySql>();
}
#[test]
#[cfg(feature = "postgres")]
fn decode_postgres() {
decode_with_db::<sqlx::Postgres>();
}
#[allow(dead_code)]
fn decode_with_db<DB: sqlx::Database>()
where
Foo: Decode<DB> + Encode<DB>,
{
let example = Foo(0x1122_3344);
let mut encoded = Vec::new();
Encode::<DB>::encode(&example, &mut encoded);
let decoded = Foo::decode(&encoded).unwrap();
assert_eq!(example, decoded);
}