mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 04:04:56 +00:00
add derives for Encode and Decode
This commit is contained in:
parent
9141bd7561
commit
bb933decb7
110
sqlx-macros/src/derives.rs
Normal file
110
sqlx-macros/src/derives.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use proc_macro2::Span;
|
||||
use quote::quote;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token::{Gt, Lt, Where};
|
||||
use syn::{
|
||||
parse_quote, Data, DataStruct, DeriveInput, Fields, PredicateType, Token, WhereClause,
|
||||
WherePredicate,
|
||||
};
|
||||
|
||||
pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
if let Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(fields),
|
||||
..
|
||||
}) = &input.data
|
||||
{
|
||||
let fields = &fields.unnamed;
|
||||
if fields.len() == 1 {
|
||||
let ident = &input.ident;
|
||||
let ty = &fields.first().unwrap().ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
|
||||
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
|
||||
sqlx::encode::Encode::encode(&self.0, buf)
|
||||
}
|
||||
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode_nullable(&self.0, buf)
|
||||
}
|
||||
fn size_hint(&self) -> usize {
|
||||
sqlx::encode::Encode::size_hint(&self.0)
|
||||
}
|
||||
}
|
||||
))
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
if let Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(fields),
|
||||
..
|
||||
}) = &input.data
|
||||
{
|
||||
let fields = &fields.unnamed;
|
||||
if fields.len() == 1 {
|
||||
let ident = &input.ident;
|
||||
let ty = &fields.first().unwrap().ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<DB> for #ident #ty_generics #where_clause {
|
||||
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode(raw).map(Self)
|
||||
}
|
||||
fn decode_null() -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode_null().map(Self)
|
||||
}
|
||||
fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode_nullable(raw).map(Self)
|
||||
}
|
||||
}
|
||||
))
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -21,6 +21,8 @@ type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
mod database;
|
||||
|
||||
mod derives;
|
||||
|
||||
mod query_macros;
|
||||
|
||||
use query_macros::*;
|
||||
@ -136,3 +138,21 @@ pub fn query_as(input: TokenStream) -> TokenStream {
|
||||
pub fn query_file_as(input: TokenStream) -> TokenStream {
|
||||
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Encode)]
|
||||
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_encode(input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Decode)]
|
||||
pub fn derive_decode(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_decode(input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
4
src/decode.rs
Normal file
4
src/decode.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub use sqlx_core::decode::*;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::Decode;
|
||||
4
src/encode.rs
Normal file
4
src/encode.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub use sqlx_core::encode::*;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::Encode;
|
||||
@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e
|
||||
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");
|
||||
|
||||
// Modules
|
||||
pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types};
|
||||
pub use sqlx_core::{arguments, describe, error, pool, row, types};
|
||||
|
||||
// Types
|
||||
pub use sqlx_core::{
|
||||
@ -42,3 +42,7 @@ pub mod ty_cons;
|
||||
#[cfg(feature = "macros")]
|
||||
#[doc(hidden)]
|
||||
pub mod result_ext;
|
||||
|
||||
pub mod encode;
|
||||
|
||||
pub mod decode;
|
||||
|
||||
27
tests/derives.rs
Normal file
27
tests/derives.rs
Normal file
@ -0,0 +1,27 @@
|
||||
#[test]
|
||||
#[cfg(feature = "macros")]
|
||||
fn encode() {
|
||||
use sqlx::encode::Encode;
|
||||
|
||||
#[derive(Encode)]
|
||||
struct Foo(i32);
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
let _: Box<dyn Encode<sqlx::Postgres>> = Box::new(Foo(1));
|
||||
#[cfg(feature = "mysql")]
|
||||
let _: Box<dyn Encode<sqlx::MySql>> = Box::new(Foo(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "macros")]
|
||||
fn decode() {
|
||||
use sqlx::decode::Decode;
|
||||
|
||||
#[derive(Decode)]
|
||||
struct Foo(i32);
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
let _: Box<dyn Decode<sqlx::Postgres>> = Box::new(Foo(1));
|
||||
#[cfg(feature = "mysql")]
|
||||
let _: Box<dyn Decode<sqlx::MySql>> = Box::new(Foo(1));
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user