mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-03 07:45:30 +00:00
add derive macros for weak & strong enums and structs
This commit is contained in:
parent
ada57fa566
commit
c3aeb275c2
@ -29,4 +29,4 @@ pub type MySqlPool = crate::pool::Pool<MySqlConnection>;
|
||||
make_query_as!(MySqlQueryAs, MySql, MySqlRow);
|
||||
impl_map_row_for_row!(MySql, MySqlRow);
|
||||
impl_column_index_for_row!(MySql);
|
||||
impl_from_row_for_tuples!(MySql, MySqlRow);
|
||||
impl_from_row_for_tuples!(MySql, MySqlRow);
|
@ -39,6 +39,9 @@ type_id_consts! {
|
||||
pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY
|
||||
pub const TEXT: TypeId = TypeId(252); // or BLOB
|
||||
|
||||
// Enum
|
||||
pub const ENUM: TypeId = TypeId(247);
|
||||
|
||||
// More Bytes
|
||||
pub const TINY_BLOB: TypeId = TypeId(249);
|
||||
pub const MEDIUM_BLOB: TypeId = TypeId(250);
|
||||
|
@ -64,6 +64,11 @@ impl MySqlTypeInfo {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn r#enum() -> Self {
|
||||
Self::new(TypeId::ENUM)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MySqlTypeInfo {
|
||||
@ -85,6 +90,7 @@ impl TypeInfo for MySqlTypeInfo {
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::ENUM
|
||||
if (self.is_binary == other.is_binary)
|
||||
&& match other.id {
|
||||
TypeId::VAR_CHAR
|
||||
@ -92,7 +98,8 @@ impl TypeInfo for MySqlTypeInfo {
|
||||
| TypeId::CHAR
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB => true,
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::ENUM => true,
|
||||
|
||||
_ => false,
|
||||
} =>
|
||||
|
@ -57,6 +57,10 @@ impl PgTypeInfo {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn oid(&self) -> u32 {
|
||||
self.id.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PgTypeInfo {
|
||||
|
@ -1,108 +1,847 @@
|
||||
use proc_macro2::Ident;
|
||||
use quote::quote;
|
||||
use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed};
|
||||
use std::iter::FromIterator;
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Arm, Attribute, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field,
|
||||
Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Variant,Stmt,
|
||||
};
|
||||
|
||||
pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
let ident = &input.ident;
|
||||
let ty = &unnamed.first().unwrap().ty;
|
||||
macro_rules! assert_attribute {
|
||||
($e:expr, $err:expr, $input:expr) => {
|
||||
if !$e {
|
||||
return Err(syn::Error::new_spanned($input, $err));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
struct SqlxAttributes {
|
||||
transparent: bool,
|
||||
postgres_oid: Option<u32>,
|
||||
repr: Option<Ident>,
|
||||
rename: Option<String>,
|
||||
}
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
fn parse_attributes(input: &[Attribute]) -> syn::Result<SqlxAttributes> {
|
||||
let mut transparent = None;
|
||||
let mut postgres_oid = None;
|
||||
let mut repr = None;
|
||||
let mut rename = None;
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
|
||||
fn encode(&self, buf: &mut <DB as sqlx::Database>::RawBuffer) {
|
||||
sqlx::encode::Encode::encode(&self.0, buf)
|
||||
}
|
||||
fn encode_nullable(&self, buf: &mut <DB as sqlx::Database>::RawBuffer) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode_nullable(&self.0, buf)
|
||||
}
|
||||
fn size_hint(&self) -> usize {
|
||||
sqlx::encode::Encode::size_hint(&self.0)
|
||||
macro_rules! fail {
|
||||
($t:expr, $m:expr) => {
|
||||
return Err(syn::Error::new_spanned($t, $m));
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! try_set {
|
||||
($i:ident, $v:expr, $t:expr) => {
|
||||
match $i {
|
||||
None => $i = Some($v),
|
||||
Some(_) => fail!($t, "duplicate attribute"),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
for attr in input {
|
||||
let meta = attr
|
||||
.parse_meta()
|
||||
.map_err(|e| syn::Error::new_spanned(attr, e))?;
|
||||
match meta {
|
||||
Meta::List(list) if list.path.is_ident("sqlx") => {
|
||||
for value in list.nested.iter() {
|
||||
match value {
|
||||
NestedMeta::Meta(meta) => match meta {
|
||||
Meta::Path(p) if p.is_ident("transparent") => {
|
||||
try_set!(transparent, true, value)
|
||||
}
|
||||
Meta::NameValue(MetaNameValue {
|
||||
path,
|
||||
lit: Lit::Str(val),
|
||||
..
|
||||
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
|
||||
Meta::List(list) if list.path.is_ident("postgres") => {
|
||||
for value in list.nested.iter() {
|
||||
match value {
|
||||
NestedMeta::Meta(Meta::NameValue(MetaNameValue {
|
||||
path,
|
||||
lit: Lit::Int(val),
|
||||
..
|
||||
})) if path.is_ident("oid") => {
|
||||
try_set!(postgres_oid, val.base10_parse()?, value);
|
||||
}
|
||||
u => fail!(u, "unexpected value"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
u => fail!(u, "unexpected attribute"),
|
||||
},
|
||||
u => fail!(u, "unexpected attribute"),
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
Meta::List(list) if list.path.is_ident("repr") => {
|
||||
if list.nested.len() != 1 {
|
||||
fail!(&list.nested, "expected one value")
|
||||
}
|
||||
match list.nested.first().unwrap() {
|
||||
NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => {
|
||||
try_set!(repr, p.get_ident().unwrap().clone(), list);
|
||||
}
|
||||
u => fail!(u, "unexpected value"),
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
)),
|
||||
}
|
||||
|
||||
Ok(SqlxAttributes {
|
||||
transparent: transparent.unwrap_or(false),
|
||||
postgres_oid,
|
||||
repr,
|
||||
rename,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> {
|
||||
let attributes = parse_attributes(&input.attrs)?;
|
||||
assert_attribute!(
|
||||
attributes.transparent,
|
||||
"expected #[sqlx(transparent)]",
|
||||
input
|
||||
);
|
||||
#[cfg(feature = "postgres")]
|
||||
assert_attribute!(
|
||||
attributes.postgres_oid.is_none(),
|
||||
"unexpected #[sqlx(postgres(oid = ..))]",
|
||||
input
|
||||
);
|
||||
assert_attribute!(
|
||||
attributes.rename.is_none(),
|
||||
"unexpected #[sqlx(rename = ..)]",
|
||||
field
|
||||
);
|
||||
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
|
||||
let attributes = parse_attributes(&field.attrs)?;
|
||||
assert_attribute!(
|
||||
!attributes.transparent,
|
||||
"unexpected #[sqlx(transparent)]",
|
||||
field
|
||||
);
|
||||
#[cfg(feature = "postgres")]
|
||||
assert_attribute!(
|
||||
attributes.postgres_oid.is_none(),
|
||||
"unexpected #[sqlx(postgres(oid = ..))]",
|
||||
field
|
||||
);
|
||||
assert_attribute!(
|
||||
attributes.rename.is_none(),
|
||||
"unexpected #[sqlx(rename = ..)]",
|
||||
field
|
||||
);
|
||||
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_enum_attributes<'a>(
|
||||
input: &'a DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<SqlxAttributes> {
|
||||
let attributes = parse_attributes(&input.attrs)?;
|
||||
assert_attribute!(
|
||||
!attributes.transparent,
|
||||
"unexpected #[sqlx(transparent)]",
|
||||
input
|
||||
);
|
||||
assert_attribute!(
|
||||
attributes.rename.is_none(),
|
||||
"unexpected #[sqlx(rename = ..)]",
|
||||
input
|
||||
);
|
||||
|
||||
for variant in variants {
|
||||
let attributes = parse_attributes(&variant.attrs)?;
|
||||
assert_attribute!(
|
||||
!attributes.transparent,
|
||||
"unexpected #[sqlx(transparent)]",
|
||||
variant
|
||||
);
|
||||
#[cfg(feature = "postgres")]
|
||||
assert_attribute!(
|
||||
attributes.postgres_oid.is_none(),
|
||||
"unexpected #[sqlx(postgres(oid = ..))]",
|
||||
variant
|
||||
);
|
||||
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant);
|
||||
}
|
||||
|
||||
Ok(attributes)
|
||||
}
|
||||
|
||||
fn check_weak_enum_attributes(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<Ident> {
|
||||
let attributes = check_enum_attributes(input, variants)?;
|
||||
#[cfg(feature = "postgres")]
|
||||
assert_attribute!(
|
||||
attributes.postgres_oid.is_none(),
|
||||
"unexpected #[sqlx(postgres(oid = ..))]",
|
||||
input
|
||||
);
|
||||
assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input);
|
||||
for variant in variants {
|
||||
let attributes = parse_attributes(&variant.attrs)?;
|
||||
assert_attribute!(
|
||||
attributes.rename.is_none(),
|
||||
"unexpected #[sqlx(rename = ..)]",
|
||||
variant
|
||||
);
|
||||
}
|
||||
Ok(attributes.repr.unwrap())
|
||||
}
|
||||
|
||||
fn check_strong_enum_attributes(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<SqlxAttributes> {
|
||||
let attributes = check_enum_attributes(input, variants)?;
|
||||
#[cfg(feature = "postgres")]
|
||||
assert_attribute!(
|
||||
attributes.postgres_oid.is_some(),
|
||||
"expected #[sqlx(postgres(oid = ..))]",
|
||||
input
|
||||
);
|
||||
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
|
||||
Ok(attributes)
|
||||
}
|
||||
|
||||
fn check_struct_attributes<'a>(
|
||||
input: &'a DeriveInput,
|
||||
fields: &Punctuated<Field, Comma>,
|
||||
) -> syn::Result<SqlxAttributes> {
|
||||
let attributes = parse_attributes(&input.attrs)?;
|
||||
assert_attribute!(
|
||||
!attributes.transparent,
|
||||
"unexpected #[sqlx(transparent)]",
|
||||
input
|
||||
);
|
||||
#[cfg(feature = "postgres")]
|
||||
assert_attribute!(
|
||||
attributes.postgres_oid.is_some(),
|
||||
"expected #[sqlx(postgres(oid = ..))]",
|
||||
input
|
||||
);
|
||||
assert_attribute!(
|
||||
attributes.rename.is_none(),
|
||||
"unexpected #[sqlx(rename = ..)]",
|
||||
input
|
||||
);
|
||||
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
|
||||
|
||||
for field in fields {
|
||||
let attributes = parse_attributes(&field.attrs)?;
|
||||
assert_attribute!(
|
||||
!attributes.transparent,
|
||||
"unexpected #[sqlx(transparent)]",
|
||||
field
|
||||
);
|
||||
#[cfg(feature = "postgres")]
|
||||
assert_attribute!(
|
||||
attributes.postgres_oid.is_none(),
|
||||
"unexpected #[sqlx(postgres(oid = ..))]",
|
||||
field
|
||||
);
|
||||
assert_attribute!(
|
||||
attributes.rename.is_none(),
|
||||
"unexpected #[sqlx(rename = ..)]",
|
||||
field
|
||||
);
|
||||
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field);
|
||||
}
|
||||
|
||||
Ok(attributes)
|
||||
}
|
||||
|
||||
pub(crate) fn expand_derive_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let args = parse_attributes(&input.attrs)?;
|
||||
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
let ident = &input.ident;
|
||||
let ty = &unnamed.first().unwrap().ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
let mut impls = Vec::new();
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!('de));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>));
|
||||
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
impls.push(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause {
|
||||
fn decode(value: <sqlx::Postgres as sqlx::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
|
||||
<#ty as sqlx::decode::Decode<'de, sqlx::Postgres>>::decode(value).map(Self)
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(feature = "mysql") {
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!('de));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::MySql>));
|
||||
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
impls.push(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<'de, sqlx::MySql> for #ident #ty_generics #where_clause {
|
||||
fn decode(value: <sqlx::MySql as sqlx::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
|
||||
<#ty as sqlx::decode::Decode<'de, sqlx::MySql>>::decode(value).map(Self)
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
// panic!("{}", q)
|
||||
Ok(quote!(#(#impls)*))
|
||||
expand_derive_encode_transparent(&input, unnamed.first().unwrap())
|
||||
}
|
||||
Data::Enum(DataEnum { variants, .. }) => match args.repr {
|
||||
Some(_) => expand_derive_encode_weak_enum(input, variants),
|
||||
None => expand_derive_encode_strong_enum(input, variants),
|
||||
},
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Named(FieldsNamed { named, .. }),
|
||||
..
|
||||
}) => expand_derive_encode_struct(input, named),
|
||||
_ => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_derive_encode_transparent(
|
||||
input: &DeriveInput,
|
||||
field: &Field,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_transparent_attributes(input, field)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let ty = &field.ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
|
||||
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
|
||||
sqlx::encode::Encode::encode(&self.0, buf)
|
||||
}
|
||||
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode_nullable(&self.0, buf)
|
||||
}
|
||||
fn size_hint(&self) -> usize {
|
||||
sqlx::encode::Encode::size_hint(&self.0)
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
fn expand_derive_encode_weak_enum(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let repr = check_weak_enum_attributes(input, &variants)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
|
||||
Ok(quote!(
|
||||
impl<DB: sqlx::Database> sqlx::encode::Encode<DB> for #ident where #repr: sqlx::encode::Encode<DB> {
|
||||
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
|
||||
sqlx::encode::Encode::encode(&(*self as #repr), buf)
|
||||
}
|
||||
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf)
|
||||
}
|
||||
fn size_hint(&self) -> usize {
|
||||
sqlx::encode::Encode::size_hint(&(*self as #repr))
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
fn expand_derive_encode_strong_enum(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_strong_enum_attributes(input, &variants)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "mysql") {
|
||||
let mut value_arms = Vec::new();
|
||||
for v in variants {
|
||||
let id = &v.ident;
|
||||
let attributes = parse_attributes(&v.attrs)?;
|
||||
if let Some(rename) = attributes.rename {
|
||||
value_arms.push(quote!(#ident :: #id => #rename,));
|
||||
} else {
|
||||
let name = id.to_string();
|
||||
value_arms.push(quote!(#ident :: #id => #name,));
|
||||
}
|
||||
}
|
||||
|
||||
tts.extend(quote!(
|
||||
impl<DB: sqlx::Database> sqlx::encode::Encode<DB> for #ident where str: sqlx::encode::Encode<DB> {
|
||||
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
|
||||
let val = match self {
|
||||
#(#value_arms)*
|
||||
};
|
||||
<str as sqlx::encode::Encode<DB>>::encode(val, buf)
|
||||
}
|
||||
fn size_hint(&self) -> usize {
|
||||
let val = match self {
|
||||
#(#value_arms)*
|
||||
};
|
||||
<str as sqlx::encode::Encode<DB>>::size_hint(val)
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
fn expand_derive_encode_struct(
|
||||
input: &DeriveInput,
|
||||
fields: &Punctuated<Field, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_struct_attributes(input, &fields)?;
|
||||
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
let ident = &input.ident;
|
||||
|
||||
let column_count = fields.len();
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
let predicates = &mut generics.make_where_clause().predicates;
|
||||
for field in fields {
|
||||
let ty = &field.ty;
|
||||
predicates.push(parse_quote!(#ty: sqlx::encode::Encode<sqlx::Postgres>));
|
||||
predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>));
|
||||
}
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let mut writes: Vec<Block> = Vec::new();
|
||||
for field in fields {
|
||||
let id = &field.ident;
|
||||
let ty = &field.ty;
|
||||
writes.push(parse_quote!({
|
||||
// write oid
|
||||
let info = <sqlx::Postgres as sqlx::types::HasSqlType<#ty>>::type_info();
|
||||
buf.extend(&info.oid().to_be_bytes());
|
||||
|
||||
// write zeros for length
|
||||
buf.extend(&[0; 4]);
|
||||
|
||||
let start = buf.len();
|
||||
sqlx::encode::Encode::<sqlx::Postgres>::encode(&self. #id, buf);
|
||||
let end = buf.len();
|
||||
let size = end - start;
|
||||
|
||||
// replaces zeros with actual length
|
||||
buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes());
|
||||
}));
|
||||
}
|
||||
|
||||
let mut sizes: Vec<Expr> = Vec::new();
|
||||
for field in fields {
|
||||
let id = &field.ident;
|
||||
let ty = &field.ty;
|
||||
sizes.push(
|
||||
parse_quote!(<#ty as sqlx::encode::Encode<sqlx::Postgres>>::size_hint(&self. #id)),
|
||||
);
|
||||
}
|
||||
|
||||
tts.extend(quote!(
|
||||
impl #impl_generics sqlx::encode::Encode<sqlx::Postgres> for #ident #ty_generics #where_clause {
|
||||
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
|
||||
buf.extend(&(#column_count as u32).to_be_bytes());
|
||||
#(#writes)*
|
||||
}
|
||||
fn size_hint(&self) -> usize {
|
||||
4 + #column_count * (4 + 4) + #(#sizes)+*
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
pub(crate) fn expand_derive_decode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let attrs = parse_attributes(&input.attrs)?;
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
expand_derive_decode_transparent(input, unnamed.first().unwrap())
|
||||
}
|
||||
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
|
||||
Some(_) => expand_derive_decode_weak_enum(input, variants),
|
||||
None => expand_derive_decode_strong_enum(input, variants),
|
||||
},
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Named(FieldsNamed { named, .. }),
|
||||
..
|
||||
}) => expand_derive_decode_struct(input, named),
|
||||
_ => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_derive_decode_transparent(
|
||||
input: &DeriveInput,
|
||||
field: &Field,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_transparent_attributes(input, field)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let ty = &field.ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<DB> for #ident #ty_generics #where_clause {
|
||||
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode(raw).map(Self)
|
||||
}
|
||||
fn decode_null() -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode_null().map(Self)
|
||||
}
|
||||
fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
<#ty as sqlx::decode::Decode<DB>>::decode_nullable(raw).map(Self)
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
fn expand_derive_decode_weak_enum(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let repr = check_weak_enum_attributes(input, &variants)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let arms = variants
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let id = &v.ident;
|
||||
parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),)
|
||||
})
|
||||
.collect::<Vec<Arm>>();
|
||||
|
||||
Ok(quote!(
|
||||
impl<DB: sqlx::Database> sqlx::decode::Decode<DB> for #ident where #repr: sqlx::decode::Decode<DB> {
|
||||
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
let val = <#repr as sqlx::decode::Decode<DB>>::decode(raw)?;
|
||||
match val {
|
||||
#(#arms)*
|
||||
_ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value")))
|
||||
}
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
fn expand_derive_decode_strong_enum(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_strong_enum_attributes(input, &variants)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
|
||||
let mut value_arms = Vec::new();
|
||||
for v in variants {
|
||||
let id = &v.ident;
|
||||
let attributes = parse_attributes(&v.attrs)?;
|
||||
if let Some(rename) = attributes.rename {
|
||||
value_arms.push(quote!(#rename => Ok(#ident :: #id),));
|
||||
} else {
|
||||
let name = id.to_string();
|
||||
value_arms.push(quote!(#name => Ok(#ident :: #id),));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: prevent heap allocation
|
||||
Ok(quote!(
|
||||
impl<DB: sqlx::Database> sqlx::decode::Decode<DB> for #ident where String: sqlx::decode::Decode<DB> {
|
||||
fn decode(buf: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
let val = <String as sqlx::decode::Decode<DB>>::decode(buf)?;
|
||||
match val.as_str() {
|
||||
#(#value_arms)*
|
||||
_ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value")))
|
||||
}
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
fn expand_derive_decode_struct(
|
||||
input: &DeriveInput,
|
||||
fields: &Punctuated<Field, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_struct_attributes(input, fields)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
|
||||
let column_count = fields.len();
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
let predicates = &mut generics.make_where_clause().predicates;
|
||||
for field in fields {
|
||||
let ty = &field.ty;
|
||||
predicates.push(parse_quote!(#ty: sqlx::decode::Decode<sqlx::Postgres>));
|
||||
predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>));
|
||||
}
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let mut reads: Vec<Vec<Stmt>> = Vec::new();
|
||||
let mut names: Vec<Ident> = Vec::new();
|
||||
for field in fields {
|
||||
let id = &field.ident;
|
||||
names.push(id.clone().unwrap());
|
||||
let ty = &field.ty;
|
||||
reads.push(parse_quote!(
|
||||
if buf.len() < 8 {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
|
||||
}
|
||||
|
||||
let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap());
|
||||
if oid != <sqlx::Postgres as sqlx::types::HasSqlType<#ty>>::type_info().oid() {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid")));
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize;
|
||||
|
||||
if buf.len() < 8 + len {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
|
||||
}
|
||||
|
||||
let raw = &buf[8..8+len];
|
||||
let #id = <#ty as sqlx::decode::Decode<sqlx::Postgres>>::decode(raw)?;
|
||||
|
||||
let buf = &buf[8+len..];
|
||||
));
|
||||
}
|
||||
let reads = reads.into_iter().flatten();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<sqlx::Postgres> for #ident#ty_generics #where_clause {
|
||||
fn decode(buf: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
|
||||
if buf.len() < 4 {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
|
||||
}
|
||||
|
||||
let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize;
|
||||
if column_count != #column_count {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count")));
|
||||
}
|
||||
let buf = &buf[4..];
|
||||
|
||||
#(#reads)*
|
||||
|
||||
if !buf.is_empty() {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len()))));
|
||||
}
|
||||
|
||||
Ok(#ident {
|
||||
#(#names),*
|
||||
})
|
||||
}
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn expand_derive_has_sql_type(
|
||||
input: &DeriveInput,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let attrs = parse_attributes(&input.attrs)?;
|
||||
match &input.data {
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
|
||||
..
|
||||
}) if unnamed.len() == 1 => {
|
||||
expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap())
|
||||
}
|
||||
Data::Enum(DataEnum { variants, .. }) => match attrs.repr {
|
||||
Some(_) => expand_derive_has_sql_type_weak_enum(input, variants),
|
||||
None => expand_derive_has_sql_type_strong_enum(input, variants),
|
||||
},
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Named(FieldsNamed { named, .. }),
|
||||
..
|
||||
}) => expand_derive_has_sql_type_struct(input, named),
|
||||
_ => Err(syn::Error::new_spanned(
|
||||
input,
|
||||
"expected a tuple struct with a single field",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_derive_has_sql_type_transparent(
|
||||
input: &DeriveInput,
|
||||
field: &Field,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_transparent_attributes(input, field)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let ty = &field.ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (impl_generics, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for clause
|
||||
let mut generics = generics.clone();
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(Self: sqlx::types::HasSqlType<#ty>));
|
||||
let (_, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "mysql") {
|
||||
tts.extend(quote!(
|
||||
impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::MySql #where_clause {
|
||||
fn type_info() -> Self::TypeInfo {
|
||||
<Self as HasSqlType<#ty>>::type_info()
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
tts.extend(quote!(
|
||||
impl #impl_generics sqlx::types::HasSqlType< #ident #ty_generics > for sqlx::Postgres #where_clause {
|
||||
fn type_info() -> Self::TypeInfo {
|
||||
<Self as HasSqlType<#ty>>::type_info()
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
fn expand_derive_has_sql_type_weak_enum(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let repr = check_weak_enum_attributes(input, variants)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "mysql") {
|
||||
tts.extend(quote!(
|
||||
impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > {
|
||||
fn type_info() -> Self::TypeInfo {
|
||||
<Self as HasSqlType<#repr>>::type_info()
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
tts.extend(quote!(
|
||||
impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > {
|
||||
fn type_info() -> Self::TypeInfo {
|
||||
<Self as HasSqlType<#repr>>::type_info()
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
fn expand_derive_has_sql_type_strong_enum(
|
||||
input: &DeriveInput,
|
||||
variants: &Punctuated<Variant, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let attributes = check_strong_enum_attributes(input, variants)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "mysql") {
|
||||
tts.extend(quote!(
|
||||
impl sqlx::types::HasSqlType< #ident > for sqlx::MySql {
|
||||
fn type_info() -> Self::TypeInfo {
|
||||
sqlx::mysql::MySqlTypeInfo::r#enum()
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
let oid = attributes.postgres_oid.unwrap();
|
||||
tts.extend(quote!(
|
||||
impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres {
|
||||
fn type_info() -> Self::TypeInfo {
|
||||
sqlx::postgres::PgTypeInfo::with_oid(#oid)
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
fn expand_derive_has_sql_type_struct(
|
||||
input: &DeriveInput,
|
||||
fields: &Punctuated<Field, Comma>,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let attributes = check_struct_attributes(input, fields)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
let oid = attributes.postgres_oid.unwrap();
|
||||
tts.extend(quote!(
|
||||
impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres {
|
||||
fn type_info() -> Self::TypeInfo {
|
||||
sqlx::postgres::PgTypeInfo::with_oid(#oid)
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
pub(crate) fn expand_derive_type(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
let encode_tts = expand_derive_encode(input)?;
|
||||
let decode_tts = expand_derive_decode(input)?;
|
||||
let has_sql_type_tts = expand_derive_has_sql_type(input)?;
|
||||
|
||||
let combined = proc_macro2::TokenStream::from_iter(
|
||||
encode_tts
|
||||
.into_iter()
|
||||
.chain(decode_tts)
|
||||
.chain(has_sql_type_tts),
|
||||
);
|
||||
Ok(combined)
|
||||
}
|
||||
|
@ -151,19 +151,37 @@ pub fn query_file_as(input: TokenStream) -> TokenStream {
|
||||
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Encode)]
|
||||
#[proc_macro_derive(Encode, attributes(sqlx))]
|
||||
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_encode(input) {
|
||||
match derives::expand_derive_encode(&input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Decode)]
|
||||
#[proc_macro_derive(Decode, attributes(sqlx))]
|
||||
pub fn derive_decode(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_decode(input) {
|
||||
match derives::expand_derive_decode(&input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_derive(HasSqlType, attributes(sqlx))]
|
||||
pub fn derive_has_sql_type(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_has_sql_type(&input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_derive(Type, attributes(sqlx))]
|
||||
pub fn derive_type(tokenstream: TokenStream) -> TokenStream {
|
||||
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
|
||||
match derives::expand_derive_type(&input) {
|
||||
Ok(ts) => ts.into(),
|
||||
Err(e) => e.to_compile_error().into(),
|
||||
}
|
||||
|
@ -40,6 +40,9 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool};
|
||||
#[doc(hidden)]
|
||||
pub extern crate sqlx_macros;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::Type;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
mod macros;
|
||||
|
||||
|
6
src/types.rs
Normal file
6
src/types.rs
Normal file
@ -0,0 +1,6 @@
|
||||
//! Traits linking Rust types to SQL types.
|
||||
|
||||
pub use sqlx_core::types::*;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::HasSqlType;
|
328
tests/derives.rs
328
tests/derives.rs
@ -1,77 +1,311 @@
|
||||
use sqlx::decode::Decode;
|
||||
use sqlx::encode::Encode;
|
||||
use sqlx::types::{HasSqlType, TypeInfo};
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[derive(PartialEq, Debug, Encode, Decode)]
|
||||
struct Foo(i32);
|
||||
#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)]
|
||||
#[sqlx(transparent)]
|
||||
struct Transparent(i32);
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn encode_with_postgres() {
|
||||
use sqlx_core::postgres::Postgres;
|
||||
#[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)]
|
||||
#[repr(i32)]
|
||||
#[allow(dead_code)]
|
||||
enum Weak {
|
||||
One,
|
||||
Two,
|
||||
Three,
|
||||
}
|
||||
|
||||
let example = Foo(0x1122_3344);
|
||||
#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)]
|
||||
#[sqlx(postgres(oid = 10101010))]
|
||||
#[allow(dead_code)]
|
||||
enum Strong {
|
||||
One,
|
||||
Two,
|
||||
#[sqlx(rename = "four")]
|
||||
Three,
|
||||
}
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
let mut encoded_orig = Vec::new();
|
||||
|
||||
Encode::<Postgres>::encode(&example, &mut encoded);
|
||||
Encode::<Postgres>::encode(&example.0, &mut encoded_orig);
|
||||
|
||||
assert_eq!(encoded, encoded_orig);
|
||||
#[derive(PartialEq, Debug, Encode, Decode, HasSqlType)]
|
||||
#[sqlx(postgres(oid = 20202020))]
|
||||
#[allow(dead_code)]
|
||||
struct Struct {
|
||||
field1: String,
|
||||
field2: i64,
|
||||
field3: bool,
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn encode_with_mysql() {
|
||||
use sqlx_core::mysql::MySql;
|
||||
|
||||
let example = Foo(0x1122_3344);
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
let mut encoded_orig = Vec::new();
|
||||
|
||||
Encode::<MySql>::encode(&example, &mut encoded);
|
||||
Encode::<MySql>::encode(&example.0, &mut encoded_orig);
|
||||
|
||||
assert_eq!(encoded, encoded_orig);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn decode_mysql() {
|
||||
decode_with_db();
|
||||
fn encode_transparent_mysql() {
|
||||
encode_transparent::<sqlx::MySql>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn decode_postgres() {
|
||||
decode_with_db();
|
||||
fn encode_transparent_postgres() {
|
||||
encode_transparent::<sqlx::Postgres>();
|
||||
}
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
fn decode_with_db()
|
||||
#[allow(dead_code)]
|
||||
fn encode_transparent<DB: sqlx::Database>()
|
||||
where
|
||||
Foo: for<'de> Decode<'de, sqlx::Postgres> + Encode<sqlx::Postgres>,
|
||||
Transparent: Encode<DB>,
|
||||
i32: Encode<DB>,
|
||||
{
|
||||
let example = Foo(0x1122_3344);
|
||||
let example = Transparent(0x1122_3344);
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
let mut encoded_orig = Vec::new();
|
||||
|
||||
Encode::<DB>::encode(&example, &mut encoded);
|
||||
Encode::<DB>::encode(&example.0, &mut encoded_orig);
|
||||
|
||||
assert_eq!(encoded, encoded_orig);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn encode_weak_enum_mysql() {
|
||||
encode_weak_enum::<sqlx::MySql>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn encode_weak_enum_postgres() {
|
||||
encode_weak_enum::<sqlx::Postgres>();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn encode_weak_enum<DB: sqlx::Database>()
|
||||
where
|
||||
Weak: Encode<DB>,
|
||||
i32: Encode<DB>,
|
||||
{
|
||||
for example in [Weak::One, Weak::Two, Weak::Three].iter() {
|
||||
let mut encoded = Vec::new();
|
||||
let mut encoded_orig = Vec::new();
|
||||
|
||||
Encode::<DB>::encode(example, &mut encoded);
|
||||
Encode::<DB>::encode(&(*example as i32), &mut encoded_orig);
|
||||
|
||||
assert_eq!(encoded, encoded_orig);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn encode_strong_enum_mysql() {
|
||||
encode_strong_enum::<sqlx::MySql>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn encode_strong_enum_postgres() {
|
||||
encode_strong_enum::<sqlx::Postgres>();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn encode_strong_enum<DB: sqlx::Database>()
|
||||
where
|
||||
Strong: Encode<DB>,
|
||||
str: Encode<DB>,
|
||||
{
|
||||
for (example, name) in [
|
||||
(Strong::One, "One"),
|
||||
(Strong::Two, "Two"),
|
||||
(Strong::Three, "four"),
|
||||
]
|
||||
.iter()
|
||||
{
|
||||
let mut encoded = Vec::new();
|
||||
let mut encoded_orig = Vec::new();
|
||||
|
||||
Encode::<sqlx::MySql>::encode(example, &mut encoded);
|
||||
Encode::<sqlx::MySql>::encode(name, &mut encoded_orig);
|
||||
|
||||
assert_eq!(encoded, encoded_orig);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn encode_struct_postgres() {
|
||||
let field1 = "Foo".to_string();
|
||||
let field2 = 3;
|
||||
let field3 = false;
|
||||
|
||||
let example = Struct {
|
||||
field1: field1.clone(),
|
||||
field2,
|
||||
field3,
|
||||
};
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
Encode::<sqlx::Postgres>::encode(&example, &mut encoded);
|
||||
|
||||
let decoded = Foo::decode(Some(sqlx::postgres::PgValue::Binary(&encoded))).unwrap();
|
||||
assert_eq!(example, decoded);
|
||||
let string_oid = <sqlx::Postgres as HasSqlType<String>>::type_info().oid();
|
||||
let i64_oid = <sqlx::Postgres as HasSqlType<i64>>::type_info().oid();
|
||||
let bool_oid = <sqlx::Postgres as HasSqlType<bool>>::type_info().oid();
|
||||
|
||||
// 3 columns
|
||||
assert_eq!(&[0, 0, 0, 3], &encoded[..4]);
|
||||
let encoded = &encoded[4..];
|
||||
|
||||
// check field1 (string)
|
||||
assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]);
|
||||
assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]);
|
||||
assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]);
|
||||
let encoded = &encoded[8 + field1.len()..];
|
||||
|
||||
// check field2 (i64)
|
||||
assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]);
|
||||
assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]);
|
||||
assert_eq!(field2.to_be_bytes(), &encoded[8..16]);
|
||||
let encoded = &encoded[16..];
|
||||
|
||||
// check field3 (bool)
|
||||
assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]);
|
||||
assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]);
|
||||
assert_eq!(field3, encoded[8] != 0);
|
||||
let encoded = &encoded[9..];
|
||||
|
||||
assert!(encoded.is_empty());
|
||||
|
||||
let string_size = <String as Encode<sqlx::Postgres>>::size_hint(&field1);
|
||||
let i64_size = <i64 as Encode<sqlx::Postgres>>::size_hint(&field2);
|
||||
let bool_size = <bool as Encode<sqlx::Postgres>>::size_hint(&field3);
|
||||
|
||||
assert_eq!(
|
||||
4 + 3 * (4 + 4) + string_size + i64_size + bool_size,
|
||||
example.size_hint()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn decode_with_db()
|
||||
where
|
||||
Foo: for<'de> Decode<'de, sqlx::MySql> + Encode<sqlx::MySql>,
|
||||
{
|
||||
let example = Foo(0x1122_3344);
|
||||
fn decode_transparent_mysql() {
|
||||
decode_with_db::<sqlx::MySql, Transparent>(Transparent(0x1122_3344));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn decode_transparent_postgres() {
|
||||
decode_with_db::<sqlx::Postgres, Transparent>(Transparent(0x1122_3344));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn decode_weak_enum_mysql() {
|
||||
decode_with_db::<sqlx::MySql, Weak>(Weak::One);
|
||||
decode_with_db::<sqlx::MySql, Weak>(Weak::Two);
|
||||
decode_with_db::<sqlx::MySql, Weak>(Weak::Three);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn decode_weak_enum_postgres() {
|
||||
decode_with_db::<sqlx::Postgres, Weak>(Weak::One);
|
||||
decode_with_db::<sqlx::Postgres, Weak>(Weak::Two);
|
||||
decode_with_db::<sqlx::Postgres, Weak>(Weak::Three);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn decode_strong_enum_mysql() {
|
||||
decode_with_db::<sqlx::MySql, Strong>(Strong::One);
|
||||
decode_with_db::<sqlx::MySql, Strong>(Strong::Two);
|
||||
decode_with_db::<sqlx::MySql, Strong>(Strong::Three);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn decode_strong_enum_postgres() {
|
||||
decode_with_db::<sqlx::Postgres, Strong>(Strong::One);
|
||||
decode_with_db::<sqlx::Postgres, Strong>(Strong::Two);
|
||||
decode_with_db::<sqlx::Postgres, Strong>(Strong::Three);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn decode_struct_postgres() {
|
||||
decode_with_db::<sqlx::Postgres, Struct>(Struct {
|
||||
field1: "Foo".to_string(),
|
||||
field2: 3,
|
||||
field3: true,
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn decode_with_db<DB: sqlx::Database, V: Decode<DB> + Encode<DB> + PartialEq + Debug>(example: V) {
|
||||
let mut encoded = Vec::new();
|
||||
Encode::<sqlx::MySql>::encode(&example, &mut encoded);
|
||||
Encode::<DB>::encode(&example, &mut encoded);
|
||||
|
||||
let decoded = Foo::decode(Some(sqlx::mysql::MySqlValue::Binary(&encoded))).unwrap();
|
||||
let decoded = V::decode(&encoded).unwrap();
|
||||
assert_eq!(example, decoded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn has_sql_type_transparent_mysql() {
|
||||
has_sql_type_transparent::<sqlx::MySql>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn has_sql_type_transparent_postgres() {
|
||||
has_sql_type_transparent::<sqlx::Postgres>();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn has_sql_type_transparent<DB: sqlx::Database>()
|
||||
where
|
||||
DB: HasSqlType<Transparent> + HasSqlType<i32>,
|
||||
{
|
||||
let info: DB::TypeInfo = <DB as HasSqlType<Transparent>>::type_info();
|
||||
let info_orig: DB::TypeInfo = <DB as HasSqlType<i32>>::type_info();
|
||||
assert!(info.compatible(&info_orig));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn has_sql_type_weak_enum_mysql() {
|
||||
has_sql_type_weak_enum::<sqlx::MySql>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn has_sql_type_weak_enum_postgres() {
|
||||
has_sql_type_weak_enum::<sqlx::Postgres>();
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn has_sql_type_weak_enum<DB: sqlx::Database>()
|
||||
where
|
||||
DB: HasSqlType<Weak> + HasSqlType<i32>,
|
||||
{
|
||||
let info: DB::TypeInfo = <DB as HasSqlType<Weak>>::type_info();
|
||||
let info_orig: DB::TypeInfo = <DB as HasSqlType<i32>>::type_info();
|
||||
assert!(info.compatible(&info_orig));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "mysql")]
|
||||
fn has_sql_type_strong_enum_mysql() {
|
||||
let info: sqlx::mysql::MySqlTypeInfo = <sqlx::MySql as HasSqlType<Strong>>::type_info();
|
||||
assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn has_sql_type_strong_enum_postgres() {
|
||||
let info: sqlx::postgres::PgTypeInfo = <sqlx::Postgres as HasSqlType<Strong>>::type_info();
|
||||
assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "postgres")]
|
||||
fn has_sql_type_struct_postgres() {
|
||||
let info: sqlx::postgres::PgTypeInfo = <sqlx::Postgres as HasSqlType<Struct>>::type_info();
|
||||
assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020)))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user