add derives for Encode and Decode

This commit is contained in:
Tom Dohrmann 2020-01-18 13:30:16 +01:00
parent 9141bd7561
commit bb933decb7
6 changed files with 170 additions and 1 deletions

110
sqlx-macros/src/derives.rs Normal file
View File

@ -0,0 +1,110 @@
use proc_macro2::Span;
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::{Gt, Lt, Where};
use syn::{
parse_quote, Data, DataStruct, DeriveInput, Fields, PredicateType, Token, WhereClause,
WherePredicate,
};
pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
if let Data::Struct(DataStruct {
fields: Fields::Unnamed(fields),
..
}) = &input.data
{
let fields = &fields.unnamed;
if fields.len() == 1 {
let ident = &input.ident;
let ty = &fields.first().unwrap().ty;
// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!(
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
sqlx::encode::Encode::encode(&self.0, buf)
}
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode_nullable(&self.0, buf)
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&self.0)
}
}
))
} else {
Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
))
}
} else {
Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
))
}
}
pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
if let Data::Struct(DataStruct {
fields: Fields::Unnamed(fields),
..
}) = &input.data
{
let fields = &fields.unnamed;
if fields.len() == 1 {
let ident = &input.ident;
let ty = &fields.first().unwrap().ty;
// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::decode::Decode<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!(
impl #impl_generics sqlx::decode::Decode<DB> for #ident #ty_generics #where_clause {
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode(raw).map(Self)
}
fn decode_null() -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode_null().map(Self)
}
fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode_nullable(raw).map(Self)
}
}
))
} else {
Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
))
}
} else {
Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
))
}
}

View File

@ -21,6 +21,8 @@ type Result<T> = std::result::Result<T, Error>;
mod database;
mod derives;
mod query_macros;
use query_macros::*;
@ -136,3 +138,21 @@ pub fn query_as(input: TokenStream) -> TokenStream {
pub fn query_file_as(input: TokenStream) -> TokenStream {
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
}
#[proc_macro_derive(Encode)]
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
match derives::expand_derive_encode(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro_derive(Decode)]
pub fn derive_decode(tokenstream: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
match derives::expand_derive_decode(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}

4
src/decode.rs Normal file
View File

@ -0,0 +1,4 @@
pub use sqlx_core::decode::*;
#[cfg(feature = "macros")]
pub use sqlx_macros::Decode;

4
src/encode.rs Normal file
View File

@ -0,0 +1,4 @@
pub use sqlx_core::encode::*;
#[cfg(feature = "macros")]
pub use sqlx_macros::Encode;

View File

@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
// Modules
pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types};
pub use sqlx_core::{arguments, describe, error, pool, row, types};
// Types
pub use sqlx_core::{
@ -42,3 +42,7 @@ pub mod ty_cons;
#[cfg(feature = "macros")]
#[doc(hidden)]
pub mod result_ext;
pub mod encode;
pub mod decode;

27
tests/derives.rs Normal file
View File

@ -0,0 +1,27 @@
#[test]
#[cfg(feature = "macros")]
fn encode() {
use sqlx::encode::Encode;
#[derive(Encode)]
struct Foo(i32);
#[cfg(feature = "postgres")]
let _: Box<dyn Encode<sqlx::Postgres>> = Box::new(Foo(1));
#[cfg(feature = "mysql")]
let _: Box<dyn Encode<sqlx::MySql>> = Box::new(Foo(1));
}
#[test]
#[cfg(feature = "macros")]
fn decode() {
use sqlx::decode::Decode;
#[derive(Decode)]
struct Foo(i32);
#[cfg(feature = "postgres")]
let _: Box<dyn Decode<sqlx::Postgres>> = Box::new(Foo(1));
#[cfg(feature = "mysql")]
let _: Box<dyn Decode<sqlx::MySql>> = Box::new(Foo(1));
}