mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 15:25:32 +00:00
Merge pull request #71 from Freax13/master
add derives for Encode and Decode
This commit is contained in:
commit
f0c88da152
@ -85,5 +85,9 @@ required-features = [ "mysql" ]
|
||||
name = "mysql-types-chrono"
|
||||
required-features = [ "mysql", "chrono", "macros" ]
|
||||
|
||||
[[test]]
|
||||
name = "derives"
|
||||
required-features = [ "macros" ]
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
|
88
sqlx-macros/src/derives.rs
Normal file
88
sqlx-macros/src/derives.rs
Normal file
@ -0,0 +1,88 @@
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed};
|
||||
|
||||
pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
let ident = &input.ident;
|
||||
let ty = &unnamed.first().unwrap().ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
|
||||
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
|
||||
sqlx::encode::Encode::encode(&self.0, buf)
|
||||
}
|
||||
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode_nullable(&self.0, buf)
|
||||
}
|
||||
fn size_hint(&self) -> usize {
|
||||
sqlx::encode::Encode::size_hint(&self.0)
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
_ => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
let ident = &input.ident;
|
||||
let ty = &unnamed.first().unwrap().ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<DB> for #ident #ty_generics #where_clause {
|
||||
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode(raw).map(Self)
|
||||
}
|
||||
fn decode_null() -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode_null().map(Self)
|
||||
}
|
||||
fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode_nullable(raw).map(Self)
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
_ => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
)),
|
||||
}
|
||||
}
|
@ -19,6 +19,8 @@ type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
mod database;
|
||||
|
||||
mod derives;
|
||||
|
||||
mod query_macros;
|
||||
|
||||
use query_macros::*;
|
||||
@ -134,3 +136,21 @@ pub fn query_as(input: TokenStream) -> TokenStream {
|
||||
pub fn query_file_as(input: TokenStream) -> TokenStream {
|
||||
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Encode)]
|
||||
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_encode(input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Decode)]
|
||||
pub fn derive_decode(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_decode(input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
6
src/decode.rs
Normal file
6
src/decode.rs
Normal file
@ -0,0 +1,6 @@
|
||||
//! Types and traits for decoding values from the database.
|
||||
|
||||
pub use sqlx_core::decode::*;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::Decode;
|
6
src/encode.rs
Normal file
6
src/encode.rs
Normal file
@ -0,0 +1,6 @@
|
||||
//! Types and traits for encoding values to the database.
|
||||
|
||||
pub use sqlx_core::encode::*;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::Encode;
|
@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e
|
||||
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
|
||||
|
||||
// Modules
|
||||
pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types};
|
||||
pub use sqlx_core::{arguments, describe, error, pool, row, types};
|
||||
|
||||
// Types
|
||||
pub use sqlx_core::{
|
||||
@ -41,3 +41,7 @@ pub mod ty_cons;
|
||||
#[cfg(feature = "macros")]
|
||||
#[doc(hidden)]
|
||||
pub mod result_ext;
|
||||
|
||||
pub mod encode;
|
||||
|
||||
pub mod decode;
|
||||
|
60
tests/derives.rs
Normal file
60
tests/derives.rs
Normal file
@ -0,0 +1,60 @@
|
||||
use sqlx::decode::Decode;
|
||||
use sqlx::encode::Encode;
|
||||
|
||||
#[derive(PartialEq, Debug, Encode, Decode)]
|
||||
struct Foo(i32);
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn encode_mysql() {
|
||||
encode_with_db::<sqlx::MySql>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn encode_postgres() {
|
||||
encode_with_db::<sqlx::Postgres>();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn encode_with_db<DB: sqlx::Database>()
|
||||
where
|
||||
Foo: Encode<DB>,
|
||||
i32: Encode<DB>,
|
||||
{
|
||||
let example = Foo(0x1122_3344);
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
let mut encoded_orig = Vec::new();
|
||||
|
||||
Encode::<DB>::encode(&example, &mut encoded);
|
||||
Encode::<DB>::encode(&example.0, &mut encoded_orig);
|
||||
|
||||
assert_eq!(encoded, encoded_orig);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn decode_mysql() {
|
||||
decode_with_db::<sqlx::MySql>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn decode_postgres() {
|
||||
decode_with_db::<sqlx::Postgres>();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn decode_with_db<DB: sqlx::Database>()
|
||||
where
|
||||
Foo: Decode<DB> + Encode<DB>,
|
||||
{
|
||||
let example = Foo(0x1122_3344);
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
Encode::<DB>::encode(&example, &mut encoded);
|
||||
|
||||
let decoded = Foo::decode(&encoded).unwrap();
|
||||
assert_eq!(example, decoded);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user