mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-03 15:55:45 +00:00
Merge pull request #71 from Freax13/master
add derives for Encode and Decode
This commit is contained in:
commit
f0c88da152
@ -85,5 +85,9 @@ required-features = [ "mysql" ]
|
|||||||
name = "mysql-types-chrono"
|
name = "mysql-types-chrono"
|
||||||
required-features = [ "mysql", "chrono", "macros" ]
|
required-features = [ "mysql", "chrono", "macros" ]
|
||||||
|
|
||||||
|
[[test]]
|
||||||
|
name = "derives"
|
||||||
|
required-features = [ "macros" ]
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
lto = true
|
lto = true
|
||||||
|
88
sqlx-macros/src/derives.rs
Normal file
88
sqlx-macros/src/derives.rs
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
use quote::quote;
|
||||||
|
use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed};
|
||||||
|
|
||||||
|
pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||||
|
match &input.data {
|
||||||
|
Data::Struct(DataStruct {
|
||||||
|
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||||
|
..
|
||||||
|
}) if unnamed.len() == 1 => {
|
||||||
|
let ident = &input.ident;
|
||||||
|
let ty = &unnamed.first().unwrap().ty;
|
||||||
|
|
||||||
|
// extract type generics
|
||||||
|
let generics = &input.generics;
|
||||||
|
let (_, ty_generics, _) = generics.split_for_impl();
|
||||||
|
|
||||||
|
// add db type for impl generics & where clause
|
||||||
|
let mut generics = generics.clone();
|
||||||
|
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||||
|
generics
|
||||||
|
.make_where_clause()
|
||||||
|
.predicates
|
||||||
|
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
|
||||||
|
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||||
|
|
||||||
|
Ok(quote!(
|
||||||
|
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
|
||||||
|
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
|
||||||
|
sqlx::encode::Encode::encode(&self.0, buf)
|
||||||
|
}
|
||||||
|
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
|
||||||
|
sqlx::encode::Encode::encode_nullable(&self.0, buf)
|
||||||
|
}
|
||||||
|
fn size_hint(&self) -> usize {
|
||||||
|
sqlx::encode::Encode::size_hint(&self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
))
|
||||||
|
}
|
||||||
|
_ => Err(syn::Error::new_spanned(
|
||||||
|
input,
|
||||||
|
"expected a tuple struct with a single field",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||||
|
match &input.data {
|
||||||
|
Data::Struct(DataStruct {
|
||||||
|
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||||
|
..
|
||||||
|
}) if unnamed.len() == 1 => {
|
||||||
|
let ident = &input.ident;
|
||||||
|
let ty = &unnamed.first().unwrap().ty;
|
||||||
|
|
||||||
|
// extract type generics
|
||||||
|
let generics = &input.generics;
|
||||||
|
let (_, ty_generics, _) = generics.split_for_impl();
|
||||||
|
|
||||||
|
// add db type for impl generics & where clause
|
||||||
|
let mut generics = generics.clone();
|
||||||
|
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||||
|
generics
|
||||||
|
.make_where_clause()
|
||||||
|
.predicates
|
||||||
|
.push(parse_quote!(#ty: sqlx::decode::Decode<DB>));
|
||||||
|
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||||
|
|
||||||
|
Ok(quote!(
|
||||||
|
impl #impl_generics sqlx::decode::Decode<DB> for #ident #ty_generics #where_clause {
|
||||||
|
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||||
|
<#ty as sqlx::decode::Decode<DB>>::decode(raw).map(Self)
|
||||||
|
}
|
||||||
|
fn decode_null() -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||||
|
<#ty as sqlx::decode::Decode<DB>>::decode_null().map(Self)
|
||||||
|
}
|
||||||
|
fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||||
|
<#ty as sqlx::decode::Decode<DB>>::decode_nullable(raw).map(Self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
))
|
||||||
|
}
|
||||||
|
_ => Err(syn::Error::new_spanned(
|
||||||
|
input,
|
||||||
|
"expected a tuple struct with a single field",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
@ -19,6 +19,8 @@ type Result<T> = std::result::Result<T, Error>;
|
|||||||
|
|
||||||
mod database;
|
mod database;
|
||||||
|
|
||||||
|
mod derives;
|
||||||
|
|
||||||
mod query_macros;
|
mod query_macros;
|
||||||
|
|
||||||
use query_macros::*;
|
use query_macros::*;
|
||||||
@ -134,3 +136,21 @@ pub fn query_as(input: TokenStream) -> TokenStream {
|
|||||||
pub fn query_file_as(input: TokenStream) -> TokenStream {
|
pub fn query_file_as(input: TokenStream) -> TokenStream {
|
||||||
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
|
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[proc_macro_derive(Encode)]
|
||||||
|
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
|
||||||
|
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||||
|
match derives::expand_derive_encode(input) {
|
||||||
|
Ok(ts) => ts.into(),
|
||||||
|
Err(e) => e.to_compile_error().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[proc_macro_derive(Decode)]
|
||||||
|
pub fn derive_decode(tokenstream: TokenStream) -> TokenStream {
|
||||||
|
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||||
|
match derives::expand_derive_decode(input) {
|
||||||
|
Ok(ts) => ts.into(),
|
||||||
|
Err(e) => e.to_compile_error().into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
6
src/decode.rs
Normal file
6
src/decode.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
//! Types and traits for decoding values from the database.
|
||||||
|
|
||||||
|
pub use sqlx_core::decode::*;
|
||||||
|
|
||||||
|
#[cfg(feature = "macros")]
|
||||||
|
pub use sqlx_macros::Decode;
|
6
src/encode.rs
Normal file
6
src/encode.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
//! Types and traits for encoding values to the database.
|
||||||
|
|
||||||
|
pub use sqlx_core::encode::*;
|
||||||
|
|
||||||
|
#[cfg(feature = "macros")]
|
||||||
|
pub use sqlx_macros::Encode;
|
@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e
|
|||||||
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
|
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
|
||||||
|
|
||||||
// Modules
|
// Modules
|
||||||
pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types};
|
pub use sqlx_core::{arguments, describe, error, pool, row, types};
|
||||||
|
|
||||||
// Types
|
// Types
|
||||||
pub use sqlx_core::{
|
pub use sqlx_core::{
|
||||||
@ -41,3 +41,7 @@ pub mod ty_cons;
|
|||||||
#[cfg(feature = "macros")]
|
#[cfg(feature = "macros")]
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub mod result_ext;
|
pub mod result_ext;
|
||||||
|
|
||||||
|
pub mod encode;
|
||||||
|
|
||||||
|
pub mod decode;
|
||||||
|
60
tests/derives.rs
Normal file
60
tests/derives.rs
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
use sqlx::decode::Decode;
|
||||||
|
use sqlx::encode::Encode;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Debug, Encode, Decode)]
|
||||||
|
struct Foo(i32);
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "mysql")]
|
||||||
|
fn encode_mysql() {
|
||||||
|
encode_with_db::<sqlx::MySql>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "postgres")]
|
||||||
|
fn encode_postgres() {
|
||||||
|
encode_with_db::<sqlx::Postgres>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn encode_with_db<DB: sqlx::Database>()
|
||||||
|
where
|
||||||
|
Foo: Encode<DB>,
|
||||||
|
i32: Encode<DB>,
|
||||||
|
{
|
||||||
|
let example = Foo(0x1122_3344);
|
||||||
|
|
||||||
|
let mut encoded = Vec::new();
|
||||||
|
let mut encoded_orig = Vec::new();
|
||||||
|
|
||||||
|
Encode::<DB>::encode(&example, &mut encoded);
|
||||||
|
Encode::<DB>::encode(&example.0, &mut encoded_orig);
|
||||||
|
|
||||||
|
assert_eq!(encoded, encoded_orig);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "mysql")]
|
||||||
|
fn decode_mysql() {
|
||||||
|
decode_with_db::<sqlx::MySql>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "postgres")]
|
||||||
|
fn decode_postgres() {
|
||||||
|
decode_with_db::<sqlx::Postgres>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn decode_with_db<DB: sqlx::Database>()
|
||||||
|
where
|
||||||
|
Foo: Decode<DB> + Encode<DB>,
|
||||||
|
{
|
||||||
|
let example = Foo(0x1122_3344);
|
||||||
|
|
||||||
|
let mut encoded = Vec::new();
|
||||||
|
Encode::<DB>::encode(&example, &mut encoded);
|
||||||
|
|
||||||
|
let decoded = Foo::decode(&encoded).unwrap();
|
||||||
|
assert_eq!(example, decoded);
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user